pytorch.github.io
pytorch.github.io copied to clipboard
Code example at https://pytorch.org/features/ is incorrect
📚 Documentation
The first code shown in https://pytorch.org/features is
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
# Compile the model code to a static representation
my_script_module = torch.jit.script(MyModule(3, 4))
# Save the compiled code and model data so it can be loaded elsewhere
my_script_module.save("my_script_module.pt")
This is incorrect Python as the last four lines should be outdented and not in the body of MyModule.
I can take this issue up. Please assign me this :))