Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Add support for GPT-OSS models

Open Comet0322 opened this issue 6 months ago • 8 comments

Summary

Add GPT-OSS model support, addressing https://github.com/linkedin/Liger-Kernel/issues/848 Completed patching for RoPE, RMSNorm, cross_entropy, and fused_linear_cross_entropy.

Known Issues

  • Gated SwiGLU Patching Support: The current Hugging Face implementation of gated SwiGLU in GptOssExperts makes patching difficult. This will be addressed in a future update.
  • GptOssExperts MXFP4 Format Support: MXFP4 tests are pending due to ongoing changes in the Hugging Face Transformers interface.
  • BF16 Convergence Issue: The BF16 convergence test is failing, while FP32 passes. This issue is under investigation.

Testing Done

FP32 Log
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED                        [100%]
test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED            [100%]
BF16 Log
pytest --disable-warnings test/convergence/bf16/test_mini_models.py 
====================================================================== test session starts =======================================================================
platform linux -- Python 3.10.18, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/admin/Liger-Kernel
configfile: pyproject.toml
plugins: xdist-3.8.0, rerunfailures-15.1
collected 1 item                                                                                                                                                 

test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] FAILED                               [100%]

============================================================================ FAILURES ============================================================================
___________________________________________ test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] ___________________________________________

model_name = 'mini_gpt_oss', num_steps = 32, lr = 1e-05, dtype = torch.bfloat16, loss_atol = 0.01, loss_rtol = 0.05, logprobs_atol = 0.1, logprobs_rtol = 0.01
param_atol = 0.01, param_rtol = 0.01

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
        [
            pytest.param(
                "mini_gpt_oss",
                32,
                1e-5,
                torch.bfloat16,
                1e-2,
                5e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not GPT_OSS_AVAILABLE,
                        reason="GPT OSS not available in this version of transformers",
                    ),
                ],
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logprobs_atol,
        logprobs_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
            extra_info="[Loss]",
        )

test/convergence/bf16/test_mini_models.py:1395: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.4809, 10.2822, 10.0886,  9.8527,  9.6104,  9.4217,  9.1856,  8.9703,
          8.7122,  8.5283,  8.2974,  ...  6.1561,  6.0330,
          6.9142,  5.7746,  5.6058,  5.5196,  5.4399,  5.1645,  5.2462,  4.9314,
          5.8588]])
tensor2 = tensor([[10.4806, 10.2860, 10.0742,  9.8525,  9.6143,  9.4222,  9.1932,  8.9775,
          8.7126,  8.5138,  8.2971,  ...  8.8191,  6.0038,
          7.8853,  5.7446,  5.5767,  5.5193,  5.4383,  5.1643,  5.2460,  4.9023,
          4.8585]])
rtol = 0.05, atol = 0.01, max_print = 5, extra_info = '[Loss]'

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5, extra_info=""):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
        extra_info (str): Extra information to show at the start of the error message.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError(extra_info + "\n".join(mismatch_details))
E           AssertionError: [Loss]Number of mismatched elements: 2
E           Mismatch at index (0, 21): tensor1[(0, 21)] = 8.848180770874023, tensor2[(0, 21)] = 6.204418182373047
E           Mismatch at index (0, 22): tensor1[(0, 22)] = 6.156144142150879, tensor2[(0, 22)] = 8.819061279296875

