mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

[Feature Request] LLaVA 1.6 LoRA fine-tuning example

Open tctrautman opened this issue 1 year ago • 17 comments

Building on the amazing work by @mzbac and @nkasmanoff in https://github.com/ml-explore/mlx-examples/pull/461, I'd really love an example of how LLaVA 1.6 (aka llava next) can be fine-tuned with a LoRA.

I might be able to make progress on this myself, but it'll take me some time. Any help or thoughts on how to best approach this would be appreciated. (Especially from @mzbac and/or @nkasmanoff.)

tctrautman avatar Mar 21 '24 21:03 tctrautman

I was waiting for the Llava 1.6 support to be merged into the transformer so we can have consistent model weight naming conventions. I haven't looked at the details of it, but the most challenging part would be in https://github.com/ml-explore/mlx-examples/blob/main/llava/llava.py#L104. Do you mind starting a draft PR for it? I am happy to help in any way if you get stuck on the implementation.

mzbac avatar Mar 22 '24 03:03 mzbac

@mzbac sounds good! I'll take a look over the weekend and see how it goes -- a lot of this is going to be pretty new for me.

Quick question as I'm getting started: do you expect it to be complex to support both the Mistral and Vicuna models?

I ask because I see the line below where we've explicitly said we only support the Llama language model.

https://github.com/ml-explore/mlx-examples/blob/fbed720d6f9ac7c854c1fdd8d9954fa1ba47691c/llava/language.py#L210

I'm personally most interested in getting the Mistral 7B version working, but I'd be open to working with Vicuna if you expect that to be a better place to get started.

tctrautman avatar Mar 22 '24 05:03 tctrautman

The Mistral is using the same Llama architecture. If the merge image feature with input ids functions properly, the language model part should just work out of box. You can take a look at the model weight here -> https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf Also, check out this ASCII diagram I created for the llava 1.5 model: https://gist.github.com/mzbac/00ebe60bb36fa4d8f65509f8e47350d5

mzbac avatar Mar 22 '24 06:03 mzbac

I couldn't get 1.6 to work either as it said that only llama and not mistral models were supported.

On a further test, trying llava-hf/llava-v1.6-34b-hf, I got:

preprocessor_config.json: 100%|█████████████████████████████████████████████████████| 754/754 [00:00<00:00, 958kB/s]
tokenizer_config.json: 100%|███████████████████████████████████████████████████| 1.86k/1.86k [00:00<00:00, 4.99MB/s]
tokenizer.model: 100%|█████████████████████████████████████████████████████████| 1.03M/1.03M [00:00<00:00, 5.56MB/s]
added_tokens.json: 100%|█████████████████████████████████████████████████████████| 23.0/23.0 [00:00<00:00, 81.9kB/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████| 748/748 [00:00<00:00, 3.56MB/s]
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
config.json: 100%|█████████████████████████████████████████████████████████████| 1.41k/1.41k [00:00<00:00, 6.51MB/s]
Traceback (most recent call last):
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/mytest.py", line 3, in <module>
    processor, model = load_model("llava-hf/llava-v1.6-34b-hf")
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/generate.py", line 83, in load_model
    processor = AutoProcessor.from_pretrained(model_path)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/mlx/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py", line 340, in from_pretrained
    raise ValueError(
ValueError: Unrecognized processing class in llava-hf/llava-v1.6-34b-hf. Can't instantiate a processor, a tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains the files of at least one of those processing classes.

... updating the transformers library and installing protobuf seems to be more promising ...

jrp2014 avatar Mar 22 '24 22:03 jrp2014

@jrp2014 I ran into a similar processor error before I realized I hadn't updated HF transformers -- updating to 4.39.1 resolved it on my end.

tctrautman avatar Mar 23 '24 01:03 tctrautman

I now get

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Fetching 23 files: 100%|█████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 28676.87it/s]
Traceback (most recent call last):
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/mytest.py", line 3, in <module>
    processor, model = load_model("llava-hf/llava-v1.6-34b-hf")
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/generate.py", line 84, in load_model
    model = LlavaModel.from_pretrained(model_path)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/llava.py", line 178, in from_pretrained
    model.load_weights(list(weights.items()))
  File "/opt/homebrew/Caskroom/miniconda/base/envs/mlx/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 203, in load_weights
    raise ValueError(f"Received parameters not in model: {extras}.")
ValueError: Received parameters not in model: image_newline.
python mytest.py  1.41s user 3.73s system 171% cpu 3.002 total

I'll have a look to see whether there are any obvious fixes.

jrp2014 avatar Mar 23 '24 11:03 jrp2014

Hey great initiative! Excited to try out these new and improved LlaVA models. Agreed that using the standard transformers format makes it easier.

I've been thinking a bit about the fine tuning aspect, and can share some code snippets I have so far which might help the discussion.

There's a ton of variety in fine-tuning the vision models, but to me easiest direction would be starting with solely fine-tuning the language model. If you wanted to do just that I am following the code from https://github.com/ml-explore/mlx-examples/tree/main/llms, where at least to start, you can attach LoRA layers to the llm. Here's how I have that for now.

model.vision_tower.freeze()
model.multi_modal_projector.freeze()
model.language_model.freeze()

lora_layers = 32
lora_parameters = {
    "keys": ["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.out_proj"],
    "rank": 64,
    "alpha": 16.0,
    "scale": 10.0,
    "dropout": 0.0,
}

linear_to_lora_layers(model.language_model.model, lora_layers, lora_parameters)

print_trainable_parameters(model.vision_tower)
print_trainable_parameters(model.multi_modal_projector)
print_trainable_parameters(model.language_model)

# Trainable parameters: 0.000% (0.000M/303.506M)
# Trainable parameters: 0.000% (0.000M/20.980M)
# Trainable parameters: 0.747% (50.332M/6738.940M)

I'm still not sure what the best approach would be for the language modeling piece. Maybe you mask everything except the answer? The fact that the shape of the inputs change once you insert all the image embeddings makes this tough and I haven't found any proper demos at least so far. The caveat being if you want your vision tower to get better, maybe you fine-tune that separately or already start with a with a CLIP-like model already fine-tuned to a different domain?

nkasmanoff avatar Mar 23 '24 15:03 nkasmanoff

So for 1.6 it seems that we need to produce an image_newline token at the end of each image row. (I have no idea whether it is each actual row or what the role of patches is …).

jrp2014 avatar Mar 23 '24 15:03 jrp2014

I’m unfortunately busy and won’t be able to take a closer look again until later today, but re: fine-tuning demos, this video might be helpful:

https://www.youtube.com/watch?v=eIziN2QUt8U

tctrautman avatar Mar 23 '24 16:03 tctrautman

Thanks. If I cheat, and add a False to the load model, to let it ignore the model parameter (image_newline) mismatch, I get:

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Fetching 23 files: 100%|███████████████████████████████████████████████| 23/23 [00:00<00:00, 319433.75it/s]
Traceback (most recent call last):
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/mytest.py", line 13, in <module>
    reply = generate_text(
            ^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/generate.py", line 97, in generate_text
    logits, cache = model(input_ids, pixel_values)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/llava.py", line 135, in __call__
    input_embddings = self.get_input_embeddings(input_ids, pixel_values)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/llava.py", line 79, in get_input_embeddings
    pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Transpose axes don't match array dimensions.

jrp2014 avatar Mar 23 '24 18:03 jrp2014

We need to preprocess the image feature before merging it with the input and image features. You can find more information at https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L514-L553.

I personally feel it would be better to copy the existing llava 1.5 and create a new example for 1.6, as there have been quite a few changes introduced in llava 1.6.

mzbac avatar Mar 24 '24 08:03 mzbac

I’m unfortunately busy and won’t be able to take a closer look again until later today, but re: fine-tuning demos, this video might be helpful: https://www.youtube.com/watch?v=eIziN2QUt8U

It did, in so that it confirmed the only way to do this it seems is to train on what comes in the answer, or after the image. With that in mind I have a quick implementation working here, which finetunes on a dataset where every token after the one is trained.

Since LlaVA works with in arbitrary places this would be a cool enhancement, but not feasible in how it's set up from what I could tell.

nkasmanoff avatar Mar 24 '24 18:03 nkasmanoff

Unfortunately, it's taking me longer than I had hoped to get up to speed -- I still might be able to get a draft PR up but it won't be for a while longer.

tctrautman avatar Mar 24 '24 22:03 tctrautman

I am really interested in this work. However, I don't know where to start though. Have you got time to advance by any chance @tctrautman? I know you are doing this in your spare time and we are all grateful for that. Please let me know, if I can be helpful for anything.

ahmetkca avatar Apr 11 '24 20:04 ahmetkca

@ahmetkca Unfortunately I haven't had time to make any progress on this, and I'm honestly not sure when I will 😞

tctrautman avatar Apr 12 '24 19:04 tctrautman

PR to support LLaVA v1.6 will be merged to mlx-vlm early tomorrow :)

And the trainer (FFT and LoRA) should follow soon.

https://github.com/Blaizzy/mlx-vlm/pull/43

Blaizzy avatar Jun 22 '24 00:06 Blaizzy

PR to support LLaVA v1.6 will be merged to mlx-vlm early tomorrow :)

And the trainer (FFT and LoRA) should follow soon.

Blaizzy/mlx-vlm#43

Can the models fused with MLX-trained LoRA adapters be used in other environments outside MLX such as ollama?

Thanks

amirvenus avatar Feb 21 '25 09:02 amirvenus