Feature request: Support for quantization
It'd be great if penzai would support model quantization out of the box. I know this is a lot of work to implement, but right now the lack of quantization support is the main reason why I wouldn't want to fine tune models with penzai.
One could use AQT for this, if penzai would expose the dot_general function somehow. But, e.g., Linear is implemented in terms of jnp.tensordot, which uses lax.dot_general under the hood, and NamedEinsum is implemented in terms of jnp.einsum, which does have an (experimental) _dot_general keyword argument, but it isn't exposed by Penzai.
Agreed this would be a very useful feature!
I think it should be pretty easy to prototype something like this without needing to directly change Penzai's implementation, because Penzai is designed to make it easy to hot-swap out model components. One implementation strategy:
- Define a new class
AQTLinearwith the same__call__interface aspz.nn.Linear, but defined in terms ofdot_generalinstead ofjnp.tensordot- It could have a classmethod
AQTLinear.from_linear(cls, orig: pz.nn.Linear, config: aqt_config.DotGeneral) -> AQTLinearthat builds itself and adopts the parameters from the original pz.nn.Linear. (This is similar to how LowRankAdapter replaces a Linear, or how KVCachingAttention replaces an Attention.) - Perhaps
AQTLinear.__call__could be implemented by usingjnp.einsuminstead ofjnp.tensordot
- It could have a classmethod
- Similarly define a new class
AQTNamedEinsumbased on penzai'sNamedEinsum - Use selectors to replace them, e.g.
( pz.select(model) .at_instances_of(pz.nn.Linear) .apply(lambda lin: AQTLinear.from_base_linear(lin, aqt_config) )
If this works, it might make sense to add the AQTLinear/AQTNamedEinsum classes (and some helper functions) into Penzai, perhaps under penzai.toolshed.aqt. Then people who want to use AQT quantization could enable it with just a few extra lines.
(I probably won't have much bandwidth to experiment with this myself, but contributions are welcome!)
I had the same thought! Your naming is better, though :) I'll try to implement the AQT layers when I find the time. But I think doing so will involve a lot of code copying, which is not ideal. If Linear etc. would directly expose the dot_general, then AQTLinear could just build a Linear with a dot_general function that follows from the AQT config. In any case, I'll implement it first in the code-copy manner and see if it works.
Are you accepting contributions from community, I would love to work on this issue.
It looks like using AQT directly is a bit more tricky than I thought, as AQT objects carry around state for calibration and the AQT code generally seems to be in an unfinished and abandoned state. I'll see if I can implement some simple post-training quantization myself, but I can't guarantee that I'll find enough time to do so.
As inspiration, I think this section of the AQT Readme and maybe this outdated user guide for flax might be helpful.
@demoncoder-crypto I'm also just a community contributor, so I'm sure contributions would be welcome. Let me know if you make progress on this!
@demoncoder-crypto This looks like the successor to AQT and might be interesting to look into: https://github.com/google/qwix