Say I have a two DataArrays, A
and B
, both with dimensions time, x, z. I want to sort all values of A
only in x and z. So that at each individual time I will have a DataArray with sorted values. Simultaneously, I also want to sort B
but based on the values of A
.
If I only had 1-D numpy arrays I could what I want following this answer:
>>> a = numpy.array([2, 3, 1])
>>> b = numpy.array([4, 6, 7])
>>> p = a.argsort()
>>> p
[2, 0, 1]
>>> a[p]
array([1, 2, 3])
>>> b[p]
array([7, 4, 6])
However, with DataArrays the problem is a bit more complicated. I can get something that works with the following code:
def zipsort_xarray(da_a, da_b, unsorted_dim="time"):
assert da_a.dims == da_b.dims, "Dimensions aren't the same"
for dim in da_a.dims:
assert np.allclose(da_a[dim], da_b[dim]), f"Coordinates of {dim} aren't the same"
sorted_dims = [ dim for dim in da_a.dims if dim != unsorted_dim ]
daa_aux = da_a.stack(aux_dim=sorted_dims) # stack all dims to be sorted into one
indices = np.argsort(daa_aux, axis=-1) # get indices that sort the last (stacked) dim
indices[unsorted_dim] = range(len(indices.time)) # turn unsorted_dim into a counter
flat_indices = np.concatenate(indices + indices.time*len(indices.aux_dim)) # Make indices appropriate for indexing a fully flattened version of the data array
daa_aux2 = da_a.stack(aux_dim2=da_a.dims) # get a fully flatten version of the data array
daa_aux2.values = daa_aux2.values[flat_indices] # apply the flattened indices to sort it
dab_aux2 = da_b.stack(aux_dim2=da_b.dims) # get a fully flatten version of the data array
dab_aux2.values = dab_aux2.values[flat_indices] # apply the same flattened indices to sort it
return daa_aux2.unstack(), dab_aux2.unstack() # return unflattened (unstacked) DataArrays
tsize=2
xsize=2
zsize=2
data1 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
coords=dict(time=range(tsize),
x=range(xsize),
z=range(zsize)))
data2 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
coords=dict(time=range(tsize),
x=range(xsize),
z=range(zsize)))
sort1, sort2 = zipsort_xarray(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))
However, not only I feel this is a bit "hacky", it also doesn't work well with dask.
I'm planning on using this on large DataArrays that will be chunked in time, so it's important that I get something going that can work in those cases. However if I chunk the DataArrays in time I get:
data1 = data1.chunk(dict(time=1))
data2 = data2.chunk(dict(time=1))
sort1, sort2 = zipsort_xarray(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))
and the output
NotImplementedError: 'argsort' is not yet a valid method on dask arrays
Is there any way to make this work with chunked DataArrays?