-2

I ran across this function that mostly does what I want, but I need to tweak it a bit.

The data I have looks like this:

import torch
import torch.nn as nn
actual_x = torch.randn(13, 16, 64, 768)

But to work in the function below, I need to permute it to:

x = torch.randn(16, 64, 768, 13)

Inside the function, I cannot manipulate the value of *args. So if I wanted to add this line to reshape my data correctly inside the function: args[0] = args[0].permute(1, 2, 3, 0)

I get 'tuple' object does not support item assignment.

class TimeDistributed(nn.Module):
    '''
    '''
    def __init__(self):
        super(TimeDistributed, self).__init__()
        self.n_layers = 13
        self.n_tokens = 64
        self.module = torch.nn.Linear(self.n_layers, self.n_tokens)

    def forward(self, *args, **kwargs):
        #only support tdim=1
        #args[0] = args[0].permute(1, 2, 3, 0)
        args = list(args[0])
        args = args.permute(1, 2, 3, 0)
        inp_shape = args[0].shape
        bs, seq_len = inp_shape[0], inp_shape[1]
        out = self.module(*[x.reshape(bs*seq_len, *x.shape[2:]) for x in args], **kwargs)
        out_shape = out.shape
        return out.view(bs, seq_len,*out_shape[1:])

It runs by:

TD1 = TimeDistributed()
out = TD1(x)
out.shape

It fails by:

TD1 = TimeDistributed()
out = TD1(actual_x)
out.shape
John Stud
  • 1,506
  • 23
  • 46
  • Nowhere in class `TimeDistributed` do I see any reference to `torch.randn` nor `args[0] = args[0].permute(1, 2, 3, 0)`. Why is this class being shown then? – Booboo Jan 14 '21 at 20:11
  • I've edited it to show exactly where the problem is. – John Stud Jan 14 '21 at 20:13
  • `*args` is a tuple, you will need to convert it to a list to support item assignment – C.Nivs Jan 14 '21 at 20:21
  • 2
    There's nothing particularly special about `args`: it's a tuple, and tuples are immutable. – chepner Jan 14 '21 at 20:21
  • 1
    This makes even less sense. `TD1 = TimeDistributed()` is missing an argument and `out = TD1(x)` requires that class ` TimeDistributed` define method `__call__`. Or am I missing something? – Booboo Jan 14 '21 at 20:22

2 Answers2

1

I found googling:

a_list = ["a", "b", "c"]
order = [1, 0, 2]

a_list = [a_list[i] for i in order]

print(a_list)
OUTPUT
['b', 'a', 'c']

More here How can I reorder a list?

Barmar
  • 741,623
  • 53
  • 500
  • 612
pippo1980
  • 2,181
  • 3
  • 14
  • 30
1

Just convert args to a list instead of a tuple with args=list(args). Then you can reorder it how you please.

Matt Miguel
  • 1,325
  • 3
  • 6
  • Would ‘listarg = *args’ give a list ? – pippo1980 Jan 14 '21 at 21:16
  • No. The * used in that way is only valid when defining a function's inputs or when calling a function. When used as part of the argument list of a function, it signifies that any positional arguments (those provided without a keyword) will be collected into the input variable being denoted with the *. When calling a function, you can use it in front of a list or tuple to make it equivalent as if you passed each element of your list/tuple in as separate inputs to the function. You can't use the * in an assignment by itself, it's not a pointer dereference operator or anything like that. – Matt Miguel Jan 14 '21 at 21:23