jax.lax.cond() signficantly slower when called inside jax.lax.scan()
I've been using cond() inside my scan() loop, and noticed it was much slower compared to just replacing it with cond * true + (1-cond) * false. Outside of scan(), cond(cond,true,false) has the same runtime as cond * true + (1-cond) * false, so I suspect there is some optimization bug when cond() is called inside scan().
an example in google colab can be found here: https://colab.research.google.com/drive/1HFYaf-IbRu7ooLGhDvqFY7Bz39OW4bqX
I did verify that _cond2_scan is much slower. @hawkinsp do you know if there's some XLA optimizations that may be relevant here?
Here's the jaxpr
make_jaxpr(_cond1_scan)(True,jnp.zeros((1000,10)),jnp.ones((1000,10)))
{ lambda ; a b c.
let d = xla_call[ backend=None
call_jaxpr={ lambda ; a b c.
let d = scan[ jaxpr={ lambda ; a b c.
let d = xla_call[ backend=None
call_jaxpr={ lambda ; a b c.
let d = convert_element_type[ new_dtype=float32 ] a
e = mul d b
f = convert_element_type[ new_dtype=int32 ] a
g = sub 1 f
h = convert_element_type[ new_dtype=float32 ] g
i = mul h c
j = add e i
in (j,) }
device=None
donated_invars=(False, False, False)
name=_cond1 ] a b c
in (d,) }
length=1000
linear=(False, False, False)
num_carry=0
num_consts=1
reverse=False
unroll=1 ] a b c
in (d,) }
device=None
donated_invars=(False, False, False)
name=_cond1_scan ] a b c
in (d,) }
make_jaxpr(_cond2_scan)(True,jnp.zeros((1000,10)),jnp.ones((1000,10)))
{ lambda ; a b c.
let d = xla_call[ backend=None
call_jaxpr={ lambda ; a b c.
let d = scan[ jaxpr={ lambda ; a b c.
let d = xla_call[ backend=None
call_jaxpr={ lambda ; a b c.
let d = convert_element_type[ new_dtype=int32 ] a
e = cond[ branches=( { lambda ; b_ a.
let
in (a,) }
{ lambda ; a c_.
let
in (a,) } )
linear=(False, False) ] d b c
in (e,) }
device=None
donated_invars=(False, False, False)
name=_cond2 ] a b c
in (d,) }
length=1000
linear=(False, False, False)
num_carry=0
num_consts=1
reverse=False
unroll=1 ] a b c
in (d,) }
device=None
donated_invars=(False, False, False)
name=_cond2_scan ] a b c
in (d,) }```
The slowdown in the colab above is on CPU backend.
The slowdown is less significant proportionately on TPU.
Maybe there's some issue in XLA/CPU?
I'm not actually seeing a big difference on CPU. The significant slow down was on GPU.