accelerated-scan
accelerated-scan copied to clipboard
warpscan: try float accumulation for bf16 and float16
@sustcsonglin has suggested that float accumulation might improve stability of the implementation. The current test I'm trying using to see this is:
python -m pytest tests -s -v -k forward | grep 'max abs error'
The results are so far the same, perhaps I need to find better test conditions to see this:
tests/test_eq.py::test_eq_forward[dtype0-32-3407-scan] max abs error 7.152557373046875e-07 seqlen 32 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-32-4-scan] max abs error 4.76837158203125e-07 seqlen 32 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-32-42-scan] max abs error 7.152557373046875e-07 seqlen 32 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-32-57-scan] max abs error 7.152557373046875e-07 seqlen 32 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-64-3407-scan] max abs error 7.152557373046875e-07 seqlen 64 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-64-4-scan] max abs error 7.152557373046875e-07 seqlen 64 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-64-42-scan] max abs error 7.152557373046875e-07 seqlen 64 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-64-57-scan] max abs error 7.152557373046875e-07 seqlen 64 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-128-3407-scan] max abs error 9.5367431640625e-07 seqlen 128 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-128-4-scan] max abs error 7.152557373046875e-07 seqlen 128 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-128-42-scan] max abs error 7.152557373046875e-07 seqlen 128 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-128-57-scan] max abs error 9.5367431640625e-07 seqlen 128 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-256-3407-scan] max abs error 7.152557373046875e-07 seqlen 256 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-256-4-scan] max abs error 9.5367431640625e-07 seqlen 256 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-256-42-scan] max abs error 9.5367431640625e-07 seqlen 256 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-256-57-scan] max abs error 7.152557373046875e-07 seqlen 256 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-512-3407-scan] max abs error 7.152557373046875e-07 seqlen 512 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-512-4-scan] max abs error 9.5367431640625e-07 seqlen 512 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-512-42-scan] max abs error 9.5367431640625e-07 seqlen 512 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-512-57-scan] max abs error 7.152557373046875e-07 seqlen 512 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-1024-3407-scan] max abs error 9.5367431640625e-07 seqlen 1024 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-1024-4-scan] max abs error 9.5367431640625e-07 seqlen 1024 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-1024-42-scan] max abs error 9.5367431640625e-07 seqlen 1024 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-1024-57-scan] max abs error 9.5367431640625e-07 seqlen 1024 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-2048-3407-scan] max abs error 9.5367431640625e-07 seqlen 2048 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-2048-4-scan] max abs error 9.5367431640625e-07 seqlen 2048 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-2048-42-scan] max abs error 9.5367431640625e-07 seqlen 2048 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-2048-57-scan] max abs error 9.5367431640625e-07 seqlen 2048 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-4096-3407-scan] max abs error 9.5367431640625e-07 seqlen 4096 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-4096-4-scan] max abs error 9.5367431640625e-07 seqlen 4096 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-4096-42-scan] max abs error 9.5367431640625e-07 seqlen 4096 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-4096-57-scan] max abs error 9.5367431640625e-07 seqlen 4096 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-8192-3407-scan] max abs error 9.5367431640625e-07 seqlen 8192 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-8192-4-scan] max abs error 9.5367431640625e-07 seqlen 8192 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-8192-42-scan] max abs error 9.5367431640625e-07 seqlen 8192 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-8192-57-scan] max abs error 9.5367431640625e-07 seqlen 8192 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-16384-3407-scan] max abs error 9.5367431640625e-07 seqlen 16384 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-16384-4-scan] max abs error 9.5367431640625e-07 seqlen 16384 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-16384-42-scan] max abs error 9.5367431640625e-07 seqlen 16384 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-16384-57-scan] max abs error 9.5367431640625e-07 seqlen 16384 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-32768-3407-scan] max abs error 9.5367431640625e-07 seqlen 32768 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-32768-4-scan] max abs error 9.5367431640625e-07 seqlen 32768 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-32768-42-scan] max abs error 9.5367431640625e-07 seqlen 32768 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-32768-57-scan] max abs error 1.430511474609375e-06 seqlen 32768 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-65536-3407-scan] max abs error 9.5367431640625e-07 seqlen 65536 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-65536-4-scan] max abs error 1.430511474609375e-06 seqlen 65536 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-65536-42-scan] max abs error 9.5367431640625e-07 seqlen 65536 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype0-65536-57-scan] max abs error 1.430511474609375e-06 seqlen 65536 dtype torch.float32
tests/test_eq.py::test_eq_forward[dtype1-32-3407-scan] max abs error 0.03125 seqlen 32 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-32-4-scan] max abs error 0.03125 seqlen 32 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-32-42-scan] max abs error 0.03125 seqlen 32 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-32-57-scan] max abs error 0.03125 seqlen 32 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-64-3407-scan] max abs error 0.03125 seqlen 64 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-64-4-scan] max abs error 0.03125 seqlen 64 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-64-42-scan] max abs error 0.03125 seqlen 64 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-64-57-scan] max abs error 0.03125 seqlen 64 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-128-3407-scan] max abs error 0.03125 seqlen 128 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-128-4-scan] max abs error 0.03125 seqlen 128 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-128-42-scan] max abs error 0.03125 seqlen 128 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-128-57-scan] max abs error 0.03125 seqlen 128 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-256-3407-scan] max abs error 0.03125 seqlen 256 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-256-4-scan] max abs error 0.03125 seqlen 256 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-256-42-scan] max abs error 0.03125 seqlen 256 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-256-57-scan] max abs error 0.03125 seqlen 256 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-512-3407-scan] max abs error 0.046875 seqlen 512 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-512-4-scan] max abs error 0.03125 seqlen 512 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-512-42-scan] max abs error 0.046875 seqlen 512 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-512-57-scan] max abs error 0.03125 seqlen 512 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-1024-3407-scan] max abs error 0.03125 seqlen 1024 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-1024-4-scan] max abs error 0.03125 seqlen 1024 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-1024-42-scan] max abs error 0.03125 seqlen 1024 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-1024-57-scan] max abs error 0.046875 seqlen 1024 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-2048-3407-scan] max abs error 0.03125 seqlen 2048 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-2048-4-scan] max abs error 0.046875 seqlen 2048 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-2048-42-scan] max abs error 0.03125 seqlen 2048 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-2048-57-scan] max abs error 0.046875 seqlen 2048 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-4096-3407-scan] max abs error 0.046875 seqlen 4096 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-4096-4-scan] max abs error 0.046875 seqlen 4096 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-4096-42-scan] max abs error 0.046875 seqlen 4096 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-4096-57-scan] max abs error 0.046875 seqlen 4096 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-8192-3407-scan] max abs error 0.0625 seqlen 8192 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-8192-4-scan] max abs error 0.046875 seqlen 8192 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-8192-42-scan] max abs error 0.046875 seqlen 8192 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-8192-57-scan] max abs error 0.0625 seqlen 8192 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-16384-3407-scan] max abs error 0.0625 seqlen 16384 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-16384-4-scan] max abs error 0.0625 seqlen 16384 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-16384-42-scan] max abs error 0.046875 seqlen 16384 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-16384-57-scan] max abs error 0.0625 seqlen 16384 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-32768-3407-scan] max abs error 0.0625 seqlen 32768 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-32768-4-scan] max abs error 0.0625 seqlen 32768 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-32768-42-scan] max abs error 0.0625 seqlen 32768 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-32768-57-scan] max abs error 0.0625 seqlen 32768 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-65536-3407-scan] max abs error 0.0625 seqlen 65536 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-65536-4-scan] max abs error 0.0625 seqlen 65536 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-65536-42-scan] max abs error 0.0625 seqlen 65536 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype1-65536-57-scan] max abs error 0.0625 seqlen 65536 dtype torch.bfloat16
tests/test_eq.py::test_eq_forward[dtype2-32-3407-scan] max abs error 0.00390625 seqlen 32 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-32-4-scan] max abs error 0.00390625 seqlen 32 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-32-42-scan] max abs error 0.00390625 seqlen 32 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-32-57-scan] max abs error 0.00390625 seqlen 32 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-64-3407-scan] max abs error 0.00390625 seqlen 64 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-64-4-scan] max abs error 0.00390625 seqlen 64 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-64-42-scan] max abs error 0.00390625 seqlen 64 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-64-57-scan] max abs error 0.00390625 seqlen 64 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-128-3407-scan] max abs error 0.00390625 seqlen 128 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-128-4-scan] max abs error 0.00390625 seqlen 128 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-128-42-scan] max abs error 0.00390625 seqlen 128 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-128-57-scan] max abs error 0.005859375 seqlen 128 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-256-3407-scan] max abs error 0.00390625 seqlen 256 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-256-4-scan] max abs error 0.005859375 seqlen 256 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-256-42-scan] max abs error 0.00390625 seqlen 256 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-256-57-scan] max abs error 0.00390625 seqlen 256 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-512-3407-scan] max abs error 0.00390625 seqlen 512 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-512-4-scan] max abs error 0.0078125 seqlen 512 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-512-42-scan] max abs error 0.00390625 seqlen 512 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-512-57-scan] max abs error 0.00390625 seqlen 512 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-1024-3407-scan] max abs error 0.00390625 seqlen 1024 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-1024-4-scan] max abs error 0.005859375 seqlen 1024 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-1024-42-scan] max abs error 0.005859375 seqlen 1024 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-1024-57-scan] max abs error 0.00390625 seqlen 1024 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-2048-3407-scan] max abs error 0.005859375 seqlen 2048 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-2048-4-scan] max abs error 0.005859375 seqlen 2048 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-2048-42-scan] max abs error 0.0078125 seqlen 2048 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-2048-57-scan] max abs error 0.00390625 seqlen 2048 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-4096-3407-scan] max abs error 0.0078125 seqlen 4096 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-4096-4-scan] max abs error 0.005859375 seqlen 4096 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-4096-42-scan] max abs error 0.005859375 seqlen 4096 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-4096-57-scan] max abs error 0.0078125 seqlen 4096 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-8192-3407-scan] max abs error 0.005859375 seqlen 8192 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-8192-4-scan] max abs error 0.0078125 seqlen 8192 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-8192-42-scan] max abs error 0.005859375 seqlen 8192 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-8192-57-scan] max abs error 0.0078125 seqlen 8192 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-16384-3407-scan] max abs error 0.005859375 seqlen 16384 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-16384-4-scan] max abs error 0.0078125 seqlen 16384 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-16384-42-scan] max abs error 0.0078125 seqlen 16384 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-16384-57-scan] max abs error 0.0078125 seqlen 16384 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-32768-3407-scan] max abs error 0.0078125 seqlen 32768 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-32768-4-scan] max abs error 0.0078125 seqlen 32768 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-32768-42-scan] max abs error 0.0078125 seqlen 32768 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-32768-57-scan] max abs error 0.0078125 seqlen 32768 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-65536-3407-scan] max abs error 0.0078125 seqlen 65536 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-65536-4-scan] max abs error 0.0078125 seqlen 65536 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-65536-42-scan] max abs error 0.0078125 seqlen 65536 dtype torch.float16
tests/test_eq.py::test_eq_forward[dtype2-65536-57-scan] max abs error 0.0078125 seqlen 65536 dtype torch.float16