splade icon indicating copy to clipboard operation
splade copied to clipboard

Tutorial to export a SPLADE model to ONNX

Open ntnq4 opened this issue 2 years ago • 6 comments

Hello,

I trained a SPLADE model on my own recently. To reduce the inference time, I tried to export my model to ONNX with torch.onnx.export() but I encountered a few errors.

Is there a tutorial somewhere for this conversion?

ntnq4 avatar Dec 01 '23 13:12 ntnq4

Hi @ntnq4

Not that I am aware of. I am not super familiar with ONNX - did you manage to make it work?

thibault-formal avatar Dec 12 '23 14:12 thibault-formal

Hi @thibault-formal

I didn't manage to make it work unfortunately... I tried this tutorial but it didn't work for my SPLADE model.

I also found this recent paper that mentionned this conversion.

ntnq4 avatar Dec 12 '23 15:12 ntnq4

Hi @ntnq4 , I have managed to convert the splade models to onnx. Although I used the pretrained checkpoint. I am aware it is counterintuitive for you but nevertheless if this helps, I am glad. To reproduce:

  • Convert the model to a torchscript.

model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore

import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore

class TransformerRep(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore
        self.model.eval() # type: ignore
        self.fp16 = True

    def encode(self, input_ids, token_type_ids, attention_mask):
        # Tokens is a dict with keys input_ids and attention_mask
        return self.model(input_ids, token_type_ids, attention_mask)[0]



class SpladeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = TransformerRep()
        self.agg = "max"
        self.model.eval()
    
    def forward(self, input_ids,token_type_ids, attention_mask):
        with torch.cuda.amp.autocast(): # type: ignore
            with torch.no_grad():
                lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
                vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
                indices = vec.nonzero().squeeze()
                weights = vec.squeeze()[indices]
        return indices[:,1], weights[:,1]

# Convert the model to TorchScript
model = SpladeModel()
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
sample = "the capital of france is paris"
inputs = tokenizer(sample, return_tensors="pt")
traced_model = torch.jit.trace(model, (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]))
  • Later Load it from File and convert it using a dummy input. Make sure to adjust the above script to match your implementation.
import torch
dyn_axis = {
    'input_ids': {0: 'batch_size', 1: 'sequence'},
    'attention_mask': {0: 'batch_size', 1: 'sequence'},
    'token_type_ids': {0: 'batch_size', 1: 'sequence'},
    'indices': {0: 'batch_size', 1: 'sequence'},
    'weights': {0: 'batch_size', 1: 'sequence'}
    }
model = torch.jit.load(model_file)
onnx_model = torch.onnx.export(
    model,
    dummy_input, # type: ignore
    f=model_onnx_file,
    input_names=['input_ids','token_type_ids', 'attention_mask'],
    output_names=['indices', 'weights'],
    dynamic_axes=dyn_axis,
    do_constant_folding=True,
    opset_version=15,
    verbose=False,
)
  • Using this method I have managed to convert the following HF models successfully.
model_names= [
   "naver/splade_v2_max",
   "naver/splade_v2_distil",
   "naver/splade-cocondenser-ensembledistil",
   "naver/efficient-splade-VI-BT-large-query",
   "naver/efficient-splade-VI-BT-large-doc",
]

requirements:

  • torch==2.2.0

Hope this helps! :)

risan-raja avatar Feb 07 '24 07:02 risan-raja

Hi @risan-raja,

Thank you for your help : ) I will try your solution on my side.

ntnq4 avatar Feb 07 '24 09:02 ntnq4

if an ONNX conversion was added to HuggingFace in a folder called onnx then it would automatically become available to HuggingFace Transformers.js and be usable locally on the web.

sroussey avatar Feb 08 '24 17:02 sroussey

Example: https://huggingface.co/Xenova/t5-small-awesome-text-to-sql/tree/main/

sroussey avatar Feb 08 '24 17:02 sroussey