Add support for GPT-OSS models
Summary
Add GPT-OSS model support, addressing https://github.com/linkedin/Liger-Kernel/issues/848 Completed patching for RoPE, RMSNorm, cross_entropy, and fused_linear_cross_entropy.
Known Issues
- Gated SwiGLU Patching Support: The current Hugging Face implementation of gated SwiGLU in GptOssExperts makes patching difficult. This will be addressed in a future update.
- GptOssExperts MXFP4 Format Support: MXFP4 tests are pending due to ongoing changes in the Hugging Face Transformers interface.
- BF16 Convergence Issue: The BF16 convergence test is failing, while FP32 passes. This issue is under investigation.
Testing Done
FP32 Log
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [100%]
test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [100%]
BF16 Log
pytest --disable-warnings test/convergence/bf16/test_mini_models.py
====================================================================== test session starts =======================================================================
platform linux -- Python 3.10.18, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/admin/Liger-Kernel
configfile: pyproject.toml
plugins: xdist-3.8.0, rerunfailures-15.1
collected 1 item
test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] FAILED [100%]
============================================================================ FAILURES ============================================================================
___________________________________________ test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] ___________________________________________
model_name = 'mini_gpt_oss', num_steps = 32, lr = 1e-05, dtype = torch.bfloat16, loss_atol = 0.01, loss_rtol = 0.05, logprobs_atol = 0.1, logprobs_rtol = 0.01
param_atol = 0.01, param_rtol = 0.01
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
[
pytest.param(
"mini_gpt_oss",
32,
1e-5,
torch.bfloat16,
1e-2,
5e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not GPT_OSS_AVAILABLE,
reason="GPT OSS not available in this version of transformers",
),
],
),
],
)
def test_mini_model(
model_name,
num_steps,
lr,
dtype,
loss_atol,
loss_rtol,
logprobs_atol,
logprobs_rtol,
param_atol,
param_rtol,
):
# Non-liger models should be initialized and tested first to avoid the module being overridden
expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
# Compare every step of the loss
> assert_verbose_allclose(
torch.tensor([expected_output["loss"]]),
torch.tensor([actual_output["loss"]]),
atol=loss_atol,
rtol=loss_rtol,
extra_info="[Loss]",
)
test/convergence/bf16/test_mini_models.py:1395:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tensor1 = tensor([[10.4809, 10.2822, 10.0886, 9.8527, 9.6104, 9.4217, 9.1856, 8.9703,
8.7122, 8.5283, 8.2974, ... 6.1561, 6.0330,
6.9142, 5.7746, 5.6058, 5.5196, 5.4399, 5.1645, 5.2462, 4.9314,
5.8588]])
tensor2 = tensor([[10.4806, 10.2860, 10.0742, 9.8525, 9.6143, 9.4222, 9.1932, 8.9775,
8.7126, 8.5138, 8.2971, ... 8.8191, 6.0038,
7.8853, 5.7446, 5.5767, 5.5193, 5.4383, 5.1643, 5.2460, 4.9023,
4.8585]])
rtol = 0.05, atol = 0.01, max_print = 5, extra_info = '[Loss]'
def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5, extra_info=""):
"""
Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
Parameters:
tensor1 (torch.Tensor): First tensor to compare.
tensor2 (torch.Tensor): Second tensor to compare.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
max_print (int): Maximum number of mismatched elements to print.
extra_info (str): Extra information to show at the start of the error message.
Raises:
AssertionError: If the tensors are not all close within the given tolerance.
"""
# Check if the shapes of the tensors match
if tensor1.shape != tensor2.shape:
raise AssertionError("Input tensors must have the same shape.")
# Calculate the difference between the tensors
diff = torch.abs(tensor1 - tensor2)
# Determine the tolerance
tolerance = atol + rtol * torch.abs(tensor2)
# Find tolerance mismatched elements
tol_mismatched = diff > tolerance
# Find nan mismatched elements
nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
# Find +inf mismatched elements
posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
# Find -inf mismatched elements
neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
# Find all mismatched elements
mismatched = torch.logical_or(
torch.logical_or(tol_mismatched, nan_mismatched),
torch.logical_or(posinf_mismatched, neginf_mismatched),
)
mismatched_indices = torch.nonzero(mismatched)
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Check if all elements are close
all_close = num_mismatched == 0
# Raise AssertionError with detailed information if there are mismatches
if not all_close and num_mismatched >= 1:
mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
print_count = min(max_print, num_mismatched)
for index in mismatched_indices[:print_count]:
i = tuple(index.tolist())
mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
if num_mismatched > max_print:
mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
> raise AssertionError(extra_info + "\n".join(mismatch_details))
E AssertionError: [Loss]Number of mismatched elements: 2
E Mismatch at index (0, 21): tensor1[(0, 21)] = 8.848180770874023, tensor2[(0, 21)] = 6.204418182373047
E Mismatch at index (0, 22): tensor1[(0, 22)] = 6.156144142150879, tensor2[(0, 22)] = 8.819061279296875
test/utils.py:131: AssertionError
---------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.480887413024902
Step 1, Loss: 10.282181739807129
Step 2, Loss: 10.088647842407227
Step 3, Loss: 9.852700233459473
Step 4, Loss: 9.610376358032227
Step 5, Loss: 9.421696662902832
Step 6, Loss: 9.185647010803223
Step 7, Loss: 8.970256805419922
Step 8, Loss: 8.712227821350098
Step 9, Loss: 8.528286933898926
Step 10, Loss: 8.297395706176758
Step 11, Loss: 8.022294998168945
Step 12, Loss: 7.8798322677612305
Step 13, Loss: 7.648087501525879
Step 14, Loss: 7.404232501983643
Step 15, Loss: 7.2665910720825195
Step 16, Loss: 9.697592735290527
Step 17, Loss: 9.604401588439941
Step 18, Loss: 9.46231746673584
Step 19, Loss: 6.633795738220215
Step 20, Loss: 9.081354141235352
Step 21, Loss: 8.848180770874023
Step 22, Loss: 6.156144142150879
Step 23, Loss: 6.032957077026367
Step 24, Loss: 5.914245128631592
Step 25, Loss: 5.774555206298828
Step 26, Loss: 5.605828285217285
Step 27, Loss: 5.519618988037109
Step 28, Loss: 5.439865589141846
Step 29, Loss: 5.164504528045654
Step 30, Loss: 5.246169567108154
Step 31, Loss: 4.93139123916626
Eval Loss: 4.858840465545654
Liger kernel patches have been reverted.
Step 0, Loss: 10.480640411376953
Step 1, Loss: 10.286001205444336
Step 2, Loss: 10.074193954467773
Step 3, Loss: 9.852497100830078
Step 4, Loss: 9.614309310913086
Step 5, Loss: 9.422234535217285
Step 6, Loss: 9.19322395324707
Step 7, Loss: 8.977506637573242
Step 8, Loss: 8.712628364562988
Step 9, Loss: 8.51380729675293
Step 10, Loss: 8.297117233276367
Step 11, Loss: 8.0232572555542
Step 12, Loss: 7.879528522491455
Step 13, Loss: 7.649257659912109
Step 14, Loss: 7.418290138244629
Step 15, Loss: 7.2662553787231445
Step 16, Loss: 9.697296142578125
Step 17, Loss: 9.61251449584961
Step 18, Loss: 9.455589294433594
Step 19, Loss: 6.634244918823242
Step 20, Loss: 9.089345932006836
Step 21, Loss: 6.204418182373047
Step 22, Loss: 8.819061279296875
Step 23, Loss: 6.003849029541016
Step 24, Loss: 5.8852691650390625
Step 25, Loss: 5.744577884674072
Step 26, Loss: 5.576700210571289
Step 27, Loss: 5.519281387329102
Step 28, Loss: 5.438257217407227
Step 29, Loss: 5.164332389831543
Step 30, Loss: 5.246014595031738
Step 31, Loss: 4.902283668518066
Eval Loss: 4.858528137207031
Liger kernel patches have been reverted.
==================================================================== short test summary info =====================================================================
FAILED test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] - AssertionError: [Loss]Number of mismatched elements: 2
================================================================= 1 failed, 1 warning in 16.33s ==================================================================
Env: torch 2.8.0, triton 3.4.0, transformers 4.55.0
- Hardware Type: H200
- [x] run
make testto ensure correctness - [ ] run
make checkstyleto ensure code style - [ ] run
make test-convergenceto ensure convergence
@Tcc0403
@Tcc0403
Many thanks for the contribution.
Thanks for the contribution! Any benchmarking results for the GPT-OSS models?
@PKUWZP I'll add the benchmark results as soon as the swiglu implementation is complete.
@shimizust During the convergence test, the loss values for the two models running in bf16 diverged significantly at certain steps. This is likely related to the issue discussed here: https://github.com/linkedin/Liger-Kernel/issues/742.
For future contributors:
When patching RMSNorm, there are 4 init args that are easily overlooked.
- casting_mode (str): "gemma" or "llama" (more detail)
- llama: downcasting back to original precision "before" multiplying weight
- gemma: downcasting back to original precision "after" multiplying weight
- init_fn (str): "zeros" or "ones", pretty much llama vs gemma impl too
- bias (float): default to 0.0
- 0.0 (llama): no ops to weight before multiplication
- 1.0 (gemma): adding 1.0 to weight before multiplication
- in_place (bool): True or False
- True: reusing tensor dY(grad_output) to save some memory, but it doesn't work if dY is required elsewhere (e.g. adding residual after rmsnorm)
- False: to address the above issue. It can always work. #376
Take GptOss for example https://github.com/huggingface/transformers/blob/v4.55.2/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L42
It requires
- casting_mode="gemma" (casting back after multiplying weight)
- init_fn="ones" (
self.weight = nn.Parameter(torch.ones(hidden_size))) - bias=0.0 (no shifting for weight)
- in_place=True (no dY usage elsewhere)
Create LigerRMSNormForXXXModel under liger_kernel/transformers/rms_norm.py to apply these init params
@Comet0322 Thanks for updating. Do you think we can check in what you have and figure out swiglu after? After @Tcc0403 's accum_dtype changes to the bf16 tests, do they pass now?
@Comet0322 and I have discussed elsewhere. It seems the router topk choice for experts is quite sensitive in bf16, leading the discrepancy in final losses. I suggest just skipping bf16 convergnece test for now.
cc @shimizust