Shubham Krishna
Shubham Krishna
R: @andyxiexu @damccorm
@damccorm linting and formatting fixed
@alanwaketan can you also take a look please ?
You can replicate the wrapping and unwrapping using this script: ``` import torch import torch_xla import torch.nn as nn import functools from transformers import AutoModelForCausalLM from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as...
@amyeroberts I had to change the `unwrap_model` because of the changes introduced here: #28949 which was `Support PyTorch/XLA FSDP via SPMD` and the existing `unwrap_model` only fails there. I can...
@amyeroberts here is a small snippet for the test: ```python import torch import torch_xla import torch.nn as nn from transformers import AutoModelForCausalLM from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2 import torch_xla.distributed.spmd...
@PawKanarek just to isolate the error, what happens if you run the same code on a GPU instead of TPU?
@PawKanarek can you also provide the training logs please and run with `logging_steps=1`? Also use `save_strategy=epoch`
@PawKanarek also after training can you try saving with `trainer.save_model('output_dir')`
@PawKanarek also with your [patch](https://github.com/huggingface/transformers/issues/29659#issuecomment-1999954329) did it work?