coremltools
coremltools copied to clipboard
Einsum doesn't handle spaces properly.
The following doesn't work as expected because it is interpreting spaces as valid characters. This is a problem when trying to convert models that we don't have direct control over. This sample also has the following issue: https://github.com/apple/coremltools/issues/1359. But I sidestepped that by modifying the build_einsum_mil and adding support for that on my local install. This doesn't in anyway effect this problem though.
import numpy as np
import torch
import coremltools as ct
class EinsumModule(torch.nn.Module):
def forward(self, x, y):
out = torch.einsum("i , j-> i j", x, y)
return out
model = EinsumModule()
i, j = 2, 3
x = torch.zeros((i,))
y = torch.zeros((j,))
traced_model = torch.jit.trace(model, (x, y))
input_types = [
ct.TensorType(name="x", dtype=np.float32, shape=x.shape),
ct.TensorType(name="y", dtype=np.float32, shape=y.shape),
]
mlmodel = ct.convert(traced_model, source="pytorch", inputs=input_types)
This should be as simple as changing parse_einsum_equation to include a
equation = equation.replace(" ", "")