test/utils.py:131: AssertionError
---------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.480887413024902
Step 1, Loss: 10.282181739807129
Step 2, Loss: 10.088647842407227
Step 3, Loss: 9.852700233459473
Step 4, Loss: 9.610376358032227
Step 5, Loss: 9.421696662902832
Step 6, Loss: 9.185647010803223
Step 7, Loss: 8.970256805419922
Step 8, Loss: 8.712227821350098
Step 9, Loss: 8.528286933898926
Step 10, Loss: 8.297395706176758
Step 11, Loss: 8.022294998168945
Step 12, Loss: 7.8798322677612305
Step 13, Loss: 7.648087501525879
Step 14, Loss: 7.404232501983643
Step 15, Loss: 7.2665910720825195
Step 16, Loss: 9.697592735290527
Step 17, Loss: 9.604401588439941
Step 18, Loss: 9.46231746673584
Step 19, Loss: 6.633795738220215
Step 20, Loss: 9.081354141235352
Step 21, Loss: 8.848180770874023
Step 22, Loss: 6.156144142150879
Step 23, Loss: 6.032957077026367
Step 24, Loss: 5.914245128631592
Step 25, Loss: 5.774555206298828
Step 26, Loss: 5.605828285217285
Step 27, Loss: 5.519618988037109
Step 28, Loss: 5.439865589141846
Step 29, Loss: 5.164504528045654
Step 30, Loss: 5.246169567108154
Step 31, Loss: 4.93139123916626
Eval Loss: 4.858840465545654
Liger kernel patches have been reverted.
Step 0, Loss: 10.480640411376953
Step 1, Loss: 10.286001205444336
Step 2, Loss: 10.074193954467773
Step 3, Loss: 9.852497100830078
Step 4, Loss: 9.614309310913086
Step 5, Loss: 9.422234535217285
Step 6, Loss: 9.19322395324707
Step 7, Loss: 8.977506637573242
Step 8, Loss: 8.712628364562988
Step 9, Loss: 8.51380729675293
Step 10, Loss: 8.297117233276367
Step 11, Loss: 8.0232572555542
Step 12, Loss: 7.879528522491455
Step 13, Loss: 7.649257659912109
Step 14, Loss: 7.418290138244629
Step 15, Loss: 7.2662553787231445
Step 16, Loss: 9.697296142578125
Step 17, Loss: 9.61251449584961
Step 18, Loss: 9.455589294433594
Step 19, Loss: 6.634244918823242
Step 20, Loss: 9.089345932006836
Step 21, Loss: 6.204418182373047
Step 22, Loss: 8.819061279296875
Step 23, Loss: 6.003849029541016
Step 24, Loss: 5.8852691650390625
Step 25, Loss: 5.744577884674072
Step 26, Loss: 5.576700210571289
Step 27, Loss: 5.519281387329102
Step 28, Loss: 5.438257217407227
Step 29, Loss: 5.164332389831543
Step 30, Loss: 5.246014595031738
Step 31, Loss: 4.902283668518066
Eval Loss: 4.858528137207031
Liger kernel patches have been reverted.
==================================================================== short test summary info =====================================================================
FAILED test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] - AssertionError: [Loss]Number of mismatched elements: 2
================================================================= 1 failed, 1 warning in 16.33s ==================================================================

Env: torch 2.8.0, triton 3.4.0, transformers 4.55.0

  • Hardware Type: H200
  • [x] run make test to ensure correctness
  • [ ] run make checkstyle to ensure code style
  • [ ] run make test-convergence to ensure convergence

Comet0322 avatar Aug 09 '25 15:08 Comet0322

@Tcc0403

Comet0322 avatar Aug 09 '25 15:08 Comet0322

@Tcc0403

Many thanks for the contribution.

lancerts avatar Aug 11 '25 14:08 lancerts

Thanks for the contribution! Any benchmarking results for the GPT-OSS models?

PKUWZP avatar Aug 11 '25 16:08 PKUWZP

@PKUWZP I'll add the benchmark results as soon as the swiglu implementation is complete.

Comet0322 avatar Aug 12 '25 06:08 Comet0322

@shimizust During the convergence test, the loss values for the two models running in bf16 diverged significantly at certain steps. This is likely related to the issue discussed here: https://github.com/linkedin/Liger-Kernel/issues/742.

loss_curve_final

Comet0322 avatar Aug 12 '25 09:08 Comet0322

For future contributors:

When patching RMSNorm, there are 4 init args that are easily overlooked.

  1. casting_mode (str): "gemma" or "llama" (more detail)
    • llama: downcasting back to original precision "before" multiplying weight
    • gemma: downcasting back to original precision "after" multiplying weight
  2. init_fn (str): "zeros" or "ones", pretty much llama vs gemma impl too
  3. bias (float): default to 0.0
    • 0.0 (llama): no ops to weight before multiplication
    • 1.0 (gemma): adding 1.0 to weight before multiplication
  4. in_place (bool): True or False
    • True: reusing tensor dY(grad_output) to save some memory, but it doesn't work if dY is required elsewhere (e.g. adding residual after rmsnorm)
    • False: to address the above issue. It can always work. #376

Take GptOss for example https://github.com/huggingface/transformers/blob/v4.55.2/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L42

It requires

  1. casting_mode="gemma" (casting back after multiplying weight)
  2. init_fn="ones" (self.weight = nn.Parameter(torch.ones(hidden_size)))
  3. bias=0.0 (no shifting for weight)
  4. in_place=True (no dY usage elsewhere)

Create LigerRMSNormForXXXModel under liger_kernel/transformers/rms_norm.py to apply these init params

Tcc0403 avatar Aug 14 '25 19:08 Tcc0403

@Comet0322 Thanks for updating. Do you think we can check in what you have and figure out swiglu after? After @Tcc0403 's accum_dtype changes to the bf16 tests, do they pass now?

shimizust avatar Sep 03 '25 16:09 shimizust

@Comet0322 and I have discussed elsewhere. It seems the router topk choice for experts is quite sensitive in bf16, leading the discrepancy in final losses. I suggest just skipping bf16 convergnece test for now.

cc @shimizust

Tcc0403 avatar Sep 03 '25 18:09 Tcc0403