Shubham Krishna

Results 25 comments of Shubham Krishna

R: @andyxiexu @damccorm

@damccorm linting and formatting fixed

You can replicate the wrapping and unwrapping using this script: ``` import torch import torch_xla import torch.nn as nn import functools from transformers import AutoModelForCausalLM from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as...

@amyeroberts I had to change the `unwrap_model` because of the changes introduced here: #28949 which was `Support PyTorch/XLA FSDP via SPMD` and the existing `unwrap_model` only fails there. I can...

@amyeroberts here is a small snippet for the test: ```python import torch import torch_xla import torch.nn as nn from transformers import AutoModelForCausalLM from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2 import torch_xla.distributed.spmd...

@PawKanarek just to isolate the error, what happens if you run the same code on a GPU instead of TPU?

@PawKanarek can you also provide the training logs please and run with `logging_steps=1`? Also use `save_strategy=epoch`

@PawKanarek also after training can you try saving with `trainer.save_model('output_dir')`

@PawKanarek also with your [patch](https://github.com/huggingface/transformers/issues/29659#issuecomment-1999954329) did it work?