maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Convert DenseGeneral to NNX

Open bvandermoon opened this issue 8 months ago • 0 comments

Description

Finishing out PR 1604 for @cgarciae

Note: I updated the logits in golden_data_grpo_default.jsonl to the NNX values. The NNX logits don't match the Linen ones since the RNG keys are setup differently.

Previous PR description:

This commit converts DenseGeneral to NNX and creates a dense_general to interface with it through a Linen wrapper. dense_general contains all the same arguments as the Linen version but adds two additional ones:

  • input_shape: the expected shape of the input.
  • in_features: an int or tuple representing the input features.

Only one of them can be set at a time.

Tests

  • Unit tests
  • Will be running a logits check for gpt3 which is the first model we are using

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • [x] I have performed a self-review of my code.
  • [x] I have necessary comments in my code, particularly in hard-to-understand areas.
  • [x] I have run end-to-end tests tests and provided workload links above if applicable.
  • [x] I have made or will make corresponding changes to the doc if needed.

bvandermoon avatar May 14 '25 22:05 bvandermoon