inversecrime

Results 15 comments of inversecrime

Is it possible that this is a RAM issue? It almost seems like it - with smaller arrays, even with x64 everything looks ok. But then again, I do not...

I think I tracked down the issue, and it looks like this even happens with 32bit numbers. Consider the following example, where the second function almost uses twice as much...

Thanks for the answer. I don't have any experience with XLA per se, so I wouldn't know where to start there. Also, I don't really know how to observe RAM...

I mean, it's just a suspicion. But why else would it need so much more memory than ``` def f(x, a, b, c): a = x * a d =...

Thanks for helping @jakevdp. I agree, _IF_ this is an issue, it's definitely not JAX, but whatever the compiler does. In the meantime I played around with this a bit...

I'm aware that `other_jit` recompiles the function with every call - in real-word scenarios it would be better to save and reuse compiled functions.

Thanks for the reply! It also seems to depend on the operation itself. For examle, with a double `vmap` (i.e. sum over last axis), it happens, but it doesn't happen...

Thanks for clarifying! Would it be a useful addition to `jax.jit` to make it possible to turn this behavior off? Instead, constants could be treated as regular variables (that then...

Thanks for helping! It would be nice to also have an option like this in `jax.jit` to control this behavior - something like `constant_folding: bool` maybe.

That was a fast comment! When trying this, i get the following error: `jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: While setting option xla_disable_hlo_passes, '1' is not a valid string value.`