Paligemma- fix devices and dtype assignments
What does this PR do?
Moves tensors to correct devices in case of multi-gpu training on accelerate and device_map = auto. Additionally ensures bf16 training works as well.
Fixes #30997
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
cc @ArthurZucker wdyt?
fwiw most of the lines in here are nearly identical to the changes i have done locally as well besides the final_embedding related one which i believe can be done with only 1 cast but didnt think too deeply about it
@grahamannett , good to know. For final_embedding it's also to fix the bf16 dtype mismatch.