1

I am trying to use itertools.product to manage the bookkeeping of some nested for loops, where the number of nested loops is not known in advance. Below is a specific example where I have chosen two nested for loops; the choice of two is only for clarity, what I need is a solution that works for an arbitrary number of loops.

This question provides an extension/generalization of the question appearing here: Efficient algorithm for evaluating a 1-d array of functions on a same-length 1d numpy array

Now I am extending the above technique using an itertools trick I learned here: Iterating over an unknown number of nested loops in python

Preamble:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [idx1, idx2]

func_table  = []
for items in product(*joint):
    f = trivial_functional(*items)
    func_table.append(f)

At the end of the above itertools loop, I have a 12-element, 1-d array of functions, func_table, each element having been built from the trivial_functional.

Question:

Suppose I am given a pair of integers, (i_1, i_2), where these integers are to be interpreted as the indices of idx1 and idx2, respectively. How can I use itertools.product to determine the correct corresponding element of the func_table array?

I know how to hack the answer by writing my own function that mimics the itertools.product bookkeeping, but surely there is a built-in feature of itertools.product that is intended for exactly this purpose?

Community
  • 1
  • 1
aph
  • 1,765
  • 2
  • 19
  • 34
  • Let me try to rephrase your question: You basically not only need the elements of the input sequences but also their indexes? – dhke Mar 20 '15 at 20:04
  • Yes, that's an efficient way of putting it. – aph Mar 20 '15 at 20:04
  • Again, I can manually look at the way itertools works and write an independent function that returns the indices. But this seems like needless repetition of work, since itertools must have solved this problem already, so I'd much prefer to use a unified itertools syntax, if available. – aph Mar 20 '15 at 20:06
  • First idea would be `for items in product(*(enumerate(j) for j in joint))`, but while that gives you the information you need, it's not necessarily in a nice format ... – dhke Mar 20 '15 at 20:06
  • How is the index in `func_table` calculated? You state it's a 1d array, but you have n indexes (one for each input sequence). From your code, each iteration gets its own position in `func_table`. – dhke Mar 20 '15 at 20:18
  • (0,0)-->0, (1,0)-->1, (2,0)-->2, (0,1)-->3, etc. From this knowledge one can reasonably easily hack the solution. – aph Mar 20 '15 at 20:26

4 Answers4

2

I don't know of a way of calculating the flat index other than doing it yourself. Fortunately this isn't that difficult:

def product_flat_index(factors, indices):
  if len(factors) == 1: return indices[0]
  else: return indices[0] * len(factors[0]) + product_flat_index(factors[1:], indices[1:])

>> product_flat_index(joint, (2, 1))
9

An alternative approach is to store the results in a nested array in the first place, making translation unnecessary, though this is more complex:

from functools import reduce
from operator import getitem, setitem, itemgetter

def get_items(container, indices):
  return reduce(getitem, indices, container)

def set_items(container, indices, value):
  c = reduce(getitem, indices[:-1], container)
  setitem(c, indices[-1], value)

def initialize_table(lengths):
  if len(lengths) == 1: return [0] * lengths[0]
  subtable = initialize_table(lengths[1:])
  return [subtable[:] for _ in range(lengths[0])]

func_table = initialize_table(list(map(len, joint)))
for items in product(*map(enumerate, joint)):
  f = trivial_functional(*map(itemgetter(1), items))
  set_items(func_table, list(map(itemgetter(0), items)), f)

>>> get_items(func_table, (2, 1)) # same as func_table[2][1]
<function>
Uri Granta
  • 1,814
  • 14
  • 25
  • This is elegant and I think it's the "correct" way to deal with an unknown degree of nesting. It's almost surely slower. But this is python so we don't mind :-) – Sanjay Manohar Mar 20 '15 at 21:15
  • 1
    In your first code snippet, Uri, what is the product_index function on line 3? I think this is a typo? – aph Mar 21 '15 at 12:38
2

So numerous answers were quite useful, thanks to everyone for the solutions.

It turns out that if I recast the problem slightly with Numpy, I can accomplish the same bookkeeping, and solve the problem I was trying to solve with vastly improved speed relative to pure python solutions. The trick is just to use Numpy's reshape method together with the normal multi-dimensional array indexing syntax.

Here's how this works. We just convert func_table into a Numpy array, and reshape it:

func_table = np.array(func_table)
component_dimensions = [len(idx1), len(idx2)]
func_table = np.array(func_table).reshape(component_dimensions)

Now func_table can be used to return the correct function not just for a single 2d point, but for a full array of 2d points:

dim1_pts = [3,1,2,1,3,3,1,3,0]
dim2_pts = [0,1,2,1,2,0,1,2,1]
func_array = func_table[dim1_pts, dim2_pts]

As usual, Numpy to the rescue!

aph
  • 1,765
  • 2
  • 19
  • 34
1

This is a little messy, but here you go:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [enumerate(idx1), enumerate(idx2)]

func_map  = {}
for indexes, items in map(lambda x: zip(*x), product(*joint)):
    f = trivial_functional(*items)
    func_map[indexes] = f

print(func_map[(2, 0)](5)) # 40 = (3+5)*5
axblount
  • 2,639
  • 23
  • 27
0

I'd suggest using enumerate() in the right place:

from itertools import product

def trivial_functional(i, j): return lambda x : (i+j)*x

idx1 = [1, 2, 3, 4]
idx2 = [5, 6, 7]
joint = [idx1, idx2]

func_table  = []
for items in product(*joint):
     f = trivial_functional(*items)
     func_table.append(f)

From what I understood from your comments and your code, func_table is simply indexed by the occurence of a certain input in the sequence. You can access it back again using:

for index, items in enumerate(product(*joint)):
    # because of the append(), index is now the 
    # position of the function created from the 
    # respective tuple in join()
    func_table[index](some_value)
dhke
  • 15,008
  • 2
  • 39
  • 56