infer a statement
I used mt5-base pre-trained model to train my massive slot and intend model and I saved checkpoints in a directory. I want to infer one simple statement without creating test dataset or evaluation. Is there a sample code?
Hi @ehsanbarkhordar this sounds like a good feature. I'll check with the team to see if someone can pick it up. Alternatively, please feel free to post a PR. Thanks.
I wrote this piece of code and then I have no idea how to obtain intent_num and slots_num.
import argparse
import datetime
import logging
import os
import pickle
import sys
import datasets
import torch
import transformers
from ruamel.yaml import YAML
from massive import (
MASSIVETrainingArguments,
init_model,
init_tokenizer,
prepare_test_dataset,
read_conf,
)
from massive.models.mt5_ic_sf_encoder_only import MT5IntentClassSlotFillEncoderOnly
logger = logging.getLogger('massive_logger')
def main():
""" Run Testing/Inference """
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
args = argparse.Namespace(config='examples/mt5_base_enc_test_20220411.yml', local_rank=None)
# create the massive.Configuration master config object
conf = read_conf(args.config)
trainer_args = MASSIVETrainingArguments(**conf.get('test.trainer_args'))
if args.local_rank:
trainer_args.local_rank = int(args.local_rank)
# Setup logging
logging.basicConfig(
format="[%(levelname)s] %(asctime)s >> %(message)s",
datefmt="%H:%M",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = trainer_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
logger.info(f"Starting the run at {datetime.datetime.now()}")
yaml = YAML(typ='safe')
logger.info(f"Using the following config: {yaml.load(open(args.config, 'r'))}")
# Check for right setup
if not conf.get('test.predictions_file'):
logger.warning("Outputs will not be saved because no test.predictions_file was given")
if conf.get('test.predictions_file') and \
(conf.get('test.trainer_args.locale_eval_strategy') != 'all only'):
raise NotImplementedError("You must use 'all only' as the locale_eval_strategy if you"
" include a predictions file")
# Get all inputs to the trainer
tokenizer = init_tokenizer(conf)
test_ds, intents, slots = prepare_test_dataset(conf, tokenizer)
model: MT5IntentClassSlotFillEncoderOnly = init_model(conf, intents, slots)
article = "please delete bananas from my shopping list."
inputs = tokenizer(article, return_tensors="pt")
outputs = model.forward(
input_ids = inputs.input_ids,
attention_mask = inputs.attention_mask,
intent_num = None,
slots_num = None,
)
summary = tokenizer.decode(
outputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
print(summary)
Hi @ehsanbarkhordar , unfortunately this is still on our backlog, but hopefully we can get to it soon. Sorry for the wait. Thanks.