josipd
josipd
Thanks for the fast reply! I observe the same behavior even with `torch.set_deterministic(True)` at the top and `CUBLAS_WORKSPACE_CONFIG=:4096:8`.
By the way, the problem is much worse under pytorch 1.7.0 (the maximum absolute difference goes up to 2.7236), and quite interestingly, when I use it for zero-shot classification only...
Btw, the above problem goes away on 1.7.0 with `jit=True` in `clip.load`.
At the moment there's only CPU implementations of the functions.
Not at the moment, but you could easily use any p-norm (i.e., exp(-|x-x'|_p) by modifying the following line https://github.com/josipd/torch-two-sample/blob/d0771287fa1ba820ad975f1f038bfd8e155d2b91/torch_two_sample/statistics_diff.py#L370 I will submit a patch soon that accepts ``norm`` as an...
Thanks for the bug report! Does this also happen for `k=1`?
Hello, The difference is the soft sort might result in easier optimization problems, see "convexification" in https://arxiv.org/pdf/2002.08871.pdf Intuitively, while the Jacobian of the normal sort operator is a permutation matrix,...
Yes, it is differentiable almost everywhere, with a Jacobian matrix that is a permutation matrix.
Hello all, I have been coincidentally also working on this, just pushed the WIP to my fork https://github.com/josipd/jax/blob/master/jax/experimental/jambax.py Would it make sense for you to collaborate to avoid duplicate work?...
Hey all, Should we agree on the API and have at least the CPU version checked in soon? Looking at the XLA GPU convention, I think what I suggested could...