RuntimeError: expected scalar type Half but found Float
Hello, I am trying to use amp to have mixed precision training for my model. I am implementing a multi-task learning algorithm, so my loss is a summation of loss_a and loss_b, which I then back-propagate. When doing so and following amp usage instructions I get:
RuntimeError: expected scalar type Half but found Float
I also tried, back-propagating one loss at the time, while retaining the graph, following your instructions for multiple losses. I get the same error.
In the previous cases, I didn't use model.half(). When I did, my losses started becoming NaN.
Do you have any suggestions on how to proceed?
Thanks in advance! VglsD
*Edit: I should also note that my network has Batch Normalization Layers
Environment: Miniconda, Python 3.7, Cuda 10, PyTorch 1.0, apex master branch
If your backward pass looks like
loss = loss_a + loss_b
with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
the fact that loss is built as a sum of two sub-losses shouldn't affect anything. Pytorch doesn't care, it just sees the combination as one big graph. Also, you shouldn't have to call model.half() when using amp.
This seems like a relatively benign use case. Do you have a minimal/standalone repro?
Same problem. At first, I can use apex package well on level O1, but I add Stochastic weight averaging, this kind of problem appears.
@USTClj Could you post a code snippet to reproduce this issue?
I tried to reproduce it using opt_level='O1', but it seems to work on my machine using:
optimizer = optim.SGD(model.parameters(), lr=1e-3)
optimizer = SWA(optimizer, swa_start=1, swa_freq=1, swa_lr=0.05)
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
When I try to train the network with frozen BN layers. I can continue to reload and train the net if I convert the BN layers to half manually.
I'm not sure to understand your use case completely. Even if I freeze the affine parameters in all batchnorm layers, my code snippet still works. Could you have a look at this code and tell me what I would have to change to reproduce your issue?
@ptrblck @mcarilli I am getting the same error but strangely this error always comes after 1 training epoch. First epoch is always fine. Also this goes away if I use keep_batchnorm_fp32 = False. And I am not summing the loss, its just normal CrossEntropyLoss.
So I guess we have a common problem with if batch norm layers are kept fp32
Hi @Msabih,
do you have a small code snippet to reproduce this issue? I've posted some code in the last post, which is working fine for me. If you don't have a code snippet ready, could you take a look at my code and compare it to yours, and let me know what I should change to trigger this issue?
@ptrblck I figured out the issue for my case. The problem occurs if someone is training and validating in one epoch and model.eval() is called in validation, for the next training iteration, the error occurs. The solution is to put model.train() in the beginning of training loop.
Could you in your script call model.eval() at the end of epoch and see if the error comes ? Then call model.train() and check if the error goes away.
Normally if distributed data parallel is not used, we dont need to specifically use model.train(). Thats why if someone is converting the code from single/dataparallel to distributed data parallel, this might be a common mistake.
With opt_level="O1", when switching from train() to eval(), there used to be some strange errors that can result from the way Amp caches casts. However, we fixed what we observed, and I actually have a test for that exact circumstance: https://github.com/NVIDIA/apex/blob/master/tests/L0/run_amp/test_cache.py#L70. However, this test does not use batchnorm.
What opt_level are you using? And can you post the full backtrace of the error you receive?
@mcarilli I got the error regardless of the opt_level on both O1 and O2. As i said, the error was due to me not calling train() after eval() for the next epoch so its not really a bug. Although I think that as data parallel and single models works without explicitly calling train()( only calling eval()) so others may fall into the same mistake therefore its good to highlight this.
I will post the backtrace anyways later.
In my case, I exactly switch from train() in the training loop and free the BN just after the train() in the training loop. I'll share my post a few days later.
Freeze BN:
def fix_bn(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval().half() # add half() we can train the network successfully.
An example: https://github.com/jialee93/Improved-Body-Parts/blob/0cb0d94ea94518406319ca1091c0b62f5e86f92a/train_distributed_SWA.py#L221
I solved that BN layer issue following @jialee93's workaround but I got other issues related to custom loss functions backpropagation regarding the float/half data type mismatch.
https://github.com/tianzhi0549/FCOS/issues/121
@mcarilli Is it possible to use FP32 BN params with eval() mode? I would like to use the model for the generation of adversarial examples within the validation loop, so it requires backprop through the network to discover the adversarial perturbation in the input space. I currently have to cast the BN params to .half() like @jialee93 mentioned, but this results in a significant reduction in the accuracy of the model due to the loss of precision.
I am using an FP16 model and optimizer. The code is an adaptation of the following example which uses APEX to convert to FP16: https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5
I'm not sure to understand your use case completely. Even if I freeze the affine parameters in all batchnorm layers, my code snippet still works. Could you have a look at this code and tell me what I would have to change to reproduce your issue?
@ptrblck I have modified your example to highlight the issue: https://gist.github.com/shoaibahmed/35ea724e742f51f317eee8118424b053
Still got the same error when generating adversarial examples using backprop. Any updates?
@LayneH I added a custom wrapper on top of the BN layer in order to make it compatible. I have added the code for your reference. Just let me know if you still encounter any errors in this regard.
class BatchNormWrapper(nn.Module):
def __init__(self, m):
super(BatchNormWrapper, self).__init__()
self.m = m
self.m.eval() # Set the batch norm to eval mode
def forward(self, x):
input_type = x.dtype
x = self.m(x.float())
return x.to(input_type)
if fp16:
# Replace the BN layers with a custom wrapper
# Otherwise, PyTorch won't use the CuDNN batch norm resulting in errors for FP16 inputs to the BN module
# The wrapper casts the input to float, performs batch norm and then casts back the input data type
def add_fp16_bn_wrapper(model):
for child_name, child in model.named_children():
if isinstance(child, nn.BatchNorm2d):
setattr(model, child_name, BatchNormWrapper(child))
else:
add_fp16_bn_wrapper(child)
add_fp16_bn_wrapper(model_and_loss)
@shoaibahmed Thanks for your code snippet!
BTW, I adapt the add_fp16_bn_wrapper function to make it compatible with sync bn:
def add_fp16_bn_wrapper(model):
for child_name, child in model.named_children():
classname = child.__class__.__name__
if classname.find('BatchNorm') != -1:
setattr(model, child_name, BatchNormWrapper(child))
else:
add_fp16_bn_wrapper(child)
In my situation, directly using module.eval().half() is dangerous, for in some cases, some initial value bn state_dict has run out of the support range by fp16, like bigger than 6e4 (running_var for me in first bn of the resnet). And direct half will cast this param to inf value, making the train procedure more unstable! @shoaibahmed & @LayneH provided an better solution!
I encounter the same problem when using some custom module with some parameter say self.paramA and the forward function including input = torch.where(cond, self.paramA, input), and I definitely included model.train() and model.eval() for each epoch. The problem occurs once evaluation after training in the first epoch, but not for the training progress. Following the previous suggestion I use
paramA = self.paramA.to(input.dtype)
input = torch.where(cond, paramA, input)
and it works now. I have no idea of why this is solved.
I did not include eval in the init for my custom module. One cons might be that the running is a little slower, due to the to operation.