I have written this python function, which I believed was going to port to numba. Unfortunately it does not, and I'm not sure I understand the error:
Invalid use of getiter with parameters (none)
.
Does it need to know the type of the generator? Is it because it returns tuples of variable length?
from numba import njit
# @njit
def iterator(N, k):
r"""Numba implementation of an iterator over tuples of N integers,
such that sum(tuple) == k.
Args:
N (int): number of elements in the tuple
k (int): sum of the elements
Returns:
tuple(int): a tuple of N integers
"""
if N == 1:
yield (k,)
else:
for i in range(k+1):
for j in iterator(N-1, k-i):
yield (i,) + j
EDIT
Thanks to Jerome for the tips. Here's the solution I eventually wrote (I started from the left):
import numpy as np
from numba import njit
@njit
def next_lst(lst, i, reset=False):
r"""Computes the next list of indices given the current list
and the current index.
"""
if lst[i] == 0:
return next_lst(lst, i+1, reset=True)
else:
lst[i] -= 1
lst[i+1] += 1
if reset:
lst[0] = np.sum(lst[:i+1])
lst[1:i+1] = 0
i = 0
return lst, i
@njit
def generator(N, k):
r"""Goes through all the lists of indices recursively.
"""
lst = np.zeros(N, dtype=np.int64)
lst[0] = k
i = 0
yield lst
while lst[-1] < k:
lst, i = next_lst(lst, i)
yield lst
This gives the correct result and it's jitted!
for lst in generator(4,2):
print(lst)
[2 0 0 0]
[1 1 0 0]
[0 2 0 0]
[1 0 1 0]
[0 1 1 0]
[0 0 2 0]
[1 0 0 1]
[0 1 0 1]
[0 0 1 1]
[0 0 0 2]