[Bug]: SDXL DoRA training fails to start when base model weight data type is set to float8
What happened?
When I activate Decomposed Weights (DoRA) training in "Lora" tab and have base SDXL model loaded as float8 in "model" tab the training process fails to start with RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half.
If Dora toggle is deactivated training starts as usual.
What did you expect would happen?
DoRA training working.
Relevant log output
epoch: 0%| | 0/100 [00:26<?, ?it/s]
Traceback (most recent call last):
File "/home/user/apps/OneTrainer/scripts/train.py", line 38, in <module>
main()
File "/home/user/apps/OneTrainer/scripts/train.py", line 29, in main
trainer.train()
File "/home/user/apps/OneTrainer/modules/trainer/GenericTrainer.py", line 575, in train
model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/modules/modelSetup/BaseStableDiffusionXLSetup.py", line 467, in predict
predicted_latent_noise = model.unet(
^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1135, in forward
emb = self.time_embedding(t_emb, timestep_cond)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/embeddings.py", line 376, in forward
sample = self.linear_1(sample)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/modules/module/LoRAModule.py", line 374, in forward
WP = self.orig_module.weight + (self.make_weight(A, B) * (self.alpha / self.rank))
~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half
Output of pip freeze
No response
The correct way to handle this is a fairly subtle problem and might depend on the state of the Stable Diffusion ecosystem at large. It's not necessarily as simple as upcasting the float8 weights into the lora dtype, or vice versa. Sorry there's not a fast fix here.
Autocast doesn't seem to work in my experiments, sadly (wrapped it in an autocast context for the train_dtype which was float16). We'd need to explicitly upcast the fp8 tensors. I'm not sure how to guarantee that's a safe operation that reflects unit scaling. Does Torch even have a standard unit scaling API? Is this implemented in an ad-hoc way by Nvidia Transformer Engine and others?
Does Torch even have a standard unit scaling API? Is this implemented in an ad-hoc way by Nvidia Transformer Engine and others?
Looks like there's no standard unit scaling. It would be nice if bitsandbytes supported FP8, then everyone would just use it as standard! We're going to have to arbitrarily decide that FP8 weights need to be unscaled to fix this bug. I think anything present in the community that is actually FP8 is unscaled (to my knowledge), so this shouldn't be a big restriction. If we get future checkpoints in FP8 that were trained with Transformer Engine or something else, we'll need to revisit that. I'm hopeful that everyone just uses bitsandbytes for everything.
@ikarsokolov There is now a beta branch with proper support for fp8 Lora training, the branch name is fp8 please give it a try if you know how to use branches.