0

Say I have a two DataArrays, A and B, both with dimensions time, x, z. I want to sort all values of A only in x and z. So that at each individual time I will have a DataArray with sorted values. Simultaneously, I also want to sort B but based on the values of A.

If I only had 1-D numpy arrays I could what I want following this answer:

>>> a = numpy.array([2, 3, 1])
>>> b = numpy.array([4, 6, 7])
>>> p = a.argsort()
>>> p
[2, 0, 1]
>>> a[p]
array([1, 2, 3])
>>> b[p]
array([7, 4, 6])

However, with DataArrays the problem is a bit more complicated. I can get something that works with the following code:

def zipsort_xarray(da_a, da_b, unsorted_dim="time"):
    assert da_a.dims == da_b.dims, "Dimensions aren't the same"
    for dim in da_a.dims:
        assert np.allclose(da_a[dim], da_b[dim]), f"Coordinates of {dim} aren't the same"

    sorted_dims = [ dim for dim in da_a.dims if dim != unsorted_dim ]
    daa_aux = da_a.stack(aux_dim=sorted_dims) # stack all dims to be sorted into one

    indices = np.argsort(daa_aux, axis=-1) # get indices that sort the last (stacked) dim
    indices[unsorted_dim] = range(len(indices.time)) # turn unsorted_dim into a counter
    flat_indices = np.concatenate(indices + indices.time*len(indices.aux_dim)) # Make indices appropriate for indexing a fully flattened version of the data array 

    daa_aux2 = da_a.stack(aux_dim2=da_a.dims) # get a fully flatten version of the data array
    daa_aux2.values = daa_aux2.values[flat_indices] # apply the flattened indices to sort it

    dab_aux2 = da_b.stack(aux_dim2=da_b.dims) # get a fully flatten version of the data array
    dab_aux2.values = dab_aux2.values[flat_indices] # apply the same flattened indices to sort it

    return daa_aux2.unstack(), dab_aux2.unstack() # return unflattened (unstacked) DataArrays



tsize=2
xsize=2
zsize=2

data1 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))
data2 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))
sort1, sort2 = zipsort_xarray(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))

However, not only I feel this is a bit "hacky", it also doesn't work well with dask.

I'm planning on using this on large DataArrays that will be chunked in time, so it's important that I get something going that can work in those cases. However if I chunk the DataArrays in time I get:

data1 = data1.chunk(dict(time=1))
data2 = data2.chunk(dict(time=1))
sort1, sort2 = zipsort_xarray(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))

and the output

NotImplementedError: 'argsort' is not yet a valid method on dask arrays

Is there any way to make this work with chunked DataArrays?

TomCho
  • 3,204
  • 6
  • 32
  • 83

1 Answers1

0

I think I have something working that seems to be fully parallel. Only works when the time dimension is chunked with size one:

import xarray as xr
import numpy as np


def zipsort3(da_a, da_b, unsorted_dim="time"):
    """
    Only works if both `da_a` and `da_b` are chunked in `unsorted_dim`
    with size 1 chunks
    """
    from dask.array import map_blocks
    assert da_a.dims == da_b.dims, "Dimensions aren't the same"
    for dim in da_a.dims:
        assert np.allclose(da_a[dim], da_b[dim]), f"Coordinates of {dim} aren't the same"

    sorted_dims = [ dim for dim in da_a.dims if dim != unsorted_dim ]
    daa_aux = da_a.stack(aux_dim=sorted_dims).transpose(unsorted_dim, "aux_dim") # stack all dims to be sorted into one
    dab_aux = da_b.stack(aux_dim=sorted_dims).transpose(unsorted_dim, "aux_dim") # stack all dims to be sorted into one

    indices = map_blocks(np.argsort, daa_aux.data, axis=-1, dtype=np.int64)

    def reorder(A, ind): return A[0,ind]
    daa_aux.data = map_blocks(reorder, daa_aux.data, indices, dtype=np.float64)
    dab_aux.data = map_blocks(reorder, dab_aux.data, indices, dtype=np.float64)
    return daa_aux.unstack(), dab_aux.unstack()


tsize=2
xsize=2
zsize=2

data1 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))
data2 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))

data1 = data1.chunk(dict(time=1))
data2 = data2.chunk(dict(time=1))

sorted1, sorted2 = zipsort3(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))
TomCho
  • 3,204
  • 6
  • 32
  • 83