Mixed precision modeling
Hello! First of all, this is truly wonderful work and I expect to be playing with this repo for some time to come. LDMs are a great idea and they make the implementation of latent models across our domain far more feasible in my opinion; this can not only democratize DDPMs, but could make them applicable to very complex domains that are currently out of reach in generative modeling and beyond.
I had a question regarding the implementation of mixed precision in these models. In my mind, mixed precision should work reliably given it is implemented correctly. When I used pytorch lightnings auto mixed precision training parameters, however, I encountered vanishing gradients resulting in very small outputs (and thus nan reconstruction loss). I somewhat expected this; for example, I am not sure if the noising process is being converted to mixed precision properly by lightning, and am currently experimenting with this to better understand what is happening. I am also not sure if lightning is correctly converting all operations to mixed precision (I tried using both the native and amp backends, with native resulting in nan loss and amp... well just being amp). I believe that with some work this can be remedied without gradient clipping, but was wondering if the authors (or any other users) had some insight on how to implement this properly, given the characteristics of both the encoder-decoder network and the LDM module.
I noticed that the LDM UNet module has a fp16 precision conversion function that appears to not be used in the code base:
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
Does anyone have any insights as to why this was not used? If mixed precision is implemented correctly, these models could be used to generate very high resolution images (I was able to train reconstruction networks using a training resolution of around 1024, so I think this could be pushed up to UHD values during training!). I realize the inference step can be used to achieve these higher resolutions, but the resulting image is still dependent on the input data scale. It would be preferable to have trained at higher resolutions as a means to achieve reliable UHD outputs, with inference scaling proportionately. I think this is a worthwhile objective for anyone using this codebase given how powerful a UHD trained model could be!
@pesser @rromb @patrickvonplaten @ablattmann @owenvincent @apolinario @cpacker I would truly appreciate any input from your experience with this!