hiddenlayer icon indicating copy to clipboard operation
hiddenlayer copied to clipboard

Support dict input?

Open zhangjing1997 opened this issue 6 years ago • 7 comments

I got a runtime error when I tried to plot a model graph whose input x actually is a dict.

import hiddenlayer as hl

x = getBuff(0)
hl.build_graph(net, x)

So I'm wondering can this plot tool support dict input.

zhangjing1997 avatar Jun 17 '19 02:06 zhangjing1997

That is an issue with the backend this uses of pytorch.jit.get_trace_graph which does not support dict input. But these days more and more models are having dicts and this hl module is amazing for analysis maybe we should figure out a way to add this capability into this package. I can help with that.

Ridhwanluthra avatar Jun 17 '19 03:06 Ridhwanluthra

I have the same problem with dict input:

File "/home/snow_ripple/workspace/01_detection/mmdet/apis/inference.py", line 96, in _inference_single hl_graph = hl.build_graph(model, **data) TypeError: build_graph() got an unexpected keyword argument 'img'

Is there a workaround solution?

SnowRipple avatar Jun 20 '19 10:06 SnowRipple

That is an issue with the backend this uses of pytorch.jit.get_trace_graph which does not support dict input. But these days more and more models are having dicts and this hl module is amazing for analysis maybe we should figure out a way to add this capability into this package. I can help with that.

Yeah. I also do think a model graph tool supporting dict is very helpful, especially for this hl module, because it definitely helps understand and present large network visually. Looking forward to your contribution. Thanks!

zhangjing1997 avatar Jun 21 '19 03:06 zhangjing1997

I have the same problem with dict input:

File "/home/snow_ripple/workspace/01_detection/mmdet/apis/inference.py", line 96, in _inference_single hl_graph = hl.build_graph(model, **data) TypeError: build_graph() got an unexpected keyword argument 'img'

Is there a workaround solution?

Maybe you can try the code like the following:

test_buff = getBuff(0)

graph = make_dot(net(test_buff), params=dict(net.named_parameters()))
graph.format = 'pdf'
graph.render("visPDF")

This interface of plotting model seems to simply get the output after feeding into the network. So I just guess it may be helpful to you. BTW, credit to the AlexNet example on https://github.com/szagoruyko/pytorchviz/blob/master/examples.ipynb.

zhangjing1997 avatar Jun 21 '19 03:06 zhangjing1997

Thanks @zhangjing1997 !

What is the point of getBuff(0)? Which module have the definition of this function?

I assume it is just a placeholder so I modified it to: graph = make_dot(model(torch.zeros([1, 3, 224, 224])), params=dict(model.named_parameters()))

But it is still complaining about hte lack of other dictionary arguments.

SnowRipple avatar Jun 26 '19 09:06 SnowRipple

Thanks @zhangjing1997 !

What is the point of getBuff(0)? Which module have the definition of this function?

I assume it is just a placeholder so I modified it to: graph = make_dot(model(torch.zeros([1, 3, 224, 224])), params=dict(model.named_parameters()))

But it is still complaining about hte lack of other dictionary arguments.

In my project, getBuff(0) is just a function returning a dict as the input to the network. I think your code is good to run. But actually your model input seems to be a torch tensor, not input. If that's the case, I guess you don't need to use the previous way I mentioned and maybe you could refer to

import torch.onnx
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model.onnx")

Otherwise, if your model needs a dict input, can you show the specific error.

zhangjing1997 avatar Jun 27 '19 12:06 zhangjing1997

Hey, any progress on this? Anyone making a PR?

manesioz avatar Nov 05 '19 14:11 manesioz