TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

NaN loss issues when I switch to the Transformer Engine TransformerLayer from pytorch layer

Open jasonkrone opened this issue 1 year ago • 1 comments

Summary I'm hitting a NaN loss issue when I use the TransformerLayer in place of a pytorch transformer layer I wrote.

Details I'm using the nvcr.io/nvidia/pytorch:24.04-py3 docker container. I train with pytorch FSDP and use bfloat16 mixed precision.

Question Has the TransformerEngine team trained a model with the ‎TELlamaDecoderLayer‎ to ensure that everything works as expected? If so, could you share this example as my use case is very similar.

Code Here's the code I wrote to wrap the TransformerLayer such that it uses the ROPE embeddings. This is the class I swapped in for my model.

class TransformerLayerWithPOS(TransformerLayer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rope = RotaryPositionEmbedding(kwargs["hidden_size"] // kwargs["num_attention_heads"])
        self.rope_freqs = rope(max_seq_len=kwargs["seq_length"]).cuda()

    def forward(self, hidden_states):
        """
        Custom forward to make sure we only pass relevant arguments to the
        forward pass of the `TransformerLayer`. Also, make sure the output
        format matches the output of the HF's `LlamaDecoderLayer`.
        """
        return super().forward(hidden_states, rotary_pos_emb=self.rope_freqs)

In addition, here are the kwargs I send to the transformer layer.

        device = "meta" if config.use_meta_device else "cuda"
        return {
            "device": device,
            "params_dtype": torch.float32, 
            "hidden_size": config.d_model,
            "ffn_hidden_size": config.d_hidden,
            "num_attention_heads": config.n_heads,
            "self_attn_mask_type": "causal",
            "normalization": "RMSNorm",
            "bias": False, 
            "activation": "swiglu",
            "attn_input_format": "bshd", 
            "fuse_wgrad_accumulation": False, 
            "seq_length": config.max_len,
            "fuse_qkv_params": True,
        }

Learning Curve See the attached learning curve which displays the NaN issue, which occurs around step #350.

Screenshot 2024-06-21 at 10 34 24 AM

jasonkrone avatar Jun 21 '24 17:06 jasonkrone

Hi @jasonkrone. Did you compare the loss curve with the loss curve you got from your pyTorch implementation - the chart only shows 1 curve I believe. The first step in troubleshooting this would be to pass the same input to your Transformer layer implementation and to the TE implementation (with dropout set to 0 in order to have apples-to-apples comparison) and confirm that the outputs (both forward and backward) match (they will not match exactly due to numerical differences but they should be very close).

ptrendx avatar Jul 15 '24 21:07 ptrendx