optax icon indicating copy to clipboard operation
optax copied to clipboard

Memory overflow using scale_by_radam

Open HGangloff opened this issue 2 years ago • 1 comments

Hi,

I have my RAM getting used up to overflow when I use scale_by_radam gradient transform or equivalently optax.radam without JIT compiling the code. The problem appears on CPU and GPU but does not appear when I use JIT compilation. The problem does not seem to exist with optax.adam.

Here is a MWE derived from optax quick start tutorial:

import random
from typing import Tuple
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # uncomment to force CPU

import optax
import jax.numpy as jnp
import jax
import numpy as np

BATCH_SIZE = 500
NUM_TRAIN_STEPS = 10000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)

initial_params = {
    'hidden': jax.random.normal(shape=[8, 200], key=jax.random.PRNGKey(0)),
    'hidden2': jax.random.normal(shape=[200, 100], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[100, 2], key=jax.random.PRNGKey(1)),
}


def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['hidden2'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x


def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  #@jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.radam(learning_rate=1e-2)
params = fit(initial_params, optimizer)

Of course this example is simple enough and does not saturate the RAM before a long time but this issue is really problematic in another particular research project.

The problem seems to be linked with this computation specific to RAdam: https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7. But I do not know how to investigate further.

Thanks for your feedback.

HGangloff avatar Aug 29 '23 15:08 HGangloff

Hi HGangloff, Prioritize JIT Compilation:

Compile your code using jax.jit whenever possible to benefit from JAX's optimizations and potentially avoid the RAM issue. Investigate RAdam Implementation:

Explore the RAdam implementation in Optax: https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7 Focus on areas that might create large temporary arrays or perform memory-intensive operations. Consider profiling memory usage to pinpoint specific lines or functions causing excessive consumption. Experiment with Alternative Optimizers:

If RAdam's performance is crucial for your research, consider: Modifying RAdam's implementation to reduce memory footprint (if feasible). Exploring alternative optimizers like Yogi, which share similarities with RAdam but might have different memory characteristics. Report to Optax Maintainers:

Share your findings and code examples with the Optax maintainers to bring attention to the issue and potentially contribute to a fix. Additional Considerations:

Memory Profiling: Use tools like jax.profiler or external profilers to track memory usage and identify bottlenecks. Batch Size Adjustment: Experiment with smaller batch sizes to reduce memory requirements per step. Hardware Constraints: Consider available RAM and potential hardware limitations. I'm ready to assist further if you have more questions or require additional guidance. I'll be waiting for your positive response!!!

itstalmeez avatar Dec 25 '23 11:12 itstalmeez

Since #969 has been closed as completed, I think this issue can be closed.

carlosgmartin avatar Nov 04 '24 23:11 carlosgmartin

Right, thank you for catching this @carlosgmartin ! @HGangloff feel free to reopen if the matter was not solved but I think it should be good.

vroulet avatar Nov 04 '24 23:11 vroulet