To run a Dask cluster on Kubernetes capable of GPU compute you need the following:
- Kubernetes nodes need GPUs and drivers. This can be set up with the NVIDIA k8s device plugin.
- Scheduler and worker pods will need a Docker image with NVIDIA tools installed. As you suggest the RAPIDS images are good for this.
- The pod container spec will need GPU resources such as
resources.limits.nvidia.com/gpu: 1
- The Dask workers needs to be started with the
dask-cuda-worker
command from the dask_cuda
package (which is included in the RAPIDS images).
Note: For Dask Gateway your container image also needs the dask-gateway
package to be installed. We can configure this to be installed at runtime but it's probably best to create a custom image with this package installed.
Therefore here is a minimal Dask Gateway config which will get you a GPU cluster.
# config.yaml
gateway:
backend:
image:
name: rapidsai/rapidsai
tag: cuda11.0-runtime-ubuntu18.04-py3.8 # Be sure to match your k8s CUDA version and user's Python version
worker:
extraContainerConfig:
env:
- name: EXTRA_PIP_PACKAGES
value: "dask-gateway"
resources:
limits:
nvidia.com/gpu: 1 # This could be >1, you will get one worker process in the pod per GPU
scheduler:
extraContainerConfig:
env:
- name: EXTRA_PIP_PACKAGES
value: "dask-gateway"
resources:
limits:
nvidia.com/gpu: 1 # The scheduler requires a GPU in case of accidental deserialisation
extraConfig:
cudaworker: |
c.ClusterConfig.worker_cmd = "dask-cuda-worker"
We can test things work by launching Dask gateway, creating a Dask cluster and running some GPU specific work. Here is an example where we get the NVIDIA driver version from each worker.
$ helm install dgwtest daskgateway/dask-gateway -f config.yaml
In [1]: from dask_gateway import Gateway
In [2]: gateway = Gateway("http://dask-gateway-service")
In [3]: cluster = gateway.new_cluster()
In [4]: cluster.scale(1)
In [5]: from dask.distributed import Client
In [6]: client = Client(cluster)
In [7]: def get_nvidia_driver_version():
...: import pynvml
...: return pynvml.nvmlSystemGetDriverVersion()
...:
In [9]: client.run(get_nvidia_driver_version)
Out[9]: {'tls://10.42.0.225:44899': b'450.80.02'}