this is a recursive attempt at solving your problem. it will only work if all the lists on the same depth have the same length. otherwise it will raise a ValueError
:
from collections.abc import Sequence
def get_shape(lst, shape=()):
"""
returns the shape of nested lists similarly to numpy's shape.
:param lst: the nested list
:param shape: the shape up to the current recursion depth
:return: the shape including the current depth
(finally this will be the full depth)
"""
if not isinstance(lst, Sequence):
# base case
return shape
# peek ahead and assure all lists in the next depth
# have the same length
if isinstance(lst[0], Sequence):
l = len(lst[0])
if not all(len(item) == l for item in lst):
msg = 'not all lists have the same length'
raise ValueError(msg)
shape += (len(lst), )
# recurse
shape = get_shape(lst[0], shape)
return shape
given your input (and the inputs from the comments) these are the results:
a1=[1,2,3,4,5,6]
b1=[[1,2,3],[4,5,6]]
print(get_shape(a1)) # (6,)
print(get_shape(b1)) # (2, 3)
print(get_shape([[0,1], [2,3,4]])) # raises ValueError
print(get_shape([[[1,2],[3,4]],[[5,6],[7,8]]])) # (2, 2, 2)
not sure if the last result is what you wanted.
UPDATE
as pointed out in the comments by mkl the code above will not catch all the cases where the shape of the nested list is inconsistent; e.g. [[0, 1], [2, [3, 4]]]
will not raise an error.
this is a shot at checking whether or not the shape is consistent (there might be a more efficient way to do this...)
from collections.abc import Sequence, Iterator
from itertools import tee, chain
def is_shape_consistent(lst: Iterator):
"""
check if all the elements of a nested list have the same
shape.
first check the 'top level' of the given lst, then flatten
it by one level and recursively check that.
:param lst:
:return:
"""
lst0, lst1 = tee(lst, 2)
try:
item0 = next(lst0)
except StopIteration:
return True
is_seq = isinstance(item0, Sequence)
if not all(is_seq == isinstance(item, Sequence) for item in lst0):
return False
if not is_seq:
return True
return is_shape_consistent(chain(*lst1))
which could be used this way:
lst0 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
lst1 = [[0, 1, 2], [3, [4, 5]], [7, [8, 9]]]
assert is_shape_consistent(iter(lst0))
assert not is_shape_consistent(iter(lst1))