4

In the torch.nn.Linear class (and other classes too), the forward method includes a @weak_script_method decorator as follows:

@weak_script_method
def forward(self, input):
    return F.linear(input, self.weight, self.bias)

What does this decorator do? Should I include it if I'm overriding the forward method in my own subclass of the Linear module?

dkv
  • 6,602
  • 10
  • 34
  • 54
  • 2
    `@weak_script_method` is for the [internal](https://github.com/pytorch/pytorch/issues/13221#issuecomment-434007127) use case in PyTorch. It's [used](https://github.com/pytorch/pytorch/blob/master/torch/_jit_internal.py#L84) to let the function be lazily compiled and inlined in the graph. I hope someone can explain it better. – kHarshit Feb 16 '19 at 05:45

1 Answers1

1

You can find the exact decorator location to get the idea.

def weak_script_method(fn):
    weak_script_methods[fn] = {
        "rcb": createResolutionCallback(frames_up=2),
        "original_method": fn
    }
return fn

But, you shouldn't need to worry about that decorator. This decorator is internal to JIT.

Technically method decorated with @weak_script_method will be added to the weak_script_methods dictionary created in front, like this:

weak_script_methods = weakref.WeakKeyDictionary() 

That dict tracks methods to avoid circular dependency problems; methods calling other methods while creating the PyTorch graph.


This really has no much sense unless you understand the concept of TorchScript in general.

The idea of TorchScript is to train models in PyTorch and export models to another non Python production environment (read:C++/C/Cuda) that support static typing.

PyTorch team made TorchScript on limited Python base to support static typing. By default, Python is dynamically typed language, but with few tricks (read:checks) it can become statically typed language.

And so TorchScript functions are statically-typed subset of Python that contains all of PyTorch's built-in Tensor operations. This difference allows TorchScript modules code to run without the need for a Python interpreter.

You can either convert the existing PyTorch methods to TorchScript using tracing (torch.jit.trace() method), or to create your TorchScripts by hand using @torch.jit.script decorator.

If you use tracing you will get a single class module at the end. Here is the example:

import inspect

import torch
def foo(x, y):
    return x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

print(type(traced_foo)) #<class 'torch.jit.TopLevelTracedModule'>
print(traced_foo) #foo()
print(traced_foo.forward) #<bound method TopLevelTracedModule.forward of foo()>

lines = inspect.getsource(traced_foo.forward)
print(lines)

Output:

<class 'torch.jit.TopLevelTracedModule'>
foo()
<bound method TopLevelTracedModule.forward of foo()>
    def forward(self, *args, **kwargs):
        return self._get_method('forward')(*args, **kwargs)

You can investigate further using the inspect module. This was just a showcase how to convert one function using tracing.

prosti
  • 42,291
  • 14
  • 186
  • 151