mlx
mlx copied to clipboard
[WIP] qqmm
This is still a draft! But this adds a new op mx::qqmm and a new primitive mx::DualQuantizedMatmul (naming is questionable).
At the moment, the implementation only supports the configuration where both inputs are quantized in the same way (this is also the only configuration supported by cublas). The output type is fixed to bf16.
There are some restructuring to ops and cublas utils.
Todo:
- batching logic in
CublasQQMM -
biasand case whencis notnullptr - not sure but we probably want
mx::qqmmto return quantized output -
CublasQQMMalso should be cleaned -
jvp,vjp,vmap