1

I'm able to access and update the .data attribute of a pytorch tensor when the variable is outside a functions namespace:

x = torch.zeros(5)
def my_function():
    x.data += torch.ones(5)
my_function()
print(x)       # tensor([1., 1., 1., 1., 1.])

When I (attempt to) update x in the regular fashion though (i.e. x += y), I get an error "UnboundLocalError: local variable 'x' referenced before assignment". This is expected because x is outside of my_function's namespace.

x = torch.zeros(5)
def my_function():
    x += torch.ones(5)   # UnboundLocalError: local variable 'x' referenced before assignment
my_function()

Why can I update x via .data but not with its regular += operator?

jstm
  • 420
  • 4
  • 14

2 Answers2

1

This doesn't have to do with PyTorch specifically. Python assumes any assignment within a local scope refers to a local variable unless the variable is explicitly declared global in that scope. A similar question: Why does this UnboundLocalError occur (closure)?

For your particular question, the problem is that x is defined only in the global scope, so you can't assign a new value to x without declaring it global. On the other hand, x.data refers to an attribute of x, the attribute itself is not a global, so you can assign it without using the global keyword.

As an example, consider the following code

class Foo():
    def __init__(self):
        self.data = 1

x = Foo()

def f():
    x.data += 1

f()
print(x.data)  # 2

This code will update x.data as expected since x.data is not a global variable.

On the other hand

class Foo():
    def __init__(self):
        self.data
    def __iadd__(self, v)
        self.data += v
        return self

x = Foo()

def f():
    x += 1    # UnboundLocalError

f()
print(x.data)

will raise an UnboundLocalError because x += 1 is interpreted by the python compiler as an assignment to x, therefore x must refer to a local variable. Since a local x hasn't been declared prior to this you get an exception.

In order for the previous code to work we need to explicitly declare x to be global within the function's scope.

class Foo():
    def __init__(self):
        self.data
    def __iadd__(self, v)
        self.data += v
        return self

x = Foo()

def f():
    global x   # tell python that x refers to a global variable
    x += 1

f()
print(x.data)  # 2
jodag
  • 19,885
  • 5
  • 47
  • 66
  • Got it, thanks for clarifying. The fact that attributes of objects outside of the current namespace are assignable but the objects themselves aren't seems pretty weird to me. Do you know why this is the case by chance? – jstm Mar 13 '23 at 01:09
  • In the case of, for example, `x.data = 1` there is no namespace ambiguity of where `data` should be because `x.data` is explicitly in the namespace of `x`, which must therefore already exist. As long as you don't re-bind `x` somewhere else in the local scope then python knows `x` must exist in an outer scope. You're not actually modifying what `x` refers to in this case. On the other hand, assignment to `x` like `x = 1` is a "name binding operation" which always binds the name in the local namespace unless declared with the `global` or `nonlocal` keywords. – jodag Mar 13 '23 at 16:24
  • Reference https://docs.python.org/3/reference/executionmodel.html – jodag Mar 13 '23 at 16:25
  • Also, I believe python determines the scope of a name at compile time, so if you do any name binding operation to `x` anywhere in the function, then all instances of `x` will be assumed to be local even if the name binding occurs after it's first use. This means that if you add something like `x = Foo()` to the end of `f` in the first code example it would cause an exception to be thrown in the previous line. – jodag Mar 13 '23 at 16:30
0

You actually can. I suppose reason is in how pytorch processes math assignment operations. It will not create new local variable, but will modify object provided as a variable. You just need to provide that object as a variable to your function. But in my opinion this approach contradicts python rules and shouldn't be used.

>>> def fn(x):
        x+=1

>>> a = 0
>>> fn(a)
>>> a
0
>>> a = torch.tensor([0.])
>>> a
tensor([0.])
>>> fn(a)
>>> a
tensor([1.])