flax
flax copied to clipboard
bias and kernel params are put on different gpu devices
System information
- OS Platform and Distribution: Linux Ubuntu 20.04
- Flax, jax, jaxlib versions: flax -> 0.6.11, jax -> 0.4.13, jaxlib -> 0.4.13+cuda11.cudnn86
- Python version: Python 3.9.19
- GPU/TPU model and memory: GPU4090 with 24GB
- CUDA version: cuda 11.8
Problem you have encountered:
When I try to initialize a Flax model on a specific gpu device (for example, gpu 1), the bias and kernel params are located on different gpu devices.
What you expected to happen:
The bias and kernel params should be put on the same gpu device.
Steps to reproduce:
import jax
import jax.numpy as jnp
from jax import tree_util
from flax import linen as nn
device = jax.devices("gpu")[1]
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(64, (3, 3), 1, name='conv1')(x)
x = nn.relu(x)
return x
rng = jax.random.PRNGKey(0)
rng = jax.device_put(rng, device)
dummy_input = jax.device_put(jnp.ones((5, 64, 64, 32)), device)
model = MyModel()
model_params = model.init({'params': rng}, dummy_input)
# model_params = tree_util.tree_map(lambda x: jax.device_put(x, device), model_params)
print(tree_util.tree_map(lambda x: (x.device()), model_params))
The output is
FrozenDict({
params: {
conv1: {
bias: gpu(id=0),
kernel: gpu(id=1),
},
},
})
Thanks!
Hi @YunxiTang, I am able to reproduce this issue.
In practice, I have seen flax models initialized on cpu, and migrated/replicated to devices later. Two examples:
- Migrating params post-initialization to GPU.
# Optional: Init on `cpu`. model_params = jax.jit(model.init, backend="cpu")({'params': rng}, dummy_input) model_params = jax.device_put(model_params, device) jax.tree.map(lambda p: p.device, model_params) # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}} - Using
jax.default_devicescope.with jax.default_device(device): model_params = model.init({'params': rng}, dummy_input) print(tree_util.tree_map(lambda x: (x.device), model_params)) # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
I will let Flax team comment on the default behavior in your case.