`vmap` with `scatter_add` extremely slow when using `xla_gpu_deterministic_ops`
Description
The issue is 1) about a rather significant slow-down to the scatter_add operation when running jax with the xla_gpu_deterministic_ops=true flag, and 2) about a further disproportionately large slow-down when using vmap around a scatter_add operation.
Below is the code to reproduce the issue. The timings are run with and without prepending os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" at the start of the script.
Firstly, just a regular scatter_add benchmark:
import jax
import jax.numpy as jnp
def scatter_add(
operand, # [operand_size]
updates, # [updates_size]
indices, # [updates_size, 1]
):
# Define dimension numbers
update_window_dims = tuple()
inserted_window_dims = (0,)
scatter_dims_to_operand_dims = (0,)
res = jax.lax.scatter_add(
operand,
indices,
updates,
dimension_numbers=jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims),
mode="drop",
)
return res
operand_size = 64 * 64 # e.g. a 64x64 image
operand = jnp.zeros((operand_size,))
updates = jnp.ones((operand_size * 4))
rng = jax.random.PRNGKey(0)
indices = jax.random.randint(rng, shape=(operand_size * 4, 1), minval=0, maxval=operand_size)
scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_jit(operand, updates, indices).block_until_ready()
# Without: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 25.3 µs ± 81 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 46.1 ms ± 4.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Secondly, a scatter_add benchmark with vmap:
n_batches = 100
operand = jnp.zeros((n_batches, operand_size,))
updates = jnp.ones((n_batches, operand_size * 4))
rng = jax.random.PRNGKey(0)
indices = jax.random.randint(rng, shape=(n_batches, operand_size * 4, 1), minval=0, maxval=operand_size)
scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)
scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
# Without: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 79.7 µs ± 173 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
>>> 17.4 s ± 61.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
It seems pretty unexpected that, when the xla_gpu_deterministic_ops flag is set to true, calling scatter_add with vmap with a batch-size of 100 makes the runtime 377x longer, i.e. 3.7 times slower than just using a manual python for-loop.
Unrelatedly, although the slow-down of scatter_add is to be expected when enforcing determinism, it is rather severe (almost 2000x slower without vmap, and over 200000x slower with vmap).
I guess this operation doesn't come up very regularly, but it appears, for example, in the backward pass through a bilinear interpolation of an image (e.g. when using jax.scipy.ndimage.map_coordinates). Even if the vmap issue gets resolved, it would be absolutely fantastic if, in addition, there was some kind of warning about the potential impact on runtime that was shown when executing code with --xla_gpu_deterministic_ops=true.
What jax/jaxlib version are you using?
jax v0.4.16, jax v0.4.16+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info
Linux, Ubuntu 22.04.3 LTS, Python 3.11.3
NVIDIA GPU info
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 On | 00000000:01:00.0 Off | Off |
| 30% 43C P2 139W / 500W| 20400MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
Reproduced on a 3090 as well.
At least at the moment I think this is expected: deterministic scatters are much slower on GPU because they eliminate any parallelism. XLA would need to emit different code for a faster determistic scatter.
@hawkinsp thanks for a response, it absolutely makes sense that deterministic scatters should be slower.
Do you think it's expected that you should get an additional slow-down from vmap? I.e. that this:
scatter_add_batched = jax.vmap(scatter_add, in_axes=(0, 0, 0), out_axes=0)
scatter_add_batched_jit = jax.jit(scatter_add_batched)
scatter_add_batched_jit(operand, updates, indices).block_until_ready()
%timeit scatter_add_batched_jit(operand, updates, indices).block_until_ready()
>>> 17.4 s
is 2.7x times as slow as this:
scatter_add_jit = jax.jit(scatter_add)
scatter_add_jit(operand, updates, indices).block_until_ready()
%timeit [scatter_add_jit(operand[i], updates[i], indices[i]).block_until_ready() for i in range(len(operand))]
>>> 4.6 s
with os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" enabled? I couldn't think of a reason this should be the case; both are deterministic, and I'd think compiling to xla should be at least as fast as doing the loop in python.
I encountered what I'm fairly confident is the same vmap-related slowdown on TPU, profiled it and discovered that while vmap of my function produces a scatter in the jaxpr, that op lowers to a while which loops over the mapped axis in XLA, with the body of the while loop containing a dynamic-update-slice, not a scatter. Presumably since XLA whiles cannot, in general, be parallelized, the compiler is unable to see this potential optimization. I don't know if this issue is common to the lowering any 'ragged' scatter.
@BrunoKM I will try your Python loop workaround and see whether that improves my use-case for now.
Setup:
In [1]: from jax import jit, lax, vmap, make_jaxpr
In [2]: import jax.numpy as jnp
In [3]: operand = jnp.ones((3, 4, 5))
In [4]: updates = jnp.ones((3, 2, 5))
In [5]: starts = jnp.ones((3,), dtype='int32')
In [6]: from functools import partial
In [7]: f = partial(lax.dynamic_update_slice_in_dim, axis=0)
Printing the jaxpr, note there is a single scatter op:
In [8]: make_jaxpr(vmap(f))(operand, updates, starts)
Out[8]:
{ lambda ; a:f32[3,4,5] b:f32[3,2,5] c:i32[3]. let
d:bool[3] = lt c 0
e:i32[3] = add c 4
f:i32[3] = select_n d c e
g:i32[] = add 0 5
h:i32[] = select_n False 0 g
i:i32[3,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(3, 1)] f
j:i32[3,1] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 1)] h
k:i32[3,2] = concatenate[dimension=1] i j
l:i32[3,1] = iota[dimension=0 dtype=int32 shape=(3, 1)]
m:i32[3,3] = concatenate[dimension=1] l k
n:f32[3,4,5] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2))
indices_are_sorted=True
mode=GatherScatterMode.CLIP
unique_indices=True
update_consts=()
update_jaxpr=None
] a m b
in (n,) }
Printing the HLO, because it is long and hard to read I've surrounded the relevant parts in ########:
In [9]: print(jit(vmap(f)).lower(operand, updates, starts).compile().as_text())
HloModule jit__unnamed_function_, entry_computation_layout={(f32[3,4,5]{2,1,0}, f32[3,2,5]{2,1,0}, s32[3]{0})->f32[3,4,5]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
#############################################################
%fused_computation (param_0: f32[3,4,5], param_1.3: f32[3,2,5], param_2.7: s32[], param_3.8: pred[], param_4.8: s32[3,3]) -> f32[3,4,5] {
%param_0 = f32[3,4,5]{2,1,0} parameter(0)
%param_3.8 = pred[] parameter(3)
%broadcast.18 = pred[1,2,5]{2,1,0} broadcast(pred[] %param_3.8), dimensions={}
%param_1.3 = f32[3,2,5]{2,1,0} parameter(1)
%param_2.7 = s32[] parameter(2)
%constant.23 = s32[] constant(0)
%dynamic-slice.7 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,2,5]{2,1,0} %param_1.3, s32[] %param_2.7, s32[] %constant.23, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
%param_4.8 = s32[3,3]{1,0} parameter(4)
%dynamic-slice.8 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_4.8, s32[] %param_2.7, s32[] %constant.23), dynamic_slice_sizes={1,3}
%slice.23 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.8), slice={[0:1], [0:1]}
%bitcast.6 = s32[] bitcast(s32[1,1]{1,0} %slice.23)
%bitcast.7 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.8)
%slice.22 = s32[1]{0} slice(s32[3]{0} %bitcast.7), slice={[1:2]}
%bitcast.5 = s32[] bitcast(s32[1]{0} %slice.22)
%dynamic-slice.6 = f32[1,2,5]{2,1,0} dynamic-slice(f32[3,4,5]{2,1,0} %param_0, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23), dynamic_slice_sizes={1,2,5}
%select.1 = f32[1,2,5]{2,1,0} select(pred[1,2,5]{2,1,0} %broadcast.18, f32[1,2,5]{2,1,0} %dynamic-slice.7, f32[1,2,5]{2,1,0} %dynamic-slice.6)
###########################################################
# Note the shape of the update array, it is 1 in the batch dimension
ROOT %dynamic-update-slice.2 = f32[3,4,5]{2,1,0} dynamic-update-slice(f32[3,4,5]{2,1,0} %param_0, f32[1,2,5]{2,1,0} %select.1, s32[] %bitcast.6, s32[] %bitcast.5, s32[] %constant.23)
###########################################################
}
#############################################################
%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
%lhs = pred[] parameter(0)
%rhs = pred[] parameter(1)
ROOT %and = pred[] and(pred[] %lhs, pred[] %rhs)
}
%fused_computation.1 (param_0.4: s32[3]) -> pred[] {
%constant.26 = s32[] constant(0)
%broadcast.19 = s32[3]{0} broadcast(s32[] %constant.26), dimensions={}
%param_0.4 = s32[3]{0} parameter(0)
%compare.4 = pred[3]{0} compare(s32[3]{0} %broadcast.19, s32[3]{0} %param_0.4), direction=LE
%constant.25 = s32[3]{0} constant({2, 2, 0})
%compare.3 = pred[3]{0} compare(s32[3]{0} %constant.25, s32[3]{0} %param_0.4), direction=GE
%and.2 = pred[3]{0} and(pred[3]{0} %compare.4, pred[3]{0} %compare.3)
%constant.24 = pred[] constant(true)
ROOT %reduce.1 = pred[] reduce(pred[3]{0} %and.2, pred[] %constant.24), dimensions={0}, to_apply=%and.reduce_sub_computation
}
%fused_computation.2 (param_0.7: s32[3,3], param_1.15: s32[]) -> s32[3] {
%param_0.7 = s32[3,3]{1,0} parameter(0)
%param_1.15 = s32[] parameter(1)
%constant.27 = s32[] constant(0)
%dynamic-slice.9 = s32[1,3]{1,0} dynamic-slice(s32[3,3]{1,0} %param_0.7, s32[] %param_1.15, s32[] %constant.27), dynamic_slice_sizes={1,3}
%slice.25 = s32[1,1]{1,0} slice(s32[1,3]{1,0} %dynamic-slice.9), slice={[0:1], [0:1]}
%bitcast.9 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.25)
%bitcast.8 = s32[3]{0} bitcast(s32[1,3]{1,0} %dynamic-slice.9)
%slice.24 = s32[2]{0} slice(s32[3]{0} %bitcast.8), slice={[1:3]}
ROOT %concatenate.2 = s32[3]{0} concatenate(s32[1]{0} %bitcast.9, s32[2]{0} %slice.24), dimensions={0}
}
#############################################################
%while_body (param.1: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> (s32[], f32[3,4,5], s32[3,3], f32[3,2,5]) {
%param.1 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
%get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=0
%copy.3 = s32[] copy(s32[] %get-tuple-element.12)
%constant.10 = s32[] constant(1)
%add = s32[] add(s32[] %copy.3, s32[] %constant.10)
%get-tuple-element.13 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=1
%get-tuple-element.19 = f32[3,2,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=3
%get-tuple-element.18 = s32[3,3]{1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.1), index=2
%fusion.2 = s32[3]{0} fusion(s32[3,3]{1,0} %get-tuple-element.18, s32[] %copy.3), kind=kLoop, calls=%fused_computation.2
%fusion.1 = pred[] fusion(s32[3]{0} %fusion.2), kind=kLoop, calls=%fused_computation.1
###########################################################
%fusion = f32[3,4,5]{2,1,0} fusion(f32[3,4,5]{2,1,0} %get-tuple-element.13, f32[3,2,5]{2,1,0} %get-tuple-element.19, s32[] %copy.3, pred[] %fusion.1, s32[3,3]{1,0} %get-tuple-element.18), kind=kLoop, calls=%fused_computation
###########################################################
ROOT %tuple.5 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %add, f32[3,4,5]{2,1,0} %fusion, s32[3,3]{1,0} %get-tuple-element.18, f32[3,2,5]{2,1,0} %get-tuple-element.19)
}
#############################################################
%while_cond (param.0: (s32[], f32[3,4,5], s32[3,3], f32[3,2,5])) -> pred[] {
%param.0 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) parameter(0)
%get-tuple-element = s32[] get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %param.0), index=0
%constant.1 = s32[] constant(3)
ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.1), direction=LT
}
%fused_computation.3 (param_0.10: s32[3]) -> s32[3,3] {
%constant.30 = s32[] constant(0)
%broadcast.25 = s32[3,3]{1,0} broadcast(s32[] %constant.30), dimensions={}
%iota.1 = s32[3,1]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(<unnamed function>)/jit(main)/iota[dtype=int32 shape=(3, 1) dimension=0]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%param_0.10 = s32[3]{0} parameter(0)
%broadcast.24 = s32[3]{0} broadcast(s32[] %constant.30), dimensions={}
%compare.5 = pred[3]{0} compare(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.24), direction=LT, metadata={op_name="jit(<unnamed function>)/jit(main)/lt" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%constant.29 = s32[] constant(4)
%broadcast.23 = s32[3]{0} broadcast(s32[] %constant.29), dimensions={}
%add.1 = s32[3]{0} add(s32[3]{0} %param_0.10, s32[3]{0} %broadcast.23), metadata={op_name="jit(<unnamed function>)/jit(main)/add" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%select.2 = s32[3]{0} select(pred[3]{0} %compare.5, s32[3]{0} %add.1, s32[3]{0} %param_0.10), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" sou
rce_file="<ipython-input-9-61375df8be79>" source_line=1}
%bitcast.10 = s32[3,1]{1,0} bitcast(s32[3]{0} %select.2), metadata={op_name="jit(<unnamed function>)/jit(main)/select_n" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%broadcast.22 = s32[3,1]{1,0} broadcast(s32[] %constant.30), dimensions={}
%concatenate.3 = s32[3,3]{1,0} concatenate(s32[3,1]{1,0} %iota.1, s32[3,1]{1,0} %bitcast.10, s32[3,1]{1,0} %broadcast.22), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/concatenate[dimension=1]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%constant.28 = s32[3]{0} constant({2, 2, 0})
%broadcast.21 = s32[3,3]{1,0} broadcast(s32[3]{0} %constant.28), dimensions={1}, metadata={op_name="jit(<unnamed function>)/jit(main)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=(1,)]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
ROOT %clamp.0 = s32[3,3]{1,0} clamp(s32[3,3]{1,0} %broadcast.25, s32[3,3]{1,0} %concatenate.3, s32[3,3]{1,0} %broadcast.21), metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
#############################################################
ENTRY %main.26 (Arg_0.1: f32[3,4,5], Arg_1.2: f32[3,2,5], Arg_2.3: s32[3]) -> f32[3,4,5] {
%constant.4 = s32[] constant(0)
%copy.8 = s32[] copy(s32[] %constant.4)
%Arg_0.1 = f32[3,4,5]{2,1,0} parameter(0), sharding={replicated}
%copy.7 = f32[3,4,5]{2,1,0} copy(f32[3,4,5]{2,1,0} %Arg_0.1)
%Arg_2.3 = s32[3]{0} parameter(2), sharding={replicated}
%fusion.3 = s32[3,3]{1,0} fusion(s32[3]{0} %Arg_2.3), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(<unnamed function>)/jit(main)/clamp" source_file="<ipython-input-9-61375df8be79>" source_line=1}
%Arg_1.2 = f32[3,2,5]{2,1,0} parameter(1), sharding={replicated}
%tuple.3 = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) tuple(s32[] %copy.8, f32[3,4,5]{2,1,0} %copy.7, s32[3,3]{1,0} %fusion.3, f32[3,2,5]{2,1,0} %Arg_1.2)
###########################################################
%while = (s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) while((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %tuple.3), condition=%while_cond, body=%while_body, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
###########################################################
ROOT %get-tuple-element.5 = f32[3,4,5]{2,1,0} get-tuple-element((s32[], f32[3,4,5]{2,1,0}, s32[3,3]{1,0}, f32[3,2,5]{2,1,0}) %while), index=1, metadata={op_name="jit(<unnamed function>)/jit(main)/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.CLIP update_jaxpr=None update_consts=()]" source_file="<ipython-input-9-61375df8be79>" source_line=1}
}
#############################################################
Here's a more minimal, self-contained snippet which reproduces the vmap slowdown. On TPU the vmap version is around 2x slower than using a Python loop and jnp.stacking the result.
from timeit import timeit
from jax import jit, lax, vmap
import jax.numpy as jnp
# For f which outputs a single array, this simulates vmap using Python map
pymap = lambda f: lambda *args: jnp.stack(list(map(f, *args)))
operands = jnp.ones((100, 32))
updates = jnp.ones((100, 2))
starts = jnp.ones((100, 1), dtype='int32')
f = lax.dynamic_update_slice
f_vmapped = jit(vmap(f))
f_pymapped = jit(pymap(f))
# Ensure compiled
f_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)
t_vmapped = timeit(
lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) / 100
t_pymapped = timeit(
lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) / 100
print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s")
On a TPU v4-8 VM I get:
Time vmap(f): 0.00088s
Time pymap(f): 0.00036s
Running the script on CPU on my laptop, the Python loop version is slower than the vmap version
Time vmap(f): 1.3e-05s
Time pymap(f): 3.3e-05s
I realize this is an older issue, but one option is to roll your own deterministic scatter_add (using prefix sums):
def add_segment(iv, jt):
i, v = iv
j, t = jt
return j, v * jp.equal(i, j) + t
@jax.jit
def scatter_add_det(operand, updates, indices):
indices = jp.reshape(indices, updates.shape)
# Sort the indices and the values by the indices.
indices, sorted = jax.lax.sort_key_val(indices, updates, dimension=-1)
# Sum up runs of the same index - the sum for each index will be at the end of each run.
_, sums = jax.lax.associative_scan(add_segment, (indices, sorted))
# Produce an array of bools - if an element is set then the position
# is the end of run of the same index.
end_of_run = jp.concatenate([jp.not_equal(indices[1:], indices[:-1]), jp.array([True])])
# Set all position that are not at end of run to an out-of-bound index.
indices = jp.where(end_of_run, indices, operand.shape[-1])
# Now do scatter-add where we know the (in-bounds) indices are unique.
# That is still fast on GPUs (no non-determinism from atomics).
return operand.at[indices].add(sums, mode='drop', unique_indices=True)
This is 5-15x slower than the non-deterministic one (depending on shape of things), but at least it's not multiple orders of magnitude. It would be nice if XLA could lower to something like this automatically.
Is this issue got fixed? My test shows that currently deterministic scatter is 10x slower than regular scatter, even with vmap.