pytorch
pytorch copied to clipboard
Enable previously disabled FA related Operators in UTs
They were disabled in AOTriton V1, but V2 should fix most of them.
Passed with
PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_meta.py -k flash_attention -v
PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_ops.py -k flash_attention -v
PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_meta.py -k functional_scaled_dot_product_attention_cuda -v
PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_ops.py -k functional_scaled_dot_product_attention_cuda -v