__shfl_sync instructions may have wrong member mask
When using WarpScanShfl from warp_scan_shfl.cuh inside a while() loop and in conjunction with a sub-warp LOGICAL_WARP_THREADS argument, i.e. LOGICAL_WARP_THREADS=2^n with n<5, I get lots of errors like these with cuda-memcheck --tool synccheck
========= Barrier error detected. Invalid arguments
========= at 0x000000d0 in __cuda_sm70_shflsync_idx
========= by thread (17,0,0) in block (204,0,0)
========= Device Frame:__cuda_sm70_shflsync_idx (__cuda_sm70_shflsync_idx : 0xd0)
========= Device Frame:/ccsopen/home/glaser/hoomd-blue/hoomd/extern/cub/cub/block/specializations/../../block/../util_ptx.cuh:358:void gpu_compute_nlist_binned_kernel<unsigned char=0, int=1, int=1>(unsigned int*, unsigned int*, double4*, unsigned int*, unsigned int const *, unsigned int const *, double4 const *, unsigned int const *, double const *, unsigned int, unsigned int const *, double4 const *, unsigned int const *, double4 const *, unsigned int const *, Index3D, Index2D, Index2D, BoxDim, double const *, double, unsigned int, double3, unsigned int, unsigned int, unsigned int) (void gpu_compute_nlist_binned_kernel<unsigned char=0, int=1, int=1>(unsigned int*, unsigned int*, double4*, unsigned int*, unsigned int const *, unsigned int const *, double4 const *, unsigned int const *, double const *, unsigned int, unsigned int const *, double4 const *, unsigned int const *, double4 const *, unsigned int const *, Index3D, Index2D, Index2D, BoxDim, double const *, double, unsigned int, double3, unsigned int, unsigned int, unsigned
I believe the root cause is the following.
WarpScanShfl sets its member_mask for the shfl_sync to reflect the sub-warp membership. However, what happens if some threads exit early, the compiler may predicate off this initialization statement, leaving member_mask in an invalid state. Later, when the PTX assembly instruction shfl.sync.idx.b32 is hit, it is executed without predicate (such as @p) and thus with a wrong mask. Then cuda-memcheck finds that the executing warp lane executes an implicit syncwarp but without its mask bits set, and issues an error, as documented here:
https://docs.nvidia.com/cuda/cuda-memcheck/index.html#synccheck-demo-illegal-syncwarp
The safe solution would be to always use the full mask (0xffffffffu) to synchronize the entire warp. I realize this may not fully take advantage of Volta's independent thread scheduling. However, if that were the goal I think the CUB API would have to expose the member_mask somehow to allow the user to set it, so that it is possible to issue e.g. a ballot_sync outside CUB first, and then pass the member mask to CUB. As discussed here: https://devblogs.nvidia.com/using-cuda-warp-level-primitives/
I will submit a pull request shortly for this solution.
This had a PR (#155). It looks like it closed during the reorganizations.
We should resurrect the PR, rebase it on main, and take a look, this sounds important.
@jglaser I know this issue is several years old, but do you happen to have some code that reproduces this or a recollection of how this situation occurred for you?
I'll summarize your report to make sure I understand it:
- A partial warp enters
WarpScanShfl::InclusiveScanusing aLOGICAL_WARP_SIZEless than 32. - The compiler implements this divergence by predicating the instructions from the warp scan implementation.
- WarpScanShuffle injects inline PTX that calls an unpredicated
shfl.sync.up.b32. - The lanes that /should/ be inactive execute this instruction with an uninitialized
member_mask.
Basically, this: https://www.godbolt.org/z/T6jrx75M6
I'm trying to find a repro where the compiler would implement the branch using an execution mask instead of branching so I can test this, but it's not clear how this would happen.