massive icon indicating copy to clipboard operation
massive copied to clipboard

infer a statement

Open ebarkhordar opened this issue 3 years ago • 3 comments

I used mt5-base pre-trained model to train my massive slot and intend model and I saved checkpoints in a directory. I want to infer one simple statement without creating test dataset or evaluation. Is there a sample code?

ebarkhordar avatar Jun 15 '22 21:06 ebarkhordar

Hi @ehsanbarkhordar this sounds like a good feature. I'll check with the team to see if someone can pick it up. Alternatively, please feel free to post a PR. Thanks.

jgmf-amazon avatar Jun 16 '22 14:06 jgmf-amazon

I wrote this piece of code and then I have no idea how to obtain intent_num and slots_num.

import argparse
import datetime
import logging
import os
import pickle
import sys

import datasets
import torch
import transformers
from ruamel.yaml import YAML

from massive import (
    MASSIVETrainingArguments,
    init_model,
    init_tokenizer,
    prepare_test_dataset,
    read_conf,
)
from massive.models.mt5_ic_sf_encoder_only import MT5IntentClassSlotFillEncoderOnly

logger = logging.getLogger('massive_logger')


def main():
    """ Run Testing/Inference """
    device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    args = argparse.Namespace(config='examples/mt5_base_enc_test_20220411.yml', local_rank=None)

    # create the massive.Configuration master config object
    conf = read_conf(args.config)
    trainer_args = MASSIVETrainingArguments(**conf.get('test.trainer_args'))
    if args.local_rank:
        trainer_args.local_rank = int(args.local_rank)

    # Setup logging
    logging.basicConfig(
        format="[%(levelname)s] %(asctime)s >> %(message)s",
        datefmt="%H:%M",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = trainer_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    logger.info(f"Starting the run at {datetime.datetime.now()}")
    yaml = YAML(typ='safe')
    logger.info(f"Using the following config: {yaml.load(open(args.config, 'r'))}")

    # Check for right setup
    if not conf.get('test.predictions_file'):
        logger.warning("Outputs will not be saved because no test.predictions_file was given")
    if conf.get('test.predictions_file') and \
            (conf.get('test.trainer_args.locale_eval_strategy') != 'all only'):
        raise NotImplementedError("You must use 'all only' as the locale_eval_strategy if you"
                                  " include a predictions file")

    # Get all inputs to the trainer
    tokenizer = init_tokenizer(conf)
    test_ds, intents, slots = prepare_test_dataset(conf, tokenizer)
    model: MT5IntentClassSlotFillEncoderOnly = init_model(conf, intents, slots)

    article = "please delete bananas from my shopping list."

    inputs = tokenizer(article, return_tensors="pt")

    outputs = model.forward(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
        intent_num = None,
        slots_num = None,
    )
    summary = tokenizer.decode(
        outputs[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )
    print(summary)

ebarkhordar avatar Jun 20 '22 12:06 ebarkhordar

Hi @ehsanbarkhordar , unfortunately this is still on our backlog, but hopefully we can get to it soon. Sorry for the wait. Thanks.

jgmf-amazon avatar Jun 22 '22 20:06 jgmf-amazon