[GPU] Modify GEMM attr group to support reshape
Description
Address performance regressions in batched gemm with grouped attrs. Previously disabled 3d->2d reshape to avoid correctness issues, in some cases we can instead reshape grouped dims to preserve correctness and performance.
Added support for per_tensor grouped attrs with 3d->2d reshape to support OV use cases.
Fixes # https://jira.devtools.intel.com/browse/MFDNN-13651
Checklist
General
- [x] Do all unit and benchdnn tests (
make testandmake test_benchdnn_*) pass locally for each commit? - [x] Have you formatted the code using clang-format?
Performance improvements
- [x] Have you submitted performance data that demonstrates performance improvements?
make test set test_scope=NIGHTLY disable test_device_cpu disable benchdnn_all enable benchdnn_matmul
make test perf-gpu set primitive=matmul ip
make test set test_scope=NIGHTLY disable test_device_cpu disable benchdnn_all enable benchdnn_matmul
make test perf-gpu set primitive=matmul ip
make test set test_scope=NIGHTLY disable test_device_cpu disable benchdnn_all enable benchdnn_matmul
make test set test_scope=NIGHTLY disable test_device_cpu disable benchdnn_all enable benchdnn_matmul
make test perf-gpu set primitive=matmul
@Simonsays095 Here are some impacted layers on BMG:
matmul improvements: | layer | main | kealanba/cvt_attrs | ratio |
| :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------: | -----------------: | ----: |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2056x1x4096:1x4096x27392 | 284.426000 | 2.915210 | 97.57 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2056x1x4096:1x4096x27392 | 283.988000 | 2.954380 | 96.12 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 1032x1x4096:1x4096x27392 | 142.770000 | 1.555830 | 91.76 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 1032x1x4096:1x4096x27392 | 142.558000 | 1.573440 | 90.60 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3072+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x3072+wei:per_oc:u8 2056x1x3072:1x3072x24576 | 160.054000 | 2.153330 | 74.33 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2176x1x4096:1x4096x14336 | 113.579000 | 1.555000 | 73.04 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3072+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2056x1x3072:1x3072x24576 | 156.052000 | 2.144480 | 72.77 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2048x1x4096:1x4096x14336 | 106.613000 | 1.468330 | 72.61 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3584+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x3584+wei:per_oc:u8 2056x1x3584:1x3584x18944 | 132.195000 | 1.826980 | 72.36 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --bia-dt=f16 --bia_mask=4 --attr-scales=src:per_ocic:f16:1x10240+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2048x1x10240:1x10240x2560 | 44.543200 | 0.615729 | 72.34 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --bia-dt=f16 --bia_mask=4 --attr-scales=src:per_ocic:f16:1x10240+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x10240+wei:per_oc:u8 2048x1x10240:1x10240x2560 | 44.692100 | 0.620625 | 72.01 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2176x1x4096:1x4096x14336 | 109.409000 | 1.533230 | 71.36 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3584+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2056x1x3584:1x3584x18944 | 128.134000 | 1.801460 | 71.13 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3072+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x3072+wei:per_oc:u8 1032x1x3072:1x3072x24576 | 79.856800 | 1.123750 | 71.06 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2048x1x4096:1x4096x14336 | 102.968000 | 1.450310 | 71.00 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3072+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 1032x1x3072:1x3072x24576 | 78.211000 | 1.121770 | 69.72 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3584+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x3584+wei:per_oc:u8 1032x1x3584:1x3584x18944 | 66.300300 | 0.959479 | 69.10 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 1088x1x4096:1x4096x14336 | 56.294000 | 0.818541 | 68.77 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x3584+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 1032x1x3584:1x3584x18944 | 64.187200 | 0.948541 | 67.67 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 1088x1x4096:1x4096x14336 | 54.688400 | 0.808333 | 67.66 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2128x1x4096:1x4096x11008 | 80.824800 | 1.198960 | 67.41 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2128x1x4096:1x4096x11008 | 79.580300 | 1.183650 | 67.23 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2176x1x4096:1x4096x6144 | 46.123100 | 0.688125 | 67.03 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2208x1x4096:1x4096x11008 | 83.881500 | 1.255100 | 66.83 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2208x1x4096:1x4096x11008 | 82.528400 | 1.239270 | 66.59 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2176x1x4096:1x4096x6144 | 45.410300 | 0.682708 | 66.51 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2128x1x4096:1x4096x12288 | 90.266400 | 1.360000 | 66.37 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2128x1x4096:1x4096x12288 | 88.814100 | 1.343960 | 66.08 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2048x1x4096:1x4096x6144 | 43.423600 | 0.657916 | 66.00 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2208x1x4096:1x4096x12288 | 93.630100 | 1.419060 | 65.98 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 1112x1x4096:1x4096x12288 | 47.162600 | 0.715937 | 65.88 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 2176x1x4096:1x4096x4096 | 30.755800 | 0.466979 | 65.86 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2208x1x4096:1x4096x12288 | 92.124400 | 1.401870 | 65.72 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 1112x1x4096:1x4096x11008 | 42.259200 | 0.643125 | 65.71 |
| --dt=u8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=src:per_ocic:u8:1x4096+wei:per_oc:u8 1088x1x4096:1x4096x4096 | 15.384900 | 0.234270 | 65.67 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 2176x1x4096:1x4096x4096 | 30.254500 | 0.461354 | 65.58 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 1112x1x4096:1x4096x11008 | 41.585200 | 0.634791 | 65.51 |
| --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-scales=src:per_ocic:f16:1x4096+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 1088x1x4096:1x4096x4096 | 15.114500 | 0.230729 | 65.51 |
@Simonsays095 Here are some impacted layers on BMG:
Thanks, Kealan. And just to make sure - no regressions?
@Simonsays095 correct, verified on PVC and BMG, I'll run a full set when CI is back up for thoroughness too.
make test set test_scope=NIGHTLY disable test_device_cpu disable benchdnn_all enable benchdnn_matmul
make test perf-gpu set primitive=matmul ip
make test set test_scope=NIGHTLY disable test_device_cpu disable benchdnn_all enable benchdnn_matmul