The main issue with the current code lies in the line np.where(img_stack_np == label)
. Indeed, it iterate over the 700 * 700 * 1000 = 490,000,000
values of img_stack_np
for the 120,000 values of label_list
resulting in 490,000,000 * 120,000 = 58,800,000,000,000
values to check.
You do not need to iterate over img_stack_np
for every label. What you need is to classify the "voxels" by their value (ie. label). You can do that using a custom sort:
- first, store the position of each voxel in an array with the label;
- then, do a key-value sort where the label is the key and the voxel position is the value;
- then, iterate through the sorted items to group them by label (or use the less efficient
np.unique
for sake of simplicity);
- finally store the position of each group in the final dict.
For sake of simplicity and to limit the memory usage, an index-based sort can also be used in replacement to the key-value sort. This can be done with argsort
. Here is an example code:
s1, s2, s3 = img_stack_np.shape
# Classification by label.
# Remove the "kind='stable'" argument if you do not care about the ordering
# of the voxel positions for a given label in the resulting dict (much faster).
index = np.argsort(img_stack_np, axis=None, kind='stable')
labels = img_stack_np.reshape(img_stack_np.size)[index]
# Generate the associated position
i1 = np.arange(s1).repeat(s2*s3)[index]
i2 = np.tile(np.arange(s2), s1).repeat(s3)[index]
i3 = np.tile(np.arange(s3), s1*s2)[index]
groupLabels, groupSizes = np.unique(img_stack_np, return_counts=True)
groupOffsets = np.concatenate(([0], counts.cumsum()))
dict_labels_and_voxels = {}
for i,label in enumerate(groupLabels):
if label==0:
continue
start, end = groupOffsets[i], groupOffsets[i] + groupSizes[i]
index = (i1[start:end], i2[start:end], i3[start:end])
dict_labels_and_voxels[label] = [index]
Here are the results on my machine using a 100x100x1000 integer-based random input with 12000 labels:
Reference algorithm: 363.29 s
Proposed algorithm (strict ordering): 3.91 s
Proposed algorithm (relaxed ordering): 2.02 s
Thus, the proposed implementation is 180 times faster in this case.
It should be several thousand times faster for your specified input.
Additional notes:
Using float32 values in img_stack_np
takes a lot of memory: 700 * 700 * 1000 * 4 ~= 2 GB
. Use float16
or even in8
/int16
types if you can. Storing the position of each voxel also takes a lot of memory: from 3 GB (with int16
) up to 12 GB (with int64
). Consider using a more compact data representation or a smaller dataset. For further improvement, look at this this post to replace the quite slow np.unique
call.