guided-diffusion icon indicating copy to clipboard operation
guided-diffusion copied to clipboard

How can I train the model by using Accelerator?

Open huanghaosen110 opened this issue 1 year ago • 1 comments

Because the officially implemented training code is too slow,I use the Accelerator to speed up the training,but when I use the mixed_precision="bf16",the forward propagation is fine,but It always cause this error when doing the backward. File "E:\DeepLearningProject\Anti-DreamBooth-main2\train2.py", line 233, in trainUnet() File "E:\DeepLearningProject\Anti-DreamBooth-main2\train2.py", line 195, in trainUnet accelerator.backward(total_loss) File "E:\miniconda\envs\py310\lib\site-packages\accelerate\accelerator.py", line 1853, in backward loss.backward(**kwargs) File "E:\miniconda\envs\py310\lib\site-packages\torch_tensor.py", line 522, in backward torch.autograd.backward( File "E:\miniconda\envs\py310\lib\site-packages\torch\autograd_init_.py", line 266, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "E:\miniconda\envs\py310\lib\site-packages\torch\autograd\function.py", line 289, in apply return user_fn(self, *args) File "E:\DeepLearningProject\Anti-DreamBooth-main2\guided_diffusion\nn.py", line 168, in backward output_tensors = ctx.run_function(*shallow_copies) File "E:\DeepLearningProject\Anti-DreamBooth-main2\guided_diffusion\unet.py", line 304, in _forward qkv = self.qkv(self.norm(x)) File "E:\miniconda\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "E:\miniconda\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "E:\miniconda\envs\py310\lib\site-packages\torch\nn\modules\conv.py", line 310, in forward return self._conv_forward(input, self.weight, self.bias) File "E:\miniconda\envs\py310\lib\site-packages\torch\nn\modules\conv.py", line 306, in _conv_forward return F.conv1d(input, weight, bias, self.stride, RuntimeError: Input type (struct c10::BFloat16) and bias type (float) should be the same

huanghaosen110 avatar Jun 04 '24 06:06 huanghaosen110

Do you fix it now? Thanks!

Benny0323 avatar Nov 07 '25 06:11 Benny0323