bit icon indicating copy to clipboard operation
bit copied to clipboard

StopIteration encountered running MNLI

Open CanYing0913 opened this issue 1 year ago • 0 comments

Hi there! I got an StopIteration when I was trying to follow the steps to run your code, scripts/run_glue.sh:

... previous messages hidden...
2024-05-13 16:17:00,172 [INFO]: module.classifier: Linear(in_features=768, out_features=3, bias=True)
Evaluating:   0%|                                                                                                                         | 0/614 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/cy/bit/quant_task_distill_glue.py", line 251, in <module>
    main()
  File "/home/cy/bit/quant_task_distill_glue.py", line 242, in main
    learner.train(train_examples, task_name, output_mode, eval_labels,
  File "/home/cy/bit/kd_learner_glue.py", line 152, in train
    teacher_results = self._do_eval(self.teacher_model, task_name, eval_dataloader, output_mode, eval_labels, num_labels)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/bit/kd_learner_glue.py", line 95, in _do_eval
    logits, _, _ = model(input_ids, segment_ids, input_mask)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/_utils.py", line 722, in reraise
    raise exception
StopIteration: Caught StopIteration in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/bit/transformer/modeling_bert.py", line 498, in forward
    sequence_output, att_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cy/bit/transformer/modeling_bert.py", line 471, in forward
    dtype=next(self.parameters()).dtype)  # fp16 compatibility
          ^^^^^^^^^^^^^^^^^^^^^^^
StopIteration

It seems there is an error in the training routine. I used the provided pretrained full precision bert_base for MNLI and modified the paths for models and dataset accordingly. I suspect this might be due to a library version conflict since there are no more pytorch_model.bin file by default (instead, model.save_pretrained() gives a model.safetensors file.) Can the environment configuration file be provided to address this issue?

And since I'm using W1A1 as the config, I guess I can set it to binary to make it run, but I also want to get a clarification of what would be the reason to use the next generator value as the data type? If environment conflict is not the cause, it is crucial to resolve my question on this to make it to work.

Thank you.

CanYing0913 avatar May 13 '24 23:05 CanYing0913