You can avoid some problems arising from numpy trying to find catchall dtype by explicitly specifying a compound dtype:
Code + some timings:
import numpy as np
import itertools
def cartesian_product_mixed_type(*arrays):
arrays = *map(np.asanyarray, arrays),
dtype = np.dtype([(f'f{i}', a.dtype) for i, a in enumerate(arrays)])
out = np.empty((*map(len, arrays),), dtype)
idx = slice(None), *itertools.repeat(None, len(arrays) - 1)
for i, a in enumerate(arrays):
out[f'f{i}'] = a[idx[:len(arrays) - i]]
return out.ravel()
a = np.arange(4)
b = np.arange(*map(ord, ('A', 'D')), dtype=np.int32).view('U1')
c = np.arange(2.)
np.set_printoptions(threshold=10)
print(f'a={a}')
print(f'b={b}')
print(f'c={c}')
print('itertools')
print(list(itertools.product(a,b,c)))
print('numpy')
print(cartesian_product_mixed_type(a,b,c))
a = np.arange(100)
b = np.arange(*map(ord, ('A', 'z')), dtype=np.int32).view('U1')
c = np.arange(20.)
import timeit
kwds = dict(globals=globals(), number=1000)
print()
print(f'a={a}')
print(f'b={b}')
print(f'c={c}')
print(f"itertools: {timeit.timeit('list(itertools.product(a,b,c))', **kwds):7.4f} ms")
print(f"numpy: {timeit.timeit('cartesian_product_mixed_type(a,b,c)', **kwds):7.4f} ms")
a = np.arange(1000)
b = np.arange(1000, dtype=np.int32).view('U1')
print()
print(f'a={a}')
print(f'b={b}')
print(f"itertools: {timeit.timeit('list(itertools.product(a,b))', **kwds):7.4f} ms")
print(f"numpy: {timeit.timeit('cartesian_product_mixed_type(a,b)', **kwds):7.4f} ms")
Sample output:
a=[0 1 2 3]
b=['A' 'B' 'C']
c=[0. 1.]
itertools
[(0, 'A', 0.0), (0, 'A', 1.0), (0, 'B', 0.0), (0, 'B', 1.0), (0, 'C', 0.0), (0, 'C', 1.0), (1, 'A', 0.0), (1, 'A', 1.0), (1, 'B', 0.0), (1, 'B', 1.0), (1, 'C', 0.0), (1, 'C', 1.0), (2, 'A', 0.0), (2, 'A', 1.0), (2, 'B', 0.0), (2, 'B', 1.0), (2, 'C', 0.0), (2, 'C', 1.0), (3, 'A', 0.0), (3, 'A', 1.0), (3, 'B', 0.0), (3, 'B', 1.0), (3, 'C', 0.0), (3, 'C', 1.0)]
numpy
[(0, 'A', 0.) (0, 'A', 1.) (0, 'B', 0.) ... (3, 'B', 1.) (3, 'C', 0.)
(3, 'C', 1.)]
a=[ 0 1 2 ... 97 98 99]
b=['A' 'B' 'C' ... 'w' 'x' 'y']
c=[ 0. 1. 2. ... 17. 18. 19.]
itertools: 7.4339 ms
numpy: 1.5701 ms
a=[ 0 1 2 ... 997 998 999]
b=['' '\x01' '\x02' ... 'ϥ' 'Ϧ' 'ϧ']
itertools: 62.6357 ms
numpy: 8.0249 ms