14

In the course of implementing the "Variable Elimination" algorithm for a Bayes' Nets program, I encountered an unexpected bug that was the result of an iterative map transformation of a sequence of objects.

For simplicity's sake, I'll use an analogous piece of code here:

>>> nums = [1, 2, 3]
>>> for x in [4, 5, 6]:
...     # Uses n if x is odd, uses (n + 10) if x is even
...     nums = map(
...         lambda n: n if x % 2 else n + 10,
...         nums,
...     )
...
>>> list(nums)
[31, 32, 33]

This is definitely the wrong result. Since [4, 5, 6] contains two even numbers, 10 should be added to each element at most twice. I was getting unexpected behaviour with this in the VE algorithm as well, so I modified it to convert the map iterator to a list after each iteration.

>>> nums = [1, 2, 3]
>>> for x in [4, 5, 6]:
...     # Uses n if x is odd, uses (n + 10) if x is even
...     nums = map(
...         lambda n: n if x % 2 else n + 10,
...         nums,
...     )
...     nums = list(nums)
...
>>> list(nums)
[21, 22, 23]

From my understanding of iterables, this modification shouldn't change anything, but it does. Clearly, the n + 10 transform for the not x % 2 case is applied one fewer times in the list-ed version.

My Bayes Nets program worked as well after finding this bug, but I'm looking for an explanation as to why it occurred.


This behaviour is 3.x-specific, and a special case of What do lambda function closures capture?. Because 3.x's map is lazy, the first version of the code won't evaluate the lambdas until list is used outside the loop, at which point the final value of x is used each time. In 2.x, map simply creates a list, and list(nums) is redundant.

cosmicFluke
  • 355
  • 1
  • 10
  • 5
    Please don't write code like that. It makes my brain hurt. – Kevin Apr 05 '16 at 05:39
  • Your code doesn't do what you explain it should do. If you want it to keep the numbers if they are odd and add ten if they are even, you must put `nums = map(lambda n: n if x % 2 != 0 else n + 10, nums)` .... there needs to be something to evaluate for the if function, otherwise it always evaluates as true. BTW. the problem with the code was not newlines, as your edits suggest. I think, what kevin complains about are the dots and arrows to the left. – Mikkel Bue Tellus Apr 05 '16 at 05:47
  • 1
    Could it be that in Python3 , `map` is an iterator? – hpaulj Apr 05 '16 at 05:52
  • @MikkelBueTellus, `x % 2` evaluates to `False` if `x % 2 == 0`, and `True` if `x % 2 == 1` – cosmicFluke Apr 05 '16 at 05:52
  • 1
    Oh, and since you iterate on the list [4,5,6], the function overwrites nums for each iteration. Which means that your output is only related to the last element (6) in your list. – Mikkel Bue Tellus Apr 05 '16 at 05:54
  • @hpaulj I'm expecting it to be an iterator. – cosmicFluke Apr 05 '16 at 05:54
  • @cosmicFluke sorry, you are right. I learned something new today :) – Mikkel Bue Tellus Apr 05 '16 at 06:01
  • 2
    yeah, the lambdas are run at the line `list(nums)`, all 3 of them see `x` then bound to 6, and thus add 10. You can verify this by adding `del x` before `list(nums)`; you'd get a `NameError`. – Antti Haapala -- Слава Україні Apr 05 '16 at 07:32
  • The same will happen with other on-demand iterators, such as `filter`: see e.g. https://stackoverflow.com/questions/34843139/trouble-with-applying-python-2-code-in-python-3. – Karl Knechtel Aug 16 '22 at 02:42

3 Answers3

13

The answer is very simple: map is a lazy function in Python 3, it returns an iterable object (in Python 2 it returns a list). Let me add some output to your example:

In [6]: nums = [1, 2, 3]

In [7]: for x in [4, 5, 6]:
   ...:     nums = map(lambda n: n if x % 2 else n + 10, nums)
   ...:     print(x)
   ...:     print(nums)
   ...:     
4
<map object at 0x7ff5e5da6320>
5
<map object at 0x7ff5e5da63c8>
6
<map object at 0x7ff5e5da6400>

In [8]: print(x)
6

In [9]: list(nums)
Out[9]: [31, 32, 33]

Note the In[8] - the value of x is 6. We could also transform the lambda function, passed to map in order to track the value of x:

In [10]: nums = [1, 2, 3]

In [11]: for x in [4, 5, 6]:
   ....:     nums = map(lambda n: print(x) or (n if x % 2 else n + 10), nums)
   ....:     

In [12]: list(nums)
6
6
6
6
6
6
6
6
6
Out[12]: [31, 32, 33]

Because map is lazy, it evaluates when list is being called. However, the value of x is 6 and that is why it produces confusing output. Evaluating nums inside the loop produces expected output.

In [13]: nums = [1, 2, 3]

In [14]: for x in [4, 5, 6]:
   ....:     nums = map(lambda n: print(x) or (n if x % 2 else n + 10), nums)
   ....:     nums = list(nums)
   ....:     
4
4
4
5
5
5
6
6
6

In [15]: nums
Out[15]: [21, 22, 23]
Community
  • 1
  • 1
awesoon
  • 32,469
  • 11
  • 74
  • 99
  • Ah! It was the `x`! Thank you. I know iterators in general are "lazy", but I missed that part of the laziness. Is there any way to set up a lambda expression to immediately evaluate a variable during each iteration of a loop? (edit: never mind, got it!) This seems to be a danger of mixing functional and imperative programming! – cosmicFluke Apr 05 '16 at 05:58
  • 2
    "it evaluates when the x value has been changed to 6". Rather, it evaluates when `list` is called on the map object, which happens when `x = 6`. –  Apr 05 '16 at 05:58
  • 1
    @cosmicFluke besides the `map`'s laziness, it's also a case of [late binding](http://docs.python-guide.org/en/latest/writing/gotchas/#late-binding-closures), as pointed out in Blckknght's and Mike's answers. – bereal Apr 05 '16 at 06:27
  • @bereal Thanks for the link! It's nice to know there's a term for it. I'm not surprised that it's an entry in a "common gotchas" article - easy to miss. – cosmicFluke Apr 05 '16 at 06:59
5

The issue has to do with how the x variable is accessed by the lambda functions you are creating. The way Python's scoping works, the lambda functions will always use the latest version of x from the outside scope when they're called, not the value that it had when they were defined.

Since map is lazy, the lambda functions don't get called until after the loop (when you consume the nested maps by passing them to list) and so, they all use the last x value.

To make each lambda function save the value x has when they are defined, add x=x like this:

lambda n, x=x: n if x % 2 else n + 10

This specifies an argument and its default value. The default will be evaluated at the time the lambda is defined, so when the lambda gets called later (without a second argument), the x inside the expression will be that saved default value.

Blckknght
  • 100,903
  • 11
  • 120
  • 169
4

If you want to use the lazy version, you need to fix x in each loop. functools.partial does exactly that:

from functools import partial

def myfilter(n, x):
    return n if x % 2 else n + 10

nums = [1, 2, 3]
for x in [4, 5, 6]:
    f = partial(myfilter, x=x)
    nums = map(f, nums)

>>> list(nums)
[21, 22, 23]
Mike Müller
  • 82,630
  • 20
  • 166
  • 161