systems icon indicating copy to clipboard operation
systems copied to clipboard

[QST] Deploy a Transformers4rec model with pre-trained embeddings

Open mvidela31 opened this issue 1 year ago • 0 comments

❓ Questions & Help

Hi everyone,

I tried to deploy a Transformers4rec model using pre-trained embedding following the Transformers4rec with pre-trained embeddings example and the transformers-next-item-prediction-with-pretrained-embeddings.ipynb (for Tensorflow Merlin-models). However, it seems to be problems to trace the PyTorch model with pre-trained embeddings.

Details

Based on the above examples, I made the following example:

data = tr.data.music_streaming_testing_data
schema = data.merlin_schema.select_by_name([
    "item_id",
    "item_category",
    "item_recency",
    "item_genres",
])

batch_size, max_length, pretrained_dim = 128, 20, 16

item_cardinality = schema["item_id"].int_domain.max + 1
np_emb_item_id = np.random.rand(item_cardinality, pretrained_dim)
embeddings_op = EmbeddingOperator(
    np_emb_item_id, lookup_key="item_id", embedding_name="pretrained_item_id_embeddings"
)

# set dataloader with pre-trained embeddings
data_loader = MerlinDataLoader.from_schema(
    schema,
    Dataset(data.path, schema=schema),
    max_sequence_length=max_length,
    batch_size=batch_size,
    transforms=[embeddings_op],
    shuffle=False,
)

# set the model schema from data-loader
model_schema = data_loader.output_schema
inputs = tr.TabularSequenceFeatures.from_schema(
    model_schema,
    max_sequence_length=max_length,
    pretrained_output_dims=8,
    normalizer="layer-norm",
    d_output=64,
    masking="mlm",
)
transformer_config = tr.XLNetConfig.build(64, 4, 2, 20)
task = tr.NextItemPredictionTask(weight_tying=True)
model = transformer_config.to_torch_model(inputs, task, max_sequence_length=max_length)

args = T4RecTrainingArguments(
    output_dir=".",
    max_steps=5,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size // 2,
    max_sequence_length=max_length,
    fp16=False,
    report_to=[],
    debug=["r"],
)

# Explicitly pass the merlin dataloader with pre-trained embeddings
recsys_trainer = Trainer(
    model=model,
    args=args,
    schema=schema,
    train_dataloader=data_loader,
    eval_dataloader=data_loader,
    compute_metrics=True,
)

recsys_trainer.train()
eval_metrics = recsys_trainer.evaluate(eval_dataset=data.path, metric_key_prefix="eval")

### Model export
topk = 20
model.top_k = topk
model.eval()

df = cudf.read_parquet(data.path, columns=model.input_schema.column_names)
table = TensorTable.from_df(df.loc[:10])
for column in table.columns:
    table[column] = convert_col(table[column], TorchColumn)
model_input_dict = table.to_dict()

traced_model = torch.jit.trace(model, model_input_dict, strict=True)
input_schema = model.input_schema
output_schema = model.output_schema

torch_op = schema.column_names >> embeddings_op >> PredictPyTorch(
    traced_model, input_schema, output_schema
)

ensemble = Ensemble(torch_op, schema)
ens_config, node_configs = ensemble.export(".")

As you can see below, a matrix shape mismatch error raises when tried to trace the PyTorch model:

Traceback (most recent call last):
  File "/opt/ml/code/train.py", line 899, in test_trainer_with_pretrained_embeddings
    traced_model = torch.jit.trace(model, model_input_dict, strict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 581, in forward
    head(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 382, in forward
    body_outputs = self.body(body_outputs, training=training, testing=testing, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 256, in forward
    input = module(input, training=training, testing=testing)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 392, in __call__
    outputs = super().__call__(inputs, *args, **kwargs)  # noqa
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/sequence.py", line 259, in forward
    outputs = self.projection_module(outputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 252, in forward
    input = module(input, **filtered_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 260, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (220x128 and 136x64)

It seems that the torch.jit.trace() function can't recognize the pre-trained embeddings provided by the dataloader.

Do you have any suggestion on how to deploy a Transformers4rec model with pre-trained embeddings on Triton Inference Server?

Thanks for your amazing work!

mvidela31 avatar Jan 10 '25 20:01 mvidela31