How can I simplify this function that converts strings of slices for PyTorch / NumPy to slice list objects that can then be used to slice arrays & tensors?
The code below works, but it seems rather inefficient in terms of how many lines it takes.
def str_to_slice_indices(slicing_str: str):
# Convert indices to lists
indices = [
[i if i else None for i in indice_set.strip().split(":")]
for indice_set in slicing_str.strip("[]").split(",")
]
# Handle Ellipsis "..."
indices = [
... if index_slice == ["..."] else index_slice for index_slice in indices
]
# Handle "None" values
indices = [
None if index_slice == ["None"] else index_slice for index_slice in indices
]
# Handle single number values
indices = [
int(index_slice[0])
if isinstance(index_slice, list)
and len(index_slice) == 1
and index_slice[0].lstrip("-").isdigit()
else index_slice
for index_slice in indices
]
# Create indice slicing list
indices = [
slice(*[int(i) if i and i.lstrip("-").isdigit() else None for i in index_slice])
if isinstance(index_slice, list)
else index_slice
for index_slice in indices
]
return indices
Running the above function with an example covering the various types of inputs, give this:
out = str_to_slice_indices("[None, :1, 3:4, 2, :, 2:, ...]")
print(out)
# out:
# [None, slice(None, 1, None), slice(3, 4, None), 2, slice(None, None, None), slice(2, None, None), Ellipsis]