0
class LSTMModel(nn.Module):
  def __init__(self,hidden_dim=80, input_dim=99,  n_layers=1):
    super(LSTMModel, self).__init__()
    self.input_dim = input_dim
    self.n_layers = n_layers
    self.hidden_dim = hidden_dim
    self.rnn = nn.LSTM(input_dim,hidden_dim, n_layers, batch_first=True)
    self.fc = nn.Linear(hidden_dim, 64)
    self.relu1  = nn.ReLU()
    self.fc1 = nn.Linear(64,20)
    self.relu2  = nn.ReLU()
    self.fc2 = nn.Linear(20,6) 
  def forward(self, input):
    h0 = torch.zeros(self.n_layers,  self.hidden_dim)
    c0 = torch.zeros(self.n_layers,  self.hidden_dim)
    out,( _,_) = self.rnn(input,(h0, c0))
    pred = self.fc(out)
    pred = self.relu1(pred)
    pred = self.fc1(pred)
    pred = self.relu2(pred)
    pred = self.fc2(pred)
    output = nn.Softmax( dim=1)(pred)
    return output

How to find the summary of this LSTM model?

I haven't been able to do anything regarding this query.

  • 1
    are you looking for something like [this](https://pypi.org/project/torch-summary/)? – Plagon May 22 '23 at 10:35
  • 1
    Does this answer your question? [How do I print the model summary in PyTorch?](https://stackoverflow.com/questions/42480111/how-do-i-print-the-model-summary-in-pytorch) – ndrwnaguib May 22 '23 at 15:41

0 Answers0