jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.lax.cond() signficantly slower when called inside jax.lax.scan()

Open sokrypton opened this issue 5 years ago • 4 comments

I've been using cond() inside my scan() loop, and noticed it was much slower compared to just replacing it with cond * true + (1-cond) * false. Outside of scan(), cond(cond,true,false) has the same runtime as cond * true + (1-cond) * false, so I suspect there is some optimization bug when cond() is called inside scan().

an example in google colab can be found here: https://colab.research.google.com/drive/1HFYaf-IbRu7ooLGhDvqFY7Bz39OW4bqX

sokrypton avatar Mar 09 '21 09:03 sokrypton

I did verify that _cond2_scan is much slower. @hawkinsp do you know if there's some XLA optimizations that may be relevant here?

Here's the jaxpr

make_jaxpr(_cond1_scan)(True,jnp.zeros((1000,10)),jnp.ones((1000,10)))

{ lambda  ; a b c.
  let d = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b c.
                                 let d = scan[ jaxpr={ lambda  ; a b c.
                                                       let d = xla_call[ backend=None
                                                                         call_jaxpr={ lambda  ; a b c.
                                                                                      let d = convert_element_type[ new_dtype=float32 ] a
                                                                                          e = mul d b
                                                                                          f = convert_element_type[ new_dtype=int32 ] a
                                                                                          g = sub 1 f
                                                                                          h = convert_element_type[ new_dtype=float32 ] g
                                                                                          i = mul h c
                                                                                          j = add e i
                                                                                      in (j,) }
                                                                         device=None
                                                                         donated_invars=(False, False, False)
                                                                         name=_cond1 ] a b c
                                                       in (d,) }
                                               length=1000
                                               linear=(False, False, False)
                                               num_carry=0
                                               num_consts=1
                                               reverse=False
                                               unroll=1 ] a b c
                                 in (d,) }
                    device=None
                    donated_invars=(False, False, False)
                    name=_cond1_scan ] a b c
  in (d,) }


make_jaxpr(_cond2_scan)(True,jnp.zeros((1000,10)),jnp.ones((1000,10)))

{ lambda  ; a b c.
  let d = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b c.
                                 let d = scan[ jaxpr={ lambda  ; a b c.
                                                       let d = xla_call[ backend=None
                                                                         call_jaxpr={ lambda  ; a b c.
                                                                                      let d = convert_element_type[ new_dtype=int32 ] a
                                                                                          e = cond[ branches=( { lambda  ; b_ a.
                                                                                                                 let 
                                                                                                                 in (a,) }
                                                                                                               { lambda  ; a c_.
                                                                                                                 let 
                                                                                                                 in (a,) } )
                                                                                                    linear=(False, False) ] d b c
                                                                                      in (e,) }
                                                                         device=None
                                                                         donated_invars=(False, False, False)
                                                                         name=_cond2 ] a b c
                                                       in (d,) }
                                               length=1000
                                               linear=(False, False, False)
                                               num_carry=0
                                               num_consts=1
                                               reverse=False
                                               unroll=1 ] a b c
                                 in (d,) }
                    device=None
                    donated_invars=(False, False, False)
                    name=_cond2_scan ] a b c
  in (d,) }```

zhangqiaorjc avatar Mar 09 '21 23:03 zhangqiaorjc

The slowdown in the colab above is on CPU backend.

The slowdown is less significant proportionately on TPU.

Screen Shot 2021-03-09 at 3 27 29 PM

Maybe there's some issue in XLA/CPU?

zhangqiaorjc avatar Mar 09 '21 23:03 zhangqiaorjc

I'm not actually seeing a big difference on CPU. The significant slow down was on GPU.

sokrypton avatar Mar 10 '21 02:03 sokrypton