batch_size in PointCloud is ineffective when reverse-mode differentiation is used
I have a large scale optimal transport problem between two PointClouds that I want to differentiate. The cost matrix does not fit into memory so I was quite happy to see the support for batch_size parameter in PointCloud.
Unfortunately, since reverse-mode differentiation needs to store all the intermediate results, I still run out of memory when calculating the gradients of the optimal transport problem. I believe the issue is related to the one discussed here https://github.com/google/jax/issues/3186. There, the authors suggest to use @jax.remat decorator to disable checkpointing of the relevant code snippets and instead opt for re-compuation during the backwards pass.
Here's a minimum example for reproduction. You might need to play around with the problem size (or reduce the batch size) depending on the size of your GPU.
from functools import partial
import jax
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
from ott.geometry import pointcloud
from ott.tools import sinkhorn_divergence
# %%
def sample_points_uniformly_from_disc(r, n, key):
key_r, key_phi = jrandom.split(key, 2)
# sqrt is necessary to achieve uniform distribution, c.f. https://stats.stackexchange.com/questions/481543/generating-random-points-uniformly-on-a-disk
r_vals = r * jnp.sqrt(jrandom.uniform(key_r, shape=(n,)))
phi_vals = 2 * jnp.pi * jrandom.uniform(key_phi, shape=(n,))
y = jnp.stack(
(r_vals * jnp.cos(phi_vals), r_vals * jnp.sin(phi_vals)),
axis=1
)
return y
# %% [markdown]
# # Sample points uniformly from discs with different radii
# %%
key = jrandom.PRNGKey(seed=42)
key, key_x, key_y = jrandom.split(key, 3)
n = 50000
x = sample_points_uniformly_from_disc(5, n, key_x)
y = sample_points_uniformly_from_disc(10, n, key_x)
plt.plot(x[:, 0], x[:, 1], ".")
plt.plot(y[:, 0], y[:, 1], ".")
# %% [markdown]
# # Run forward sinkhorn divergence with batch size
# %%
@partial(jax.jit, static_argnames=["batch_size"])
def f(x, y, a, b, batch_size=None):
out = sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud, x, y, a=a, b=b,
batch_size=batch_size,
sinkhorn_kwargs={"use_danskin": True}
)
return out.divergence, out
batch_size = 10000
div, div_res = f(x, y, None, None, batch_size=batch_size)
print("div:", div)
# outputs 12.471696
# %% [markdown]
# # Run backward pass (fails with OOM)
# %%
df = jax.value_and_grad(f, has_aux=True)
(div, div_res), grad = df(x, y, None, None, batch_size=batch_size)
print("div=", div)
When I run the last cell on my machine, I get
2023-08-23 12:06:41.435311: W external/xla/xla/service/hlo_rematerialization.cc:2202] Can't reduce memory use below 17.77GiB (19078594560 bytes) by rematerialization; only reduced to 29.82GiB (32015643581 bytes), down from 29.82GiB (32015644425 bytes) originally
2023-08-23 12:06:52.591406: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.31GiB (rounded to 10000000000)requested by op
2023-08-23 12:06:52.591535: W external/tsl/tsl/framework/bfc_allocator.cc:497] *****************************************************_______________________________________________
2023-08-23 12:06:52.591823: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10000000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 781.2KiB
constant allocation: 328B
maybe_live_out allocation: 18.63GiB
preallocated temp allocation: 9.32GiB
preallocated temp fragmentation: 0B (0.00%)
total allocation: 27.95GiB
Peak buffers:
Buffer 1:
Size: 9.31GiB
Operator: op_name="jit(f)/jit(main)/broadcast_in_dim[shape=(5, 10000, 50000) broadcast_dimensions=()]" source_file="/tmp/ipykernel_2433426/603878625.py" source_line=3
XLA Label: broadcast
Shape: f32[5,10000,50000]
==========================
Buffer 2:
Size: 9.31GiB
Operator: op_name="jit(f)/jit(main)/while/body/dynamic_update_slice" source_file="/tmp/ipykernel_2433426/603878625.py" source_line=3 deduplicated_name="fusion.145"
XLA Label: fusion
Shape: f32[5,10000,50000]
==========================
Buffer 3:
Size: 9.31GiB
Operator: op_name="jit(f)/jit(main)/while/body/dynamic_update_slice" source_file="/tmp/ipykernel_2433426/603878625.py" source_line=3 deduplicated_name="fusion.145"
XLA Label: fusion
Shape: f32[5,10000,50000]
==========================
Any help in addressing this would be really appreciated!