penzai icon indicating copy to clipboard operation
penzai copied to clipboard

Feature request: Support for quantization

Open JEM-Mosig opened this issue 1 year ago • 6 comments

It'd be great if penzai would support model quantization out of the box. I know this is a lot of work to implement, but right now the lack of quantization support is the main reason why I wouldn't want to fine tune models with penzai.

JEM-Mosig avatar Apr 17 '25 10:04 JEM-Mosig

One could use AQT for this, if penzai would expose the dot_general function somehow. But, e.g., Linear is implemented in terms of jnp.tensordot, which uses lax.dot_general under the hood, and NamedEinsum is implemented in terms of jnp.einsum, which does have an (experimental) _dot_general keyword argument, but it isn't exposed by Penzai.

JEM-Mosig avatar Apr 19 '25 09:04 JEM-Mosig

Agreed this would be a very useful feature!

I think it should be pretty easy to prototype something like this without needing to directly change Penzai's implementation, because Penzai is designed to make it easy to hot-swap out model components. One implementation strategy:

  • Define a new class AQTLinear with the same __call__ interface as pz.nn.Linear, but defined in terms of dot_general instead of jnp.tensordot
    • It could have a classmethod AQTLinear.from_linear(cls, orig: pz.nn.Linear, config: aqt_config.DotGeneral) -> AQTLinear that builds itself and adopts the parameters from the original pz.nn.Linear. (This is similar to how LowRankAdapter replaces a Linear, or how KVCachingAttention replaces an Attention.)
    • Perhaps AQTLinear.__call__ could be implemented by using jnp.einsum instead of jnp.tensordot
  • Similarly define a new class AQTNamedEinsum based on penzai's NamedEinsum
  • Use selectors to replace them, e.g.
    (
      pz.select(model)
      .at_instances_of(pz.nn.Linear)
      .apply(lambda lin: AQTLinear.from_base_linear(lin, aqt_config)
    )
    

If this works, it might make sense to add the AQTLinear/AQTNamedEinsum classes (and some helper functions) into Penzai, perhaps under penzai.toolshed.aqt. Then people who want to use AQT quantization could enable it with just a few extra lines.

(I probably won't have much bandwidth to experiment with this myself, but contributions are welcome!)

danieldjohnson avatar Apr 20 '25 22:04 danieldjohnson

I had the same thought! Your naming is better, though :) I'll try to implement the AQT layers when I find the time. But I think doing so will involve a lot of code copying, which is not ideal. If Linear etc. would directly expose the dot_general, then AQTLinear could just build a Linear with a dot_general function that follows from the AQT config. In any case, I'll implement it first in the code-copy manner and see if it works.

JEM-Mosig avatar Apr 22 '25 08:04 JEM-Mosig

Are you accepting contributions from community, I would love to work on this issue.

demoncoder-crypto avatar Apr 22 '25 11:04 demoncoder-crypto

It looks like using AQT directly is a bit more tricky than I thought, as AQT objects carry around state for calibration and the AQT code generally seems to be in an unfinished and abandoned state. I'll see if I can implement some simple post-training quantization myself, but I can't guarantee that I'll find enough time to do so.

As inspiration, I think this section of the AQT Readme and maybe this outdated user guide for flax might be helpful.

@demoncoder-crypto I'm also just a community contributor, so I'm sure contributions would be welcome. Let me know if you make progress on this!

JEM-Mosig avatar Apr 24 '25 08:04 JEM-Mosig

@demoncoder-crypto This looks like the successor to AQT and might be interesting to look into: https://github.com/google/qwix

JEM-Mosig avatar May 07 '25 14:05 JEM-Mosig