The general idea is to imagine the recursive call graph/tree and what the leaf nodes are; then, the iterative solution is simply going from the leaf nodes and iteratively build the tree, all the way to the root.
Of course, this is easier said than done, but often there is structure to the problem that can help in your intuition. In this particular case, this is the 2D grid.
Intuition
Let's start by building some intuition. Look at the branches in the code. They decide whether you recurse in a particular circumstance. What do they correspond to? When do you not recurse? In order, these are:
- the right edge of the grid
- the top edge of the grid
- the bottom edge of the grid
We need to build these first.
Base Case
Ask yourself: in what circumstances do we not recurse at all? This is the base case. In no particular order, these are:
- the top-right corner of the grid, and
dir=1
- the bottom-right corner of the grid, and
dir=0
Recursive Cases
Finally, ask yourself: starting with the values that we have, what can we calculate?
- for the top-right corner, we can calculate the entire right edge for
dir=1
- for the bottom-right corner, we can calculate the entire right edge for
dir=0
From this, we can then calculate the entire right edge for dir=2
.
Now that we've filled the values for the right edge, what can we then calculate? Remember the special circumstances above. The cells that only depend on the right edge are the two cells in the top and bottom edges immediately to the left of the right edge, with dir=1
and dir=0
, respectively.
With that in hand, we can now calculate the second column from the right for dir=1
and dir=0
, and therefore dir=2
.
Repeat until you find the value for the cell you wanted.
The Code
Note: this is a little suboptimal because it fills the entire table, but it should suffice to illustrate the idea.
def fill(dir, x, y):
base = matrix[y][x]
if x < cols-1:
best = base + cache[2, x + 1, y]
else:
best = base
if dir != 0 and y > 0:
best = max(best, base + cache[1, x, y - 1])
if dir != 1 and y < rows - 1:
best = max(best, base + cache[0, x, y + 1])
cache[dir, x, y] = best
def maxsum(dir, x, y):
for i in range(cols - 1, -1, -1):
for j in range(rows - 1, -1, -1):
fill(0, i, j)
for j in range(rows):
fill(1, i, j)
for j in range(rows):
fill(2, i, j)
return cache[dir, x, y]