GLJeff

Results 3 comments of GLJeff

To further clarify: I believe both implementions are wrong in the sense that they are not finding a scaling vector independent of the number of states. SM-G-SUM should set: grad_output[:,...

You're using a later version of pytorch than they did. Just remove the [0]

Is something like DeviceRadixSort even safe to run right now without your fix?