Added code for Full fine tune
The code presented here is derived from the original lora.py file, with minimal modifications. The primary addition is the inclusion of full fine-tuning functionality, while preserving the core structure of the original code. This revised version offers a potential starting point for testing the training process on more powerful Mac devices.
Efforts were made to avoid altering any code within the tuner/* directory, ensuring that this update does not introduce any conflicts with the legacy codebase.
The code has been successfully tested on a Mac M2 Studio model with 192GB of memory, demonstrating its compatibility with high-performance hardware configurations.
Do you perform your full fine-tune in float32?
Fixed load/save function. fully tested with Phi-2 2.8B model.
- model file saving works well and resuming of model training works well also.
Do you perform your full fine-tune in float32?
no. I just copy code from mlx-lm/lora.py and fixed to run as full fine-tune. to make compatible original tuner/* code.
test example.
- M2 studio / 192GB
- Model: phi-2 2.8B
- Training set: chat completion / 400 items. (only iter set to 10, to show running demo)
$ python -m mlx_lm.full -c full_config.yaml
Loading configuration file full_config.yaml
Loading pretrained model
model file loaded: model.safetensors
Loading datasets
Training
Starting training..., iters: 10
Iter 1: Val loss 1.405, Val took 2.793s
Iter 5: Train loss 1.094, Learning Rate 1.000e-05, It/sec 0.956, Tokens/sec 779.347, Trained Tokens 4077, Peak mem 22.747 GB
Iter 10: Train loss 1.311, Learning Rate 1.000e-05, It/sec 0.668, Tokens/sec 589.665, Trained Tokens 8490, Peak mem 26.198 GB
Iter 10: Val loss 1.542, Val took 2.936s
Saved final adapter weights to adapter.npz.
Saved final model weights to model.safetensors.
Tried training qwen-1.8b. NaN loss immediately. Will try phi-2.
Tried training qwen-1.8b. NaN loss immediately. Will try phi-2.
when I tried Gemma-2b, same NaN loss. maybe, it's foundation code issue. maybe in models/* ? I didn't check.
Think its the float16.
Just checked - NaN w/ phi.
I also was receiving NaN using Qwen 14B against my dataset but couldn't reproduce with the test data in lora/data. Tried again with updates on main for both mlx/mlx_lm this morning and have reached 4K iterations so far w/out NaN's .
In the past it had been a float16 issue for me. I don't remember if I quantized this one at 32 or 16, but the config.json of the locally converted model has:
{
"architectures": [
"Qwen2ForCausalLM"
],
[..]
"quantization": {
"group_size": 64,
"bits": 4
},
[..]
"torch_dtype": "bfloat16",
[..]
"use_bfloat16": false,
}
I've opened an older issue (#620) regarding training error NaN values
This is cool, and I think it would be nice to support. We might be able to do it with a far smaller diff however. Something like:
- Have a training type field in the config
- If it's
full_fine_tunethen don't freeze the model / don't use LoRA layers
Everything else should be the same. Wdyt?
This landed in #903