I ran across this function that mostly does what I want, but I need to tweak it a bit.
The data I have looks like this:
import torch
import torch.nn as nn
actual_x = torch.randn(13, 16, 64, 768)
But to work in the function below, I need to permute it to:
x = torch.randn(16, 64, 768, 13)
Inside the function, I cannot manipulate the value of *args
. So if I wanted to add this line to reshape my data correctly inside the function: args[0] = args[0].permute(1, 2, 3, 0)
I get 'tuple' object does not support item assignment
.
class TimeDistributed(nn.Module):
'''
'''
def __init__(self):
super(TimeDistributed, self).__init__()
self.n_layers = 13
self.n_tokens = 64
self.module = torch.nn.Linear(self.n_layers, self.n_tokens)
def forward(self, *args, **kwargs):
#only support tdim=1
#args[0] = args[0].permute(1, 2, 3, 0)
args = list(args[0])
args = args.permute(1, 2, 3, 0)
inp_shape = args[0].shape
bs, seq_len = inp_shape[0], inp_shape[1]
out = self.module(*[x.reshape(bs*seq_len, *x.shape[2:]) for x in args], **kwargs)
out_shape = out.shape
return out.view(bs, seq_len,*out_shape[1:])
It runs by:
TD1 = TimeDistributed()
out = TD1(x)
out.shape
It fails by:
TD1 = TimeDistributed()
out = TD1(actual_x)
out.shape