hiddenlayer icon indicating copy to clipboard operation
hiddenlayer copied to clipboard

RNN can not display correctly

Open AutuanLiu opened this issue 7 years ago • 4 comments

  • Here is a simple example of LSTM neural network. image

We need hl.transforms.Rename() to rename the RNN node.

tsfm = [hl.transforms.Rename(op='prim::PythonOp', to = 'LSTM')]

image

AutuanLiu avatar Nov 07 '18 07:11 AutuanLiu

Thanks for the report. Admittedly, I haven't tested with RNNs. Would you mind sharing the sample code used to generate this?

waleedka avatar Nov 08 '18 09:11 waleedka

  1. RNN(LSTM) code:
import torch
import hiddenlayer as hl
from torch import nn

class RNN_Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        hidden = self.initHidden(x.size(0))
        y, _ = self.rnn(x, hidden)
        return self.fc(y[:, -1, :])

    def initHidden(self, batchsize):
        weight = next(self.parameters())
        h0 = weight.new_zeros(self.num_layers, batchsize, self.hidden_dim)
        return (h0, h0)

model = RNN_Net(5, 15, 5)
hl.build_graph(model, torch.zeros([32, 20, 5]))
  1. The graph of neural network with RNN(LSTM). image

  2. Warning

UserWarning: ONNX export failed on RNN/GRU/LSTM because batch_first not supported
  warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
  1. LSTM can be rendered when batch_first=False.

AutuanLiu avatar Nov 08 '18 12:11 AutuanLiu

@AutuanLiu I want to give an update on this. Thank you for raising the issue and providing details. And sorry it took a long time.

I ran experiments with LSTMs and tried to find a simple pattern that makes it easy to render them. Unfortunately, I couldn't find a simple solution. And the method you used works in a subset of the cases only. Also, PythonOp can represent other operations.

So I think your solution is useful for your case, but shouldn't be built into the library as it might cause unintended problems for other network types. At this point, I'm afraid I don't have a good solution for LSTMs. I hope I manage to find some free time in the near feature to dive deeper into how PyTorch represents LSTMs and find a general solution that works in all (or most) cases.

waleedka avatar Dec 03 '18 04:12 waleedka

Thank you!

AutuanLiu avatar Dec 04 '18 06:12 AutuanLiu