oneDNN icon indicating copy to clipboard operation
oneDNN copied to clipboard

Please privide more 4bit matmul docs and example?

Open alanzhai219 opened this issue 9 months ago • 10 comments

oneDNN enables 4bits for matmul, like u4/s4 and float-4bit. However, there is no document or example about 4-bit data type. Especially, 4bit storage structure and use cases in the real inference. Could you provide the more details to describe it?

alanzhai219 avatar Jul 21 '25 08:07 alanzhai219

Hello @alanzhai219, Thank you for your request. The core concepts behind INT4 and INT8 MatMul are quite similar. You can refer to the MatMul Tutorial for details on weight decompression, as well as examples available in the examples/tutorials/matmul folder.

For INT4 weight decompression specifically, we have a draft PR (#2193) that introduces an example demonstrating how 8 INT4 values are packed along the N dimension into an INT32 using the tag::ba format. While this PR has been in our backlog for some time, I plan to update it soon to move it forward. In the meantime, you're welcome to use it (code link) as a reference.

shu1chen avatar Jul 21 '25 15:07 shu1chen

@shu1chen Is INT4 weight decompression feature supported on CPU platforms?

alanzhai219 avatar Jul 24 '25 02:07 alanzhai219

Is INT4 weight decompression feature supported on CPU platforms?

@alanzhai219 Yes. Please refer to the oneDNN documentation for the list of data types supported on both CPU and GPU. You can also use the benchdnn tool provided by oneDNN to verify this. The test has passed on CPU, which confirms that this case is supported.

$ ./tests/benchdnn/benchdnn --engine=cpu --matmul --dt=bf16:u4:bf16 --stag=ab --wtag=ba --dtag=ab --attr-fpmath=any:true --attr-scales=wei:per_tensor:bf16:256x1 --attr-zero-points=wei:per_tensor:u4:256x1 1024x1024:1024x1024
0:PASSED (1577 ms) __REPRO: --matmul --dt=bf16:u4:bf16 --stag=ab --wtag=ba --dtag=ab --attr-scales=wei:per_tensor:bf16:256x1 --attr-zero-points=wei:per_tensor:u4:256x1 --attr-fpmath=any:true 1024x1024:1024x1024
tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 1.58s; create_pd: 0.00s (0%); create_prim: 0.00s (0%); fill: 0.07s (4%); execute: 1.44s (92%); compute_ref: 0.06s (3%); compare: 0.01s (0%);

shu1chen avatar Jul 24 '25 03:07 shu1chen

@shu1chen Thanks for your comments. The example works. But I still have some questions.

  1. src & dst with f32 data type will be supported for u4 / s4 decompression?
  2. My branch is rls-v3.8. The impl kernel is ref. Why?
./benchdnn --engine=cpu --matmul --dt=f16:s4:f16 --stag=ab --wtag=ba --dtag=ab --attr-fpmath=any:true --attr-scales=wei:per_tensor:f16:256x1 --attr-zero-points=wei:per_tensor:s4:256x1 1024x1024:1024x1024
onednn_verbose,v1,info,oneDNN v3.8.0 (commit a762d3248ee5e04b2348f3a5aeecfa64da4634d8)
onednn_verbose,v1,info,cpu,runtime:OpenMP,nthr:192
onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and Intel AMX with bfloat16 and 8-bit integer support
onednn_verbose,v1,info,gpu,runtime:none
onednn_verbose,v1,info,graph,backend,0:dnnl_backend
onednn_verbose,v1,info,experimental features are enabled
onednn_verbose,v1,info,use batch_normalization stats one pass is enabled
onednn_verbose,v1,info,GPU convolution v2 is disabled
onednn_verbose,v1,info,experimental functionality for sparse domain is enabled
onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,brg_matmul:avx10_1_512_amx,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported attribute,src/cpu/x64/matmul/brgemm_matmul.cpp:178
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,brg_matmul:avx10_1_512,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported attribute,src/cpu/x64/matmul/brgemm_matmul.cpp:178
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,brg_matmul:avx512_core_bf16,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported attribute,src/cpu/x64/matmul/brgemm_matmul.cpp:178
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,brg_matmul:avx512_core_vnni,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported attribute,src/cpu/x64/matmul/brgemm_matmul.cpp:178
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,brg_matmul:avx512_core,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported attribute,src/cpu/x64/matmul/brgemm_matmul.cpp:178
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,brg_matmul:avx2_vnni,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported attribute,src/cpu/x64/matmul/brgemm_matmul.cpp:178
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,gemm:jit:f32,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported datatype combination,src/cpu/matmul/gemm_f32_matmul.cpp:93
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,gemm:jit:bf16,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported datatype combination,src/cpu/matmul/gemm_bf16_matmul.cpp:63
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,gemm:jit:bf16,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported datatype combination,src/cpu/matmul/gemm_bf16_matmul.cpp:63
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,gemm:jit,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported datatype combination,src/cpu/matmul/gemm_x8s8s32x_matmul.cpp:120
onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,brg_matmul:avx2,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,unsupported attribute,src/cpu/x64/matmul/brgemm_matmul.cpp:178
onednn_verbose,v1,primitive,create:cache_miss,cpu,matmul,ref:any,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,0.0410156
onednn_verbose,v1,primitive,create:cache_hit,cpu,matmul,ref:any,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,0.00512695
onednn_verbose,v1,primitive,create:check,matmul,unsupported attribute,src/common/matmul.cpp:78
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,simple:any,undef,src:f32::blocked:ba::f0 dst:s4::blocked:ba::f0,,,1024x1024,0.0510254
onednn_verbose,v1,primitive,exec,cpu,reorder,simple:any,undef,src:f32::blocked:ba::f0 dst:s4::blocked:ba::f0,,,1024x1024,34.833
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1024x1024,0.26709
onednn_verbose,v1,primitive,exec,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1024x1024,0.072998
onednn_verbose,v1,primitive,exec,cpu,matmul,ref:any,undef,src:f16::blocked:ab::f0 wei:s4::blocked:ba::f0 dst:f16::blocked:ab::f0,attr-fpmath:any:true attr-scales:wei:4095:f16:256x1 attr-zero-points:wei:4095:s4:256x1,,1024x1024:1024x1024,4397.66
onednn_verbose,v1,primitive,create:dispatch,reorder,unsupported datatype,src/cpu/x64/matmul/brgemm_matmul_reorders.cpp:294
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,jit_direct_copy:uni,undef,src:f16::blocked:ab::f0 dst:f32::blocked:ab::f0,,,1024x1024,0.411865
onednn_verbose,v1,primitive,exec,cpu,reorder,jit_direct_copy:uni,undef,src:f16::blocked:ab::f0 dst:f32::blocked:ab::f0,,,1024x1024,0.187012
0:PASSED (4929 ms) __REPRO: --matmul --dt=f16:s4:f16 --stag=ab --wtag=ba --dtag=ab --attr-scales=wei:per_tensor:f16:256x1 --attr-zero-points=wei:per_tensor:s4:256x1 --attr-fpmath=any:true 1024x1024:1024x1024
tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 4.93s; create_pd: 0.00s (0%); create_prim: 0.00s (0%); fill: 0.07s (1%); execute: 4.40s (89%); compute_ref: 0.44s (9%); compare: 0.01s (0%);

alanzhai219 avatar Jul 25 '25 08:07 alanzhai219

  1. src & dst with f32 data type will be supported for u4 / s4 decompression?

When using u4 or s4 quantized weights, it's more typical to pair them with bf16 (bfloat16) or f16 (float16) for src and dst, rather than f32. Although the combinations f32:u4:f32 and f32:s4:f32 are not explicitly listed in the oneDNN matmul documentation, they are functionally supported on CPU through the ref:any (reference) implementation. You can verify this using benchdnn with the following command:

$ ./tests/benchdnn/benchdnn -v7 --engine=cpu --fast-ref=false --matmul --dt=f32:u4:f32 --stag=ab --wtag=ba --dtag=ab --attr-scales=wei:common --attr-zero-points=wei:common --attr-fpmath=strict:true 1024x1024:1024x1024
create: --matmul --fast-ref=false --dt=f32:u4:f32 --stag=ab --wtag=ba --dtag=ab --attr-fpmath=strict:true 1024x1024:1024x1024
oneDNN implementation: ref:any
  1. My branch is rls-v3.8. The impl kernel is ref. Why?

In your case, the ref implementation is selected due to the specific configuration of scales and zero-points in the benchdnn command. You can observe the dispatching logic in the verbose logs (look for create:dispatch) to understand why optimized paths are not selected.

For example, when using f16:u4:f16 with appropriate attributes and allowing flexible math modes, oneDNN dispatches to an optimized kernel like brg_matmul:avx10_1_512:

$ ./tests/benchdnn/benchdnn -v7 --engine=cpu --fast-ref=false --matmul --dt=f16:u4:f16 --stag=ab --wtag=ba --dtag=ab --attr-scales=wei:common --attr-zero-points=wei:common --attr-fpmath=any:true 1024x1024:1024x1024
create: --matmul --fast-ref=false --dt=f16:u4:f16 --stag=ab --wtag=ba --dtag=ab --attr-fpmath=any:true 1024x1024:1024x1024
oneDNN implementation: brg_matmul:avx10_1_512

This shows that optimized kernels are available under certain conditions, and the fallback to ref is likely due to attribute constraints or platform limitations.

shu1chen avatar Jul 28 '25 03:07 shu1chen

@shu1chen I did some debugs based on the https://github.com/uxlfoundation/oneDNN/blob/f29c7d0f19ef01912a0c690c0a7f6b510493bd4d/examples/tutorials/matmul/int4_weight_decompression.cpp. The error reports from the below codes.

    // https://github.com/uxlfoundation/oneDNN/blob/rls-v3.8/src/common/primitive_attr.cpp#L103
    CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::zero_points_groups),
            zero_points_.has_default_groups()));
  1. I don't know how to set the zero_points_groups. Could you help give some modification suggestion on above sample? My goal is to call onednn brgemm kernel for u4/s4 group weight decompression.
  2. Any document on group config limitation? mush be divisible by 32?

alanzhai219 avatar Jul 30 '25 03:07 alanzhai219

I don't know how to set the zero_points_groups. Could you help give some modification suggestion on above sample? My goal is to call onednn brgemm kernel for u4/s4 group weight decompression.

To experiment with different configurations, you can use benchdnn following the guidance in the attributes documentation. This can help verify whether the implementation falls back to BRGEMM. Based on the results, you can then adjust the inputs to the set_zero_points function in the test code accordingly.

Any document on group config limitation? mush be divisible by 32?

Yes, due to hardware specifics, when the number of groups is greater than 1, it must be a multiple of 32. This is related to hardware alignment, and there is currently no documentation for it.

shu1chen avatar Jul 30 '25 09:07 shu1chen

@shu1chen I find the reorder between ab and ba in u4/s4 decompression sample. Will you provide such reorder kernel?

alanzhai219 avatar Aug 08 '25 07:08 alanzhai219

@shu1chen ping

alanzhai219 avatar Aug 11 '25 05:08 alanzhai219

I find the reorder between ab and ba in u4/s4 decompression sample. Will you provide such reorder kernel?

Hello @alanzhai219, Just to clarify, are you referring to whether the reorder primitive can support transposing between ab and ba for u4/s4 data types? If so, this functionality is currently not supported. INT4 support in oneDNN is limited. Also, as documented in the oneDNN Developer Guide, s4/u4 data types are only supported as a storage data type for weights argument in case of weights decompression.

shu1chen avatar Aug 11 '25 06:08 shu1chen