0

I'm using Dask to process a massive dataset and eventually build a model for a classification task and I'm running into problems. I hope I can get some help.

Main Task

I'm working with clinical notes. Each clinical note has a note type associated with it. However, over 60% of the notes are of type *Missing*. I'm trying to train a classifier on the notes that are labeled and run inference on the notes that have the missing type.

Data

I'm working with 3 years worth of clinical notes. The total data size is ~1.3TB. These were pulled from a database using PySpark (I have no control over this process) and are organized as year/month/partitions.parquet. The root directory is raw_data. The number of partitions within each month varies (e.g, one of the months has 2620 partitions). The total number of partitions is over 50,000.

Machine

Cores: 64 Memory: 1TB

Machine is shared with others so I won't be able to access the entire hardware resources at a given time.

Code

As a first step towards building the model, I want to preprocess the data and do some EDA. I'm using the package Textdescriptives which uses SpaCy to get some basic information about the text.

def replace_empty(text, replace=np.nan):
  """
  Replace empty notes with nan's which can be removed later
  """
  if pd.isnull(text):
    return text
  elif text.isspace() or text == '':
    return replace
  return text

def fix_ws(text):
  """
  Replace multiple carriage returns with a single newline
  and multiple new lines with a single new line
  """
  text = re.sub('\r', '\n', text)
  text = re.sub('\n+', '\n', text)
  return text

def replace_empty_part(df, **kwargs):
  return df.apply(replace_empty)

def fix_ws_part(df, **kwargs):
  return df.apply(fix_ws)

def fix_missing_part(df, **kwargs):
  return df.apply(lambda t: *Missing* if t == 'Unknown at this time' else t)

def extract_td_metrics(text, spacy_model):
  try:
    doc = spacy_model(text)
    metrics_df = td.extract_df(doc)[cols]
    return metrics_df.squeeze()
  except:
    return pd.Series([np.nan for _ in range(len(cols))], index=cols)
  
def extract_metrics_part(df, **kwargs):
  spacy_model = spacy.load('en_core_web_sm', disable=['tok2vec', 'parser', 'ner', 'attribute_ruler', 'lemmantizer'])
  spacy_model.add_pipe('textdescriptives')
  return df.apply(extract_td_metrics, spacy_model=spacy_model)

client = Client(n_workers=32)
notes_df = dd.read_parquet(single_month)

notes_df['Text'] = notes_df['Text'].map_partitions(replace_empty_part, meta='string')
notes_df = notes_df.dropna()
notes_df['Text'] = notes_df['Text'].map_partitions(fix_ws_part, meta='string')
notes_df['NoteType'] = notes_df['NoteType'].map_partitions(fix_missing_part, meta='string')
metrics_df = notes_df['Text'].map_partitions(extract_metrics_part)
notes_df = dd.concat([notes_df, metrics_df], axis=1)
notes_df = notes_df.dropna()
notes_df = notes_df.repartition(npartitions=4)
notes_df.to_parquet(processed_notes, schema={'NoteType': pa.string(), 'Text': pa.string(), write_index=False)

All of this code was tested on a small sample with Pandas to make sure it works and on Dask (on the same sample) to make sure the results matched. When I run this code on only a single month worth of data, after running for a few seconds, the process just hangs outputing a stream of warnings of this type:

timestamp - distributed.utils_perf - WARNING - full garbage collections took 35% CPU time recently (threshold: 10%) 

The machine is in a secure enclave so I don't have copy/paste facility so I'm typing out everything here. After some research I came across two threads here and here. While there wasn't a direct solution in either one of them, suggestions included disabling Python garbage collection using gc.disable and starting a clean environment with dask freshly installed. Both of these didn't help me. I'm wondering if I can maybe modify my code so that this problem doesn't happen. There is no way to load all this data in memory and use Pandas directly.

Thanks.

shaun
  • 560
  • 1
  • 11
  • 29
  • Hi @shaun, is this somewhat related to https://dask.discourse.group/t/client-not-starting-and-hangs-after-setting-temporary-directory-in-dask-config/1497? Anyway, do you have access to the dashboard to see what is happening on the Dask cluster side? Also, this code `notes_df = notes_df.repartition(npartitions=4)` seems weird to me, why do you always want 4 partitions? Does that fit in memory? – Guillaume EB Feb 03 '23 at 07:14
  • Thank you for the reply. The problem is basically with extract_df which in turn uses SpaCy. I need to figure out how to fix that. – shaun Feb 06 '23 at 21:10
  • Does the code works with Pandas on a single partition? – Guillaume EB Feb 08 '23 at 15:43

0 Answers0