FriedRiceLab icon indicating copy to clipboard operation
FriedRiceLab copied to clipboard

Issue on inference for converted to tflight ESWT-12-12_LSR_x4 model

Open koranten2 opened this issue 1 year ago • 2 comments

We tried to convert ESWT-12-12_LSR_x4.pth model from torch to tflight. We should use Flex tf ops as not all layers were converted initially but finally model was converted successfully without errors with tf ops. On inference we have an issue RuntimeError: tensorflow/lite/kernels/reshape.cc:92 num_input_elements != num_output_elements (0 != 8)Node number 0 (RESHAPE) failed to prepare.Node number 360 (IF) failed to prepare.

Please, tell if you tried conversion to tflight. Can you check this issue on your side? Please, note that intermediate tf model working well.

Model was convertion scheme pth -> onnx -> tf -> tflight Conversion script

    _import numpy as np
  
    import torch
    from basicsr.models import build_model
    from .utils import get_config
    import onnx
    import torchvision
    import onnx_tf
    import tensorflow as tf
    from onnx import helper
    def __init__(self, model_config_path, task_config_path, checkpoint_path):
        self.opt = get_config(model_config_path, task_config_path, checkpoint_path)
        self.device = torch.device('cpu')
        self.model = build_model(self.opt).net_g.to(self.device).to(torch.float32).eval()
        self.saveModel(self.model)_

    
    _def saveModel(self, model):
        modelName = "sr"
        input_shape = (1, 3, 256, 256)
        torch.onnx.export(model, torch.randn(input_shape), modelName + '-new.onnx', opset_version=12, input_names=['input'], output_names=['output'])
        onnx_model = onnx.load(modelName + '-new.onnx')
        # Convert ONNX model to TensorFlow format
        tf_model = onnx_tf.backend.prepare(onnx_model)
        # Export  TensorFlow  model
        tf_model.export_graph(modelName + '.tf')
        converter = tf.lite.TFLiteConverter.from_saved_model(modelName + '.tf')
        converter.target_spec.supported_ops = [
          tf.lite.OpsSet.TFLITE_BUILTINS,
          tf.lite.OpsSet.SELECT_TF_OPS
        ]
        tflite_model = converter.convert()
        open(modelName + '.tflite', 'wb').write(tflite_model)_

Some artifacts: sr.tf.zip

sr_12-12.tflight.zip

Our inference scripts

tf:

   _import tensorflow as tf
      import numpy as np
      from PIL import Image
      import PIL
      
      import torch
      import torchvision
      import torchvision.transforms as T
      
      def swapChannelsInput(input_tensor):
          input_tensor = input_tensor[tf.newaxis, ...]
      
          out = input_tensor.numpy()
          torchTensor = torch.from_numpy(out)
          torchTensor = torchTensor.permute(0, 3, 1, 2)
          np_arr = torchTensor.detach().cpu().numpy()
          tensorflow_tensor = tf.constant(np_arr)
          return tensorflow_tensor
      
      def showOutput(res):
      
          res = tf.squeeze(res)
      
          res = res.numpy()
          torchTensorRes = torch.from_numpy(res)
          torchTensorRes = torchTensorRes.permute(1, 2, 0)
          resFinal = torchTensorRes.detach().cpu().numpy()
          return PIL.Image.fromarray(resFinal.astype(np.uint8))
      
      extraction_path = "sr.tf/"
      test_image_path = "frame0.jpg"
      model = tf.saved_model.load(extraction_path)
      
      infer = model.signatures["serving_default"]
      
      image_np = np.array(Image.open(test_image_path))
      input_tensor = tf.convert_to_tensor(image_np, tf.float32)
      input_tensor = swapChannelsInput(input_tensor)
      
      res = infer(tf.constant(input_tensor))['output']
      showOutput(res).show()_

tflight:

        import tensorflow as tf
        import numpy as np
        import cv2
        
        from PIL import Image
        import PIL
        
        import torch
        import torchvision
        
        class TFLiteModel:
            def __init__(self, model_path: str):
                self.interpreter = tf.lite.Interpreter(model_path)
                self.interpreter.allocate_tensors()
        
                self.input_details = self.interpreter.get_input_details()
                self.output_details = self.interpreter.get_output_details()
        
            def predict(self, *data_args):
                assert len(data_args) == len(self.input_details)
                for data, details in zip(data_args, self.input_details):
                    self.interpreter.set_tensor(details["index"], data)
                self.interpreter.invoke()
                return self.interpreter.get_tensor(self.output_details[0]["index"])
        
        
        model = TFLiteModel("sr_12-12.tflite")
        
        test_image_path = "frame0.jpg"
        image_np = np.array(Image.open(test_image_path))
        input_tensor = tf.convert_to_tensor(image_np, tf.float32)
        
        input_tensor = input_tensor[tf.newaxis, ...]
        
        out = input_tensor.numpy()
        torchTensor = torch.from_numpy(out)
        torchTensor = torchTensor.permute(0, 3, 1, 2)
        np_arr = torchTensor.detach().cpu().numpy()
        tensorflow_tensor = tf.constant(np_arr)
        res = model.predict(tensorflow_tensor)[0]

koranten2 avatar Nov 15 '24 13:11 koranten2

Thanks a lot for checking out our work! I really appreciate it. But I’ve got to be honest — I only work with PyTorch and don’t have experience with TensorFlow or TFLite... Sorry I can’t be more helpful here :(

liozur avatar Nov 15 '24 14:11 liozur

I see, thank you for your work and feedback! If I found solution, I can write it here. It may be useful for those who will try to use given engine on mobile platform.

koranten2 avatar Nov 15 '24 16:11 koranten2