transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Paligemma- fix devices and dtype assignments

Open molbap opened this issue 1 year ago • 2 comments

What does this PR do?

Moves tensors to correct devices in case of multi-gpu training on accelerate and device_map = auto. Additionally ensures bf16 training works as well.

Fixes #30997

molbap avatar May 24 '24 09:05 molbap

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

cc @ArthurZucker wdyt?

molbap avatar May 24 '24 11:05 molbap

fwiw most of the lines in here are nearly identical to the changes i have done locally as well besides the final_embedding related one which i believe can be done with only 1 cast but didnt think too deeply about it

grahamannett avatar May 24 '24 15:05 grahamannett

@grahamannett , good to know. For final_embedding it's also to fix the bf16 dtype mismatch.

molbap avatar May 24 '24 15:05 molbap