mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Added code for Full fine tune

Open ziozzang opened this issue 1 year ago • 10 comments

The code presented here is derived from the original lora.py file, with minimal modifications. The primary addition is the inclusion of full fine-tuning functionality, while preserving the core structure of the original code. This revised version offers a potential starting point for testing the training process on more powerful Mac devices.

Efforts were made to avoid altering any code within the tuner/* directory, ensuring that this update does not introduce any conflicts with the legacy codebase.

The code has been successfully tested on a Mac M2 Studio model with 192GB of memory, demonstrating its compatibility with high-performance hardware configurations.

ziozzang avatar Apr 03 '24 07:04 ziozzang

Do you perform your full fine-tune in float32?

N8python avatar Apr 04 '24 03:04 N8python

Fixed load/save function. fully tested with Phi-2 2.8B model.

  • model file saving works well and resuming of model training works well also.

Do you perform your full fine-tune in float32?

no. I just copy code from mlx-lm/lora.py and fixed to run as full fine-tune. to make compatible original tuner/* code.

ziozzang avatar Apr 04 '24 05:04 ziozzang

test example.

  • M2 studio / 192GB
  • Model: phi-2 2.8B
  • Training set: chat completion / 400 items. (only iter set to 10, to show running demo)
$ python -m mlx_lm.full -c full_config.yaml

Loading configuration file full_config.yaml
Loading pretrained model
model file loaded: model.safetensors
Loading datasets
Training
Starting training..., iters: 10
Iter 1: Val loss 1.405, Val took 2.793s
Iter 5: Train loss 1.094, Learning Rate 1.000e-05, It/sec 0.956, Tokens/sec 779.347, Trained Tokens 4077, Peak mem 22.747 GB
Iter 10: Train loss 1.311, Learning Rate 1.000e-05, It/sec 0.668, Tokens/sec 589.665, Trained Tokens 8490, Peak mem 26.198 GB
Iter 10: Val loss 1.542, Val took 2.936s
Saved final adapter weights to adapter.npz.
Saved final model weights to model.safetensors.

ziozzang avatar Apr 04 '24 05:04 ziozzang

Tried training qwen-1.8b. NaN loss immediately. Will try phi-2.

N8python avatar Apr 05 '24 03:04 N8python

Tried training qwen-1.8b. NaN loss immediately. Will try phi-2.

when I tried Gemma-2b, same NaN loss. maybe, it's foundation code issue. maybe in models/* ? I didn't check.

ziozzang avatar Apr 05 '24 03:04 ziozzang

Think its the float16.

N8python avatar Apr 05 '24 03:04 N8python

Just checked - NaN w/ phi.

N8python avatar Apr 05 '24 03:04 N8python

I also was receiving NaN using Qwen 14B against my dataset but couldn't reproduce with the test data in lora/data. Tried again with updates on main for both mlx/mlx_lm this morning and have reached 4K iterations so far w/out NaN's .

In the past it had been a float16 issue for me. I don't remember if I quantized this one at 32 or 16, but the config.json of the locally converted model has:

{
    "architectures": [
        "Qwen2ForCausalLM"
    ],
    [..]
    "quantization": {
        "group_size": 64,
        "bits": 4
    },
    [..]
    "torch_dtype": "bfloat16",
    [..]
    "use_bfloat16": false,
}

chimezie avatar Apr 05 '24 13:04 chimezie

I've opened an older issue (#620) regarding training error NaN values

chimezie avatar Apr 09 '24 16:04 chimezie

This is cool, and I think it would be nice to support. We might be able to do it with a far smaller diff however. Something like:

  • Have a training type field in the config
  • If it's full_fine_tune then don't freeze the model / don't use LoRA layers

Everything else should be the same. Wdyt?

awni avatar Apr 17 '24 03:04 awni

This landed in #903

awni avatar Sep 30 '24 19:09 awni