Some changes in `inner-padding` option
Summary:
The diff modifies the padding option and added tests with compile:
- For the scaled_mm of shape MxKxN, the current
inner_paddingoption only pads theKdimension. However, ifNis not divisible by 16, we also got the error
E RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasLtMatmulAlgoGetHeuristic( ltHandle, computeDesc.descriptor(), Adesc.descriptor(), Bdesc.descriptor(), Cdesc.descriptor(), Ddesc.descriptor(), preference.descriptor(), 1, &heuristicResult, &returnedResult)`
So, modified the pad_inner option to also pad the N dimensions.
- The compile of inner-padding only works with the triton PR https://github.com/triton-lang/triton/pull/4222.
Before the triton PR, the inductor code-gen kernel fails at
tmp10 = tl.where(tmp6, tmp8, tmp9)
TypeError: unexpected type fp8e5 and fp8e5
Reviewed By: irobert0126
Differential Revision: D62003827
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/858
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 1 New Failure
As of commit 9b6212f83e950ce070b6278cd5233a6ca7b2dac7 with merge base 8aa6533ae08d97664105b03d0e93bcbacd50da0b ():
NEW FAILURE - The following job has failed:
-
Run Float8 Tests / test (SM-89, linux.g6.4xlarge.experimental.nvidia.gpu, --pre torch --index-url https://download.p... / linux-job (gh)
RuntimeError: Command docker exec -t e50b322cfcf1cd127d433e98a31e5aaa2b7eafbb18d4fc75996fda16486b8df4 /exec failed with exit code 1
This comment was automatically generated by Dr. CI and updates every 15 minutes.
This pull request was exported from Phabricator. Differential Revision: D62003827
This pull request was exported from Phabricator. Differential Revision: D62003827
This pull request was exported from Phabricator. Differential Revision: D62003827
This pull request was exported from Phabricator. Differential Revision: D62003827
overall looks great! Can we have two more things:
- update https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_padding.py and share data on how this PR impacts performance
- mark this PR as BC-breaking and add the before and after state to the PR summary
About the test failures:
- Some
aot_eagertests failed due to:
E RecursionError: maximum recursion depth exceeded while calling a Python object
/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:309: RecursionError
_ test_aot_eager[dtype2-True-ScalingType.DYNAMIC-ScalingType.DELAYED-ScalingType.STATIC-True-True] _
From the error traces, it keeps calling the below functions for multiple times until exceeded the maximum recursion depth:
/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/backends/common.py:51: in _wrapped_bw_compiler
return disable(disable(bw_compiler)(*args, **kwargs))
/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:632: in _fn
return fn(*args, **kwargs)
/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/backends/common.py:51: in _wrapped_bw_compiler
return disable(disable(bw_compiler)(*args, **kwargs))
/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:632: in _fn
return fn(*args, **kwargs)
/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/backends/common.py:51: in _wrapped_bw_compiler
return disable(disable(bw_compiler)(*args, **kwargs))
/opt/conda/envs/venv/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:632: in _fn
- I'm still debugging this error.
-
test_inductortests failed at
tmp10 = tl.where(tmp6, tmp8, tmp9)
TypeError: unexpected type fp8e5 and fp8e5
It's expected - the compile of inner-padding only works if the triton package contains this PR https://github.com/triton-lang/triton/pull/4222.
- How should we add such tests in github CI?
cc @vkuzo do you have any suggestions? thanks!
It's expected - the compile of inner-padding only works if the triton package contains this PR https://github.com/triton-lang/triton/pull/4222
can you check if that PR is in triton 3.1.0? https://github.com/pytorch/pytorch/blob/main/.ci/docker/triton_version.txt is the current triton version in OSS PyTorch.
@vkuzo , it's not included in triton 3.1. The pr (#4222) shows in the difference between 3.1.x and main: https://github.com/triton-lang/triton/compare/release/3.1.x...main