dynamic
dynamic copied to clipboard
Trace fails sanity checks during export to lite for android : torch.jit.trace()
Code:
import cv2
import math
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
from torch.utils.mobile_optimizer import optimize_for_mobile
device=None
seed=0
lr=1e-4
weight_decay=1e-4
lr_step_period=15
model_name="r2plus1d_18"
weights = "file_path.pt"
np.random.seed(seed)
torch.manual_seed(seed)
if device is None:
print("in if cond")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.video.__dict__[model_name](pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 1)
model.fc.bias.data[0] = 55.6
if device.type == "cuda":
model = torch.nn.DataParallel(model)
model.to(device)
if weights is not None:
checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
summary(model, (3, 1, 112, 112))
example = torch.rand((1, 3, 1, 112, 112))
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("path_to_save.ptl")
The sanity check problem can be avoided by specifying check_trace= False But in that case the save_for_lite_interpreter fails:
RuntimeError: Could not export Python function call 'Scatter'. Remove calls to Python functions before export.
Using pytorch documentation to export the EF prediction model for Android
Is this as issue in the input shape? What should be the exact input shape for the first layer?