fix unet script problem
when running python run_sweep.py -m pytorch_unet -t eval -d cuda --jit,
this would raise error AttributeError: 'RecursiveScriptModule' object has no attribute 'n_classes'
add a final mark to keep this attibute when script
Can you please also remove the not_implemented field in pytorch_unet's metadata.yaml? This will enable jit unit test for this model. See https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/pytorch_unet/metadata.yaml#L7
OK. BTW, @xuzhao9 can we move these eval function into model so that trace/script could trace the whole model, without changing model semantic. Just like https://github.com/pytorch/benchmark/commit/5b07f357eaed06ccda9e7283f838c11228755229 . I think this is nesssary for script and script-based backends.
@zzpmiracle The eval function code was directly copied from upstream: https://github.com/milesial/Pytorch-UNet/blob/master/evaluate.py, so the change in https://github.com/pytorch/benchmark/commit/5b07f357eaed06ccda9e7283f838c11228755229 will change the original code behavior.
@davidberard98 do you know if this is the only way for torchscript to trace the whole model? Is there a better solution? If there is no better way to do this, I am also okay to accept it.
@xuzhao9 @zzpmiracle so, a few comments:
- https://pytorch.org/docs/stable/jit.html#attributes is actually outdated, typically you will not need to annotate to preserve parameters. See below for an example.
- In this case, we do need to take extra steps to preserve because of optimize_for_inference (specifically, torch.jit.freeze) which inlines the constant. The preferred way to preserve parameters like this is to use
torch.jit.optimize_for_inference(torch.jit.freeze(model.eval(), preserved_attrs=["n_classes", ...] , ...)). So if possible it would be nice to add a way to set this property in cases like this. However, marking withFinal[]also seems to work (AFAIK this works because handling constants just isn't implemented for freezing).
Example for #1:
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.x = 2
def forward(self, y):
return self.x
# demo outdated docs, you can see that this doesn't error out.
m = Model()
# to repro the failure, do m = torch.jit.freeze(m.eval())
print(m.x)
@davidberard98 OK, I got this, thanks! How about the other question, that to change the code to trace the whole evaluation model?
@davidberard98 OK, I got this, thanks! How about the other question, that to change the code to trace the whole evaluation model?
I am trying to understand why this is necessary for script backends? Without tracing the whole computation, users can still get the result from script-ed model and un-scripted evaluation code?
Yes, it could run, but some code run in script mode while other code run in eager mode, this could not get the true performance of the jit or some other backend.
@zzpmiracle Actually I believe it should be the opposite - the upstream model code doesn't have all the code running in script code (see https://github.com/milesial/Pytorch-UNet/blob/master/evaluate.py#L26). Therefore, the current impl in torchbench reflects the true performance of jit or other backends, unless upstream accepts patch that moves all the code into the model like https://github.com/pytorch/benchmark/commit/5b07f357eaed06ccda9e7283f838c11228755229. What do you think?
OK, I implement jit_callback function to enable jit now.
I test use torch.jit.optimize_for_inference(torch.jit.freeze(torch.jit.script(self.model.eval()), preserved_attrs=["n_classes"])) would have a bad performance(39.47ms) than use torch.jit.freeze(torch.jit.script(self.model.eval()), preserved_attrs=["n_classes"]), which only spend 23.88ms in A10.
unless upstream accepts patch that moves all the code into the model like https://github.com/pytorch/benchmark/commit/5b07f357eaed06ccda9e7283f838c11228755229
OK, let's keep it now.
@xuzhao9 @davidberard98 thanks for you patient explanation, I learned a lot.
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.