GLJeff
Results
3
comments of
GLJeff
To further clarify: I believe both implementions are wrong in the sense that they are not finding a scaling vector independent of the number of states. SM-G-SUM should set: grad_output[:,...
You're using a later version of pytorch than they did. Just remove the [0]
Is something like DeviceRadixSort even safe to run right now without your fix?