ms-swift icon indicating copy to clipboard operation
ms-swift copied to clipboard

DPO训练报错KeyError: 'prompt_input_ids'

Open JiaweiZhao-git opened this issue 1 year ago • 4 comments

按自定义数据格式,训练DPO在Map时报错 File "ms-swift/swift/trainers/dpo_trainer.py", line 114, in tokenize_row if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: KeyError: 'prompt_input_ids'

打印了下answer的key:dict_keys(['input_ids', 'attention_mask', 'prompt_inputs_embeds', 'prompt_attention_mask'])

训练代码: CUDA_VISIBLE_DEVICES=2
swift rlhf
--rlhf_type dpo
--model_type internvl2-4b
--model_id_or_path ./OpenGVLab/InternVL2-4B
--beta 0.1
--sft_beta 0.1
--sft_type lora
--dataset {custom_dataset_path}.jsonl
--num_train_epochs 2
--lora_target_modules DEFAULT
--gradient_checkpointing true
--batch_size 1
--learning_rate 5e-5
--gradient_accumulation_steps 16
--warmup_ratio 0.03
--save_total_limit 1

JiaweiZhao-git avatar Aug 14 '24 11:08 JiaweiZhao-git

先用2.3版本的吧, 我修一下

hjh0119 avatar Aug 14 '24 11:08 hjh0119

another problem : KeyError: 'prompt_pixel_values'

linzhenyuyuchen avatar Aug 15 '24 08:08 linzhenyuyuchen

Same error with only by changing the model to glm4v-9b-chat from llava1_6-mistral-7b-instruct in the first DPO example here.

CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
    --rlhf_type dpo \
    --model_type glm4v-9b-chat \
    --beta 0.1 \
    --sft_beta 0.1 \
    --sft_type lora \
    --dataset rlaif-v#1000 \
    --num_train_epochs 2 \
    --lora_target_modules DEFAULT \
    --gradient_checkpointing true \
    --batch_size 1 \
    --learning_rate 5e-5 \
    --gradient_accumulation_steps 16 \
    --warmup_ratio 0.03 \
    --save_total_limit 2

Lopa07 avatar Aug 15 '24 17:08 Lopa07

After changing the model, there are three batch keys are missing, prompt_input_ids, prompt_pixel_values, and prompt_image_sizes. And now there is an additional key prompt_images.

Lopa07 avatar Aug 15 '24 18:08 Lopa07

fixed

hjh0119 avatar Aug 16 '24 10:08 hjh0119

This update with the following command gave the following error:

CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
    --rlhf_type dpo \
    --model_type glm4v-9b-chat \
    --beta 0.1 \
    --sft_beta 0.1 \
    --sft_type lora \
    --dataset rlaif-v#1000 \
    --num_train_epochs 2 \
    --lora_target_modules DEFAULT \
    --gradient_checkpointing true \
    --batch_size 1 \
    --learning_rate 5e-5 \
    --gradient_accumulation_steps 16 \
    --warmup_ratio 0.03 \
    --save_total_limit 2
Parameter 'function'=<bound method DPOTrainer.tokenize_row of <swift.trainers.dpo_trainer.DPOTrainer object at 0x7f8b0d91ef10>> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
Map: 100%|████████████████████████████████████████████████████████████████| 990/990 [01:29<00:00, 11.09 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 10.40 examples/s]
[INFO:swift] Dataset Token Length: 119.152525±56.107634, min=18.000000, max=304.000000, size=990
[INFO:swift] Dataset Token Length: 102.500000±36.414969, min=60.000000, max=162.000000, size=10
[INFO:swift] The RLHFArguments will be saved in: /VDIL_COREML/m.banerjee/ms-swift/output/glm4v-9b-chat/v1-20240816-101916/sft_args.json
[INFO:swift] The Seq2SeqTrainingArguments will be saved in: /VDIL_COREML/m.banerjee/ms-swift/output/glm4v-9b-chat/v1-20240816-101916/training_args.json
[INFO:swift] The logging file will be saved in: /VDIL_COREML/m.banerjee/ms-swift/output/glm4v-9b-chat/v1-20240816-101916/logging.jsonl
Train:   0%|                                                                               | 0/122 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/VDIL_COREML/m.banerjee/ms-swift/swift/cli/rlhf.py", line 5, in <module>
    rlhf_main()
  File "/VDIL_COREML/m.banerjee/ms-swift/swift/utils/run_utils.py", line 32, in x_main
    result = llm_x(args, **kwargs)
  File "/VDIL_COREML/m.banerjee/ms-swift/swift/llm/rlhf.py", line 231, in llm_rlhf
    trainer.train(training_args.resume_from_checkpoint)
  File "/VDIL_COREML/m.banerjee/ms-swift/swift/trainers/dpo_trainer.py", line 63, in train
    res = super().train(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/ms-swift/swift/trainers/mixin.py", line 538, in train
    res = super().train(resume_from_checkpoint, *args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/transformers/trainer.py", line 1948, in train
    return inner_training_loop(
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/transformers/trainer.py", line 3328, in training_step
    loss = self.compute_loss(model, inputs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/trl/trainer/dpo_trainer.py", line 1408, in compute_loss
    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
  File "/VDIL_COREML/m.banerjee/ms-swift/swift/trainers/dpo_trainer.py", line 204, in get_batch_loss_metrics
    forward_output = self.concatenated_forward(model, batch)
  File "/VDIL_COREML/m.banerjee/ms-swift/swift/trainers/dpo_trainer.py", line 314, in concatenated_forward
    outputs = model(
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/peft/peft_model.py", line 1577, in forward
    return self.base_model(
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
    return self.model.forward(*args, **kwargs)
  File "/home/m.banerjee/.cache/huggingface/modules/transformers_modules/01328faefe122fe605c1c127b62e6031d3ffebf7/modeling_chatglm.py", line 1176, in forward
    transformer_outputs = self.transformer(
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/VDIL_COREML/m.banerjee/anaconda3/envs/swift/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/m.banerjee/.cache/huggingface/modules/transformers_modules/01328faefe122fe605c1c127b62e6031d3ffebf7/modeling_chatglm.py", line 1024, in forward
    (attention_mask[i, :boi_token_pos + 1], torch.ones(num_patches).to(attention_mask.device),
UnboundLocalError: local variable 'num_patches' referenced before assignment
Train:   0%|                                                                               | 0/122 [00:12<?, ?it/s]

Lopa07 avatar Aug 16 '24 17:08 Lopa07

Created the issue DPO training error UnboundLocalError: local variable 'num_patches' referenced before assignment #1734.

Lopa07 avatar Aug 16 '24 17:08 Lopa07

KeyError: 'prompt_input_ids' 这个报错现在有解决吗?

workmistm avatar Aug 19 '24 03:08 workmistm

KeyError: 'prompt_input_ids' 这个报错现在有解决吗?

update ms-swift to v2.3.1

hjh0119 avatar Aug 19 '24 07:08 hjh0119