StopIteration encountered running MNLI
Hi there! I got an StopIteration when I was trying to follow the steps to run your code, scripts/run_glue.sh:
... previous messages hidden...
2024-05-13 16:17:00,172 [INFO]: module.classifier: Linear(in_features=768, out_features=3, bias=True)
Evaluating: 0%| | 0/614 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/cy/bit/quant_task_distill_glue.py", line 251, in <module>
main()
File "/home/cy/bit/quant_task_distill_glue.py", line 242, in main
learner.train(train_examples, task_name, output_mode, eval_labels,
File "/home/cy/bit/kd_learner_glue.py", line 152, in train
teacher_results = self._do_eval(self.teacher_model, task_name, eval_dataloader, output_mode, eval_labels, num_labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/bit/kd_learner_glue.py", line 95, in _do_eval
logits, _, _ = model(input_ids, segment_ids, input_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
output.reraise()
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/_utils.py", line 722, in reraise
raise exception
StopIteration: Caught StopIteration in replica 1 on device 1.
Original Traceback (most recent call last):
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/bit/transformer/modeling_bert.py", line 498, in forward
sequence_output, att_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/miniforge3/envs/pq/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cy/bit/transformer/modeling_bert.py", line 471, in forward
dtype=next(self.parameters()).dtype) # fp16 compatibility
^^^^^^^^^^^^^^^^^^^^^^^
StopIteration
It seems there is an error in the training routine. I used the provided pretrained full precision bert_base for MNLI and modified the paths for models and dataset accordingly. I suspect this might be due to a library version conflict since there are no more pytorch_model.bin file by default (instead, model.save_pretrained() gives a model.safetensors file.) Can the environment configuration file be provided to address this issue?
And since I'm using W1A1 as the config, I guess I can set it to binary to make it run, but I also want to get a clarification of what would be the reason to use the next generator value as the data type? If environment conflict is not the cause, it is crucial to resolve my question on this to make it to work.
Thank you.