Model and input data type is not same
Is your feature request related to a problem? Please describe.
Hi, when I trained sdv1.5 model with fp16 mode by using the examples/text_to_image/train_text_to_image.py file, I found there is a mismatch between unet model and input data. Specificaly, In this line , the unet model has float32 dtype, but the noisy_latents has the float16 dtype. Although it will not raise an error in cuda , I use my custom device it will raise an error, I wonder how can I change this code to use float16.
Describe the solution you'd like. To avoid get a wrong model, I would like you give a right code to match model and input.
Describe alternatives you've considered. A clear and concise description of any alternative solutions or features you've considered.
Additional context. Add any other context or screenshots about the feature request here.