Issue on inference for converted to tflight ESWT-12-12_LSR_x4 model
We tried to convert ESWT-12-12_LSR_x4.pth model from torch to tflight. We should use Flex tf ops as not all layers were converted initially but finally model was converted successfully without errors with tf ops. On inference we have an issue RuntimeError: tensorflow/lite/kernels/reshape.cc:92 num_input_elements != num_output_elements (0 != 8)Node number 0 (RESHAPE) failed to prepare.Node number 360 (IF) failed to prepare.
Please, tell if you tried conversion to tflight. Can you check this issue on your side? Please, note that intermediate tf model working well.
Model was convertion scheme pth -> onnx -> tf -> tflight Conversion script
_import numpy as np
import torch
from basicsr.models import build_model
from .utils import get_config
import onnx
import torchvision
import onnx_tf
import tensorflow as tf
from onnx import helper
def __init__(self, model_config_path, task_config_path, checkpoint_path):
self.opt = get_config(model_config_path, task_config_path, checkpoint_path)
self.device = torch.device('cpu')
self.model = build_model(self.opt).net_g.to(self.device).to(torch.float32).eval()
self.saveModel(self.model)_
_def saveModel(self, model):
modelName = "sr"
input_shape = (1, 3, 256, 256)
torch.onnx.export(model, torch.randn(input_shape), modelName + '-new.onnx', opset_version=12, input_names=['input'], output_names=['output'])
onnx_model = onnx.load(modelName + '-new.onnx')
# Convert ONNX model to TensorFlow format
tf_model = onnx_tf.backend.prepare(onnx_model)
# Export TensorFlow model
tf_model.export_graph(modelName + '.tf')
converter = tf.lite.TFLiteConverter.from_saved_model(modelName + '.tf')
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
open(modelName + '.tflite', 'wb').write(tflite_model)_
Some artifacts: sr.tf.zip
Our inference scripts
tf:
_import tensorflow as tf
import numpy as np
from PIL import Image
import PIL
import torch
import torchvision
import torchvision.transforms as T
def swapChannelsInput(input_tensor):
input_tensor = input_tensor[tf.newaxis, ...]
out = input_tensor.numpy()
torchTensor = torch.from_numpy(out)
torchTensor = torchTensor.permute(0, 3, 1, 2)
np_arr = torchTensor.detach().cpu().numpy()
tensorflow_tensor = tf.constant(np_arr)
return tensorflow_tensor
def showOutput(res):
res = tf.squeeze(res)
res = res.numpy()
torchTensorRes = torch.from_numpy(res)
torchTensorRes = torchTensorRes.permute(1, 2, 0)
resFinal = torchTensorRes.detach().cpu().numpy()
return PIL.Image.fromarray(resFinal.astype(np.uint8))
extraction_path = "sr.tf/"
test_image_path = "frame0.jpg"
model = tf.saved_model.load(extraction_path)
infer = model.signatures["serving_default"]
image_np = np.array(Image.open(test_image_path))
input_tensor = tf.convert_to_tensor(image_np, tf.float32)
input_tensor = swapChannelsInput(input_tensor)
res = infer(tf.constant(input_tensor))['output']
showOutput(res).show()_
tflight:
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import PIL
import torch
import torchvision
class TFLiteModel:
def __init__(self, model_path: str):
self.interpreter = tf.lite.Interpreter(model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def predict(self, *data_args):
assert len(data_args) == len(self.input_details)
for data, details in zip(data_args, self.input_details):
self.interpreter.set_tensor(details["index"], data)
self.interpreter.invoke()
return self.interpreter.get_tensor(self.output_details[0]["index"])
model = TFLiteModel("sr_12-12.tflite")
test_image_path = "frame0.jpg"
image_np = np.array(Image.open(test_image_path))
input_tensor = tf.convert_to_tensor(image_np, tf.float32)
input_tensor = input_tensor[tf.newaxis, ...]
out = input_tensor.numpy()
torchTensor = torch.from_numpy(out)
torchTensor = torchTensor.permute(0, 3, 1, 2)
np_arr = torchTensor.detach().cpu().numpy()
tensorflow_tensor = tf.constant(np_arr)
res = model.predict(tensorflow_tensor)[0]
Thanks a lot for checking out our work! I really appreciate it. But I’ve got to be honest — I only work with PyTorch and don’t have experience with TensorFlow or TFLite... Sorry I can’t be more helpful here :(
I see, thank you for your work and feedback! If I found solution, I can write it here. It may be useful for those who will try to use given engine on mobile platform.