jax icon indicating copy to clipboard operation
jax copied to clipboard

JIT constant folding

Open inversecrime opened this issue 1 year ago • 6 comments

Description

Hi, I was hoping that someone could help me with this.

Sometimes, when using constants in jitted functions, I get warnings like this one:

2024-05-19 20:16:26.694439: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %reduce.8 = f64[200000,10]{1,0} reduce(f64[200000,10,10]{2,1,0} %broadcast.2, f64[] %constant.3), dimensions={2}, to_apply=%region_0.4, metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(2,)]" source_file="..." source_line=13}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.

These warnings appear seemingly random, for example with the following code:

from functools import wraps
import jax
import jax.numpy as jnp
import jax.core

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", "cpu")

v = jnp.zeros((200000, 10, 10))


def f():
    return jax.vmap(jax.vmap(jnp.sum))(v)


jax.jit(f)()

This code produces "constant folding" warnings on windows and on linux. Maybe / probably this is dependend on OS version, CPU type, ...

When playing around with array shapes and number of nested vmaps, these messages appear or not appear without any clear (atleast not clear to me) pattern. For exampe, this is fast:

v = jnp.zeros((1000000, 10, 10))
def f():
    return jax.vmap(jnp.sum)(v)
jax.jit(f)()

While this is slow and produces the warning:

v = jnp.zeros((1000000, 10, 2))
def f():
    return jax.vmap(jnp.sum)(v)
jax.jit(f)()

Constant folding only happens when compiling with jax.jit - making jaxprs is not affected. Since jaxprs are perfectly able to catch constants, it is possible to compile them while treating constants as variables. The following function demonstrates this:

def other_jit(f):
    @wraps(f)
    def wrapper(*args):
        jaxpr = jax.make_jaxpr(f)(*args)
        return jax.jit(lambda c, *a: jax.core.eval_jaxpr(jaxpr.jaxpr, c, *a))(jaxpr.consts, *args)
    return wrapper

Now, using other_jit(f)() instead of jax.jit(f)() prevents the issue.

I was wondering if this is intended behavior. Wouldn't it be a better solution in most cases to always treat constants as variables while compiling, to prevent constant folding from slowing down compilations?

In real-world scenarios, using (a generalized version of) the other_jit function I presented here can significantly reduce compilation times from a few minutes to just seconds.

What's your opinion on this? I would appreciate any help or suggestions.

System info (python version, jaxlib version, accelerator, etc.)

cpu jax 0.4.28 jaxlib 0.4.28

inversecrime avatar May 19 '24 20:05 inversecrime

I'm aware that other_jit recompiles the function with every call - in real-word scenarios it would be better to save and reuse compiled functions.

inversecrime avatar May 19 '24 20:05 inversecrime

Some explanation why it depends on the shape: We have a heuristic to not apply constant folding if the operand shape is too large. The cutoff is 45 * 1000 * 1000 elements. In the "fast" cases we don't apply constant folding.

akuegel avatar May 21 '24 06:05 akuegel

Thanks for the reply! It also seems to depend on the operation itself. For examle, with a double vmap (i.e. sum over last axis), it happens, but it doesn't happen when using only one vmap (i.e. sum over last two axes):

import jax
import jax.numpy as jnp
import jax.core

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", "cpu")

v = jnp.zeros((200000, 10, 10))


def f():
    return jax.vmap(jax.vmap(jnp.sum))(v)


def g():
    return jax.vmap(jnp.sum)(v)


print("f")
jax.jit(f)()
print("g")
jax.jit(g)()

Maybe it would help to clarify what constant folding is used for / where it makes sense to apply it. As far as I know, it basically means that the compiler evaluates some operations at compile time (to save runtime), if all inputs for these operations are known in advance (i.e. if they are "constants").

I'm wondering why this is so slow - intuitively, I would think that constant folding happens approximately at the speed of numpy or uncompiled jax.numpy. But it seems to be much slower than that!

inversecrime avatar May 21 '24 11:05 inversecrime

For constant folding, the HloEvaluator is used. It is not optimized for speed, but for correctness, as it is used as reference backend in tests. You can see the rules that we have for the ConstantFolding pass here:

https://github.com/openxla/xla/blob/main/xla/service/hlo_constant_folding.cc

I don't know what the nested jax.vmap would translate to, but I think you can safely assume that fast runtime means that constant folding is not applied. Constant folding only makes sense if what is being constant folded would run several times. If it is run only a single time, then you would be better off without constant folding.

akuegel avatar May 21 '24 11:05 akuegel

Thanks for clarifying!

Would it be a useful addition to jax.jit to make it possible to turn this behavior off? Instead, constants could be treated as regular variables (that then get passed to the compiled function), preventing constant folding from ever happening.

For example, you could force this with the current API using jax.make_jaxpr and partial_eval - basically first extracting all constants (i.e. known values) with make_jaxpr, then computing as many values as possible using partial_eval, and then compiling the remaining jaxpr, using the precomputed values whenever it's called.

Maybe this would be a nice addition for those users (like me) who use many and large static arrays (i.e. constants in the context of jit) but don't want constant folding to slow the compilation down.

inversecrime avatar May 22 '24 21:05 inversecrime

I am not familiar with the JAX side of things. On XLA side we have a flag that could be used to turn off constant folding:

--xla_disable_hlo_passes=constant_folding

This can be set via the XLA_FLAGS environment variable. So something like os.environ['XLA_FLAGS'] = "--xla_disable_hlo_passes=constant_folding" from python

akuegel avatar May 23 '24 06:05 akuegel

Thanks for helping!

It would be nice to also have an option like this in jax.jit to control this behavior - something like constant_folding: bool maybe.

inversecrime avatar May 28 '24 23:05 inversecrime

You can do this via jax.jit(f).lower(*args).compile(compiler_options={'xla_disable_hlo_passes': True}). We are looking into supporting this as an option to jit but you can do it via the AOT API for now.

yashk2810 avatar May 28 '24 23:05 yashk2810

That was a fast comment!

When trying this, i get the following error: jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: While setting option xla_disable_hlo_passes, '1' is not a valid string value.

inversecrime avatar May 28 '24 23:05 inversecrime

Ohh sorry you need 'xla_disable_hlo_passes': 'constant_folding'

yashk2810 avatar May 28 '24 23:05 yashk2810

RuntimeError: Protocol Buffer reflection usage error:
  Method      : google::protobuf::Reflection::SetString
  Message type: xla.DebugOptions
  Field       : xla.DebugOptions.xla_disable_hlo_passes
  Problem     : Field is repeated; the method requires a singular field.

The code I used:

v = jnp.zeros((200000, 10, 10))

def f():
    return jax.vmap(jax.vmap(jnp.sum))(v)

jax.jit(f).lower().compile(compiler_options={'xla_disable_hlo_passes': 'constant_folding'})

inversecrime avatar May 28 '24 23:05 inversecrime

Hmm, this might require some fixes in the jax code. I'll take a look.

yashk2810 avatar May 28 '24 23:05 yashk2810

Was there any solution to this issue? I am also getting these warnings when scanning over large tensors.

I'd also be happy with some way to silence the warnings and just accept the long compile time but without the terminal spam, but I couldn't find any logger that corresponded to the errors being issued (I tried setting all jax loggers to logging.ERROR using logging.root.manager.loggingDict).

jessegrabowski avatar Jul 30 '24 06:07 jessegrabowski

I just ended up first converting my main functions to jaxpr and then compiling them while treating their constants as dynamic variables.

Additionally, this enables you to use partialeval to precompute the constant part of your jaxpr (which can be a lot of saved time if you call it often enough).

Sadly, jax does not seem to have any api for automating these processes.

inversecrime avatar Jul 30 '24 08:07 inversecrime

@inversecrime Do you have code snippets to perform this? I am having similar troubles

renecotyfanboy avatar Nov 28 '24 14:11 renecotyfanboy

my_jit.txt

This is (up to some minor differences) what I currently use. I don't exactly know how well the current implementation of jax.jit handles constant folding, but the "jit" function in this file will always treat constants that appear during jaxpr tracing as dynamic inputs to jax.jit and thus prevent any constants from appearing during compile time.

It should be straightforward to see what this code does. If you enable "precompute", your code will run faster (if you have lots of constants in your main function) but may need considerably more memory.

Note: As far as I know, jaxprs are considered to be private/semi-public api.

inversecrime avatar Nov 29 '24 21:11 inversecrime

Thank you very much, I finally tested it and it works as intended. I still wish for better workaround in the future ahah

renecotyfanboy avatar Dec 05 '24 11:12 renecotyfanboy