3

Background

I'm using dask to manage tens, sometimes hundreds of thousands of jobs, each of which involves reading in zarr data, transforming the data in some way, and writing out output (one output per job). I'm using a pangeo/daskhub-style jupyterhub kubernetes cluster with dask gateway.

The data isn't alignable into a giant dask-backed xarray DataArray, instead, our common usage pattern is to simply use dask distributed's Client.map to map functions to tasks. Each task can run in memory.

The problem is that, for some operations including reading a zarr array, xarray automatically schedules each I/O operation with the cluster scheduler, even when the operation is called by a remote worker. This multiplies the number of tasks the scheduler has to manage, sometimes by a large factor. If the zarr arrays have many chunks, sometimes the dask scheduler tries to distribute the reads, leading to a huge amount of network traffic within the cluster, which can grind progress to a halt when many tasks try to get in line to have their chunks read all at the same time.

It's entirely possible that the right answer here is "don't do this", e.g. for a job this big and complex use something like Argo or Kubeflow. But I wanted to see if anyone had ideas for how to make this work well with dask.

Question

My question, essentially, is whether it's possible to prevent xarray (or another library with native dask support) from using the cluster's scheduler when run within a dask.distributed task.

I think the ideal as I see it might be something like:

def mappable_task(args):
    input_fp, output_fp = args

    # goal would be for all of the code within this block to operate as if the
    # cluster scheduler did not exist, to ensure data is read locally and the
    # additional tasks created by reading the zarr array do not bog down the 
    # cluster scheduler.
    with dd.disable_scheduler():
        
        # in our workflow we're reading/writing to google cloud storage using
        # gcsfs.GCSFileSystem.get_mapper
        # https://gcsfs.readthedocs.io/en/latest/api.html
        # Also, we sometimes will only be reading a small portion of the
        # data, or are combining multiple datasets. Just noting that this
        # may involve many scheduled operations per `mappable_task` call
        ds = xr.open_zarr(input_fp).load()
        
        # some long-running operation on the data, which depending on our
        # use case has run from nonlinear transforms to geospatial ops to
        # calling hydrodynamics models
        res = ds * 2

        res.to_zarr(output_fp)

def main():
    JOB_SIZE = 10_000
    jobs = [(f'/infiles/{i}.zarr', f'/outfiles/{i}.zarr') for i in range(JOB_SIZE)]
    client = dd.Client()
    futures = client.map(mappable_task, jobs)
    dd.wait(futures)

I'm not sure if this would involve changing the behavior of xarray, zarr, or dd.get_client(), or something else.

MRE

The above could be tweaked to get a testable example. The goal would be to not see any tasks appear other than the main mapped function. I ran the following in a jupyterlab ipython notebook and watched tasks with dask-labextension (the scheduler dashboard shows the same result)

imports

import xarray as xr
import dask.distributed as dd
import numpy as np
import os
import datetime
import shutil

Test file setup

shutil.rmtree('infiles', ignore_errors=True)
shutil.rmtree('outfiles', ignore_errors=True)

os.makedirs('infiles', exist_ok=True)
os.makedirs('outfiles', exist_ok=True)

# create two Zarr stores, each with 1000 chunks. This isn't an uncommon
# structure, though each chunk would normally have far more data
for i in range(2):
    ds = xr.Dataset(
        {'var1': (('dim1', 'dim2', ), np.random.random(size=(1000, 100)))},
        coords={'dim1': np.arange(1000), 'dim2': np.arange(100)},
    ).chunk({'dim1': 1})
    
    ds.to_zarr(f'infiles/data_{i}.zarr')

Function definition

def mappable_task(args):
    input_fp, output_fp = args
        
    # in our workflow we're reading/writing to google cloud storage using
    # gcsfs.GCSFileSystem.get_mapper
    # https://gcsfs.readthedocs.io/en/latest/api.html
    ds = xr.open_zarr(input_fp).load()

    # some long-running operation on the data, which depending on our
    # use case has run from nonlinear transforms to geospatial ops to
    # calling hydrodynamics models
    res = ds * 2

    res.to_zarr(output_fp)

create a client and watch the dashboard

client = dd.Client()
client

Map the job

JOB_SIZE = 2
jobs = [(f'infiles/data_{i}.zarr', f'outfiles/out_{i}.zarr') for i in range(JOB_SIZE)]

futures = client.map(mappable_task, jobs)
dd.wait(futures);

Cleanup (if running again)

shutil.rmtree('outfiles', ignore_errors=True)
os.makedirs('outfiles', exist_ok=True)

# refresh the client (in case of running multiple times)
client.restart()

Full cleanup

shutil.rmtree('infiles', ignore_errors=True)
shutil.rmtree('outfiles', ignore_errors=True)
client.close();

Note that thousands of tasks are scheduled despite there only being two jobs.

I'm using a conda environment with (among many other packages) the following:

dask                      2.30.0                     py_0    conda-forge
dask-gateway              0.9.0            py38h578d9bd_0    conda-forge
dask-labextension         3.0.0                      py_0    conda-forge
jupyter-server-proxy      1.5.0                      py_0    conda-forge
jupyter_client            6.1.7                      py_0    conda-forge
jupyter_core              4.7.0            py38h578d9bd_0    conda-forge
jupyter_server            1.1.3            py38h578d9bd_0    conda-forge
jupyter_telemetry         0.1.0              pyhd8ed1ab_1    conda-forge
jupyterhub                1.2.2            py38h578d9bd_0    conda-forge
jupyterhub-base           1.2.2            py38h578d9bd_0    conda-forge
jupyterlab                2.2.9                      py_0    conda-forge
jupyterlab_server         1.2.0                      py_0    conda-forge
jupyterlab_widgets        1.0.0              pyhd8ed1ab_1    conda-forge
kubernetes                1.18.8                        0    conda-forge
kubernetes-client         1.18.8               haa36a5b_0    conda-forge
kubernetes-node           1.18.8               haa36a5b_0    conda-forge
kubernetes-server         1.18.8               haa36a5b_0    conda-forge
nb_conda_kernels          2.3.1            py38h578d9bd_0    conda-forge
nbclient                  0.5.1                      py_0    conda-forge
nodejs                    15.2.1               h914e61d_0    conda-forge
notebook                  6.1.6            py38h578d9bd_0    conda-forge
numpy                     1.19.4           py38hf0fd68c_1    conda-forge
pandas                    1.1.5            py38h51da96c_0    conda-forge
python                    3.8.6           h852b56e_0_cpython    conda-forge
python-kubernetes         11.0.0           py38h32f6830_0    conda-forge
xarray                    0.16.2             pyhd8ed1ab_0    conda-forge
zarr                      2.6.1              pyhd8ed1ab_0    conda-forge
Michael Delgado
  • 13,789
  • 3
  • 29
  • 54
  • 2
    This is a great question! If you run with chunks=None, does that reduce the task count sufficiently? – Maximilian Apr 04 '21 at 04:52
  • 1
    It might be worth also posting this as a GitHub discussions question if you don't get a good answer here. – Maximilian Apr 04 '21 at 04:53
  • 1
    :facepalm: always read the docs many times. Thanks @Maximilian! I would be interested to know if there's a general answer to this question for dask-backed xarray ops, but this is really helpful for zarr reads specifically! – Michael Delgado Apr 04 '21 at 05:01
  • wow - I'm finding chunks=None improves performance significantly, even for loading subsets of the data, when reading in zarr arrays to a single worker. Really appreciate the tip! – Michael Delgado Apr 04 '21 at 18:25

0 Answers0