I am looking for any alternate way to write
arr_1d = np.zeros(d ** n)
arr_nd = arr_1d.reshape((d,) * n)
with a @numba.jit(nopython=True)
decorator. (The rank n
is not fixed so writing (d,...,d)
is not a solution)
As an example:
from numba import njit, objmode
def tensor(series: np.ndarray, dim: int):
n = int(np.log(len(series)) / np.log(dim))
new_shape = (dim,) * n
# new_shape = [dim] * n
# new_shape = tuple([dim] * n)
# with objmode(new_shape="UniTuple(dtype=np.int8, count="+str(n)+")"):
# new_shape = (dim,) * n
return series.reshape(new_shape)
d = 2
n = 3
series = np.zeros(shape=d**n)
print(tensor(series=series, dim=d))
# [[[0. 0.]
# [0. 0.]]
#
# [[0. 0.]
# [0. 0.]]]
print(njit(tensor)(series=series, dim=d))
# Error (see below)
I tried different ways of writing the new_shape
tuple:
The multiplication of tuples seems not to be implemented:
new_shape = (dim,) * n
# TypingError: No implementation of function Function(<built-in function mul>) found for signature
The multiplication of lists is implemented, but a list does not work in the reshape
method:
new_shape = [dim] * n
# TypingError: Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (list(int64)<iv=None>)
As stated in the documentation, but worth a shot, tuple()
is not implemented in numba:
new_shape = tuple([dim] * n)
# TypingError: No implementation of function Function(<class 'tuple'>) found for signature:
So I tried objmode
but I don't understand it enough, so I can not make it work...
with objmode(new_shape="UniTuple(np.int8,"+str(n)+")"):
new_shape = (dim,) * n
# CompilerError: Error handling objmode argument 'new_shape'. The value must be a compile-time constant either as a non-local variable or a getattr expression that refers to a Numba type.
There might not be a solution yet, but if I missed something, I'd be delighted to receive any advice !
Thanks