CTranslate2
CTranslate2 copied to clipboard
Add support for HF transformers T5Model
Would really appreciate if cTranslate2 could support the transformers library T5Model, as this model is commonly used for translation tasks.
I've tried subclassing MarianMTLoader for this, however the T5Config is quite different, so it would require a loader of its own. Here's the loader I've got as work in progress:
@register_loader("T5Config")
class T5Loader(MarianMTLoader):
@property
def architecture_name(self):
return "T5Model"
def get_model_spec(self, model):
spec = transformer_spec.TransformerSpec(
model.config.num_layers,
model.config.num_heads,
pre_norm=False,
activation=_SUPPORTED_ACTIVATIONS["gelu"],
layernorm_embedding=getattr(model.config, "normalize_embedding", True),
)
spec.with_target_bos = False
self.set_encoder(spec.encoder, model.encoder)
self.set_decoder(spec.decoder, model.decoder)
self.set_linear(spec.decoder.projection, model.lm_head)
final_logits_bias = getattr(model, "final_logits_bias", None)
if final_logits_bias is not None and final_logits_bias.nonzero().numel() != 0:
spec.decoder.projection.bias = final_logits_bias.squeeze().numpy()
return spec
def set_decoder(self, spec, decoder):
spec.start_from_zero_embedding = True
super().set_decoder(spec, decoder)
The above errors in set_common_layers: AttributeError: 'T5Stack' object has no attribute 'embed_scale'
The T5 model also requires changes in the core C++ implementation. The model is using a different implementation for layer normalization and relative position embeddings.