class LSTMModel(nn.Module):
def __init__(self,hidden_dim=80, input_dim=99, n_layers=1):
super(LSTMModel, self).__init__()
self.input_dim = input_dim
self.n_layers = n_layers
self.hidden_dim = hidden_dim
self.rnn = nn.LSTM(input_dim,hidden_dim, n_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, 64)
self.relu1 = nn.ReLU()
self.fc1 = nn.Linear(64,20)
self.relu2 = nn.ReLU()
self.fc2 = nn.Linear(20,6)
def forward(self, input):
h0 = torch.zeros(self.n_layers, self.hidden_dim)
c0 = torch.zeros(self.n_layers, self.hidden_dim)
out,( _,_) = self.rnn(input,(h0, c0))
pred = self.fc(out)
pred = self.relu1(pred)
pred = self.fc1(pred)
pred = self.relu2(pred)
pred = self.fc2(pred)
output = nn.Softmax( dim=1)(pred)
return output
How to find the summary of this LSTM model?
I haven't been able to do anything regarding this query.