[bug] Unable to freeze specific layers of a pretrained model
🐞 Describe the Bug
I'm trying to freeze specific layers of a pretrained model (for example only layer 0).
The problem is that loading a pretrained model like Apriel-Thinker will load its decoder config as a FixedBlockSequenceConfig. However I need to pass a block-pattern config to freeze only certain layers. e.g.
decoder:
type: pattern
pattern:
- train_block
- freeze_block
- freeze_block
blocks:
train_block:
mlp:
lr_scale: 1.e-12
freeze_block:
lr_scale: 1.e-12
We currently cannot reconcile these two configs.
So the solution would be to prevent loading the pretrained config with load_config: none, and re-pass the entire block config.
However this does not work currently because some type parameters are creeping into the decoder config:
'!!! block':
type: decoder
mixer:
type: attention
rotary:
type: none
mlp:
type: mlp
normalization:
type: layer_norm
🔄 Steps to Reproduce
Steps to reproduce the behavior:
- Fast-LLM version: https://github.com/ServiceNow/Fast-LLM/tree/b7c0de61662c61e83c617bd8157d0bf9426e3d52
- Train with the following config
pretrained:
format: mistral
path: /mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker-reinit-attn-layer-0
load_config: none
model:
base_model:
decoder:
type: pattern
pattern:
- train_block
- freeze_block
- freeze_block
- freeze_block
blocks:
train_block:
mlp:
lr_scale: 1.e-12
freeze_block:
lr_scale: 1.e-12
type: gpt
- Fails during config validation with
fast_llm.config.NestedValidationError: Validation failed for field `model` of type `fast_llm.models.gpt.config.GPTModelConfig` in class fast_llm.models.gpt.config.GPTTrainerConfig:
Validation failed for field `base_model` of type `fast_llm.models.gpt.config.GPTBaseModelConfig` in class fast_llm.models.gpt.config.GPTModelConfig:
Validation failed for field `decoder` of type `fast_llm.layers.block.config.BlockSequenceConfig` in class fast_llm.models.gpt.config.GPTBaseModelConfig:
Unknown field `block` in class fast_llm.layers.block.config.PatternBlockSequenceConfig
The decoder config would look like:
decoder:
type: pattern
blocks:
train_block:
[...]
freeze_block:
[...]
pattern:
- train_block
- freeze_block
- freeze_block
- freeze_block
num_blocks: 50
'!!! block': <--- undesired entry coming from the pretrained checkpoint
type: decoder
mixer:
type: attention
rotary:
type: none
mlp:
type: mlp
normalization:
type: layer_norm
🎯 Expected Behavior
Should only load current config
Even after patching the creeping type parameters: https://github.com/ServiceNow/Fast-LLM/commit/f7a0837d5ba134a3941d2599dfc174b3eb3ef62f, loading the pretrained model into a PatternBlockSequence fails:
...
File "/home/toolkit/code/Fast-LLM/fast_llm/engine/multi_stage/fast_llm_model.py", line 38, in load_checkpoint
converter = config.format.get_handler_class()(self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/code/Fast-LLM/fast_llm/engine/checkpoint/huggingface.py", line 43, in __init__
self._exported_config = self._export_config(model.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/code/Fast-LLM/fast_llm/engine/checkpoint/huggingface.py", line 120, in _export_config
cls.base_model_converter_class.export_config(config.base_model),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/code/Fast-LLM/fast_llm/models/gpt/conversion/llama.py", line 522, in export_config
cls.decoder_converter_class.export_config(config.decoder),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/code/Fast-LLM/fast_llm/models/gpt/conversion/llama.py", line 424, in export_config
Assert.custom(isinstance, config, FixedBlockSequenceConfig)
File "/home/toolkit/code/Fast-LLM/fast_llm/utils.py", line 198, in custom
assert fn(
^^^
AssertionError: Assertion failed: fn(
The above can be fixed by adding support for compatible pattern-block-sequence in Llama conversion. See: https://github.com/ServiceNow/Fast-LLM/pull/388/commits/6f2d5e3070f3d4188dd4aab6c63d194d1da916c1 and https://github.com/ServiceNow/Fast-LLM/pull/388/commits/52517190ecf61f16dedc8e68c7d305af2beece74