flax icon indicating copy to clipboard operation
flax copied to clipboard

bias and kernel params are put on different gpu devices

Open YunxiTang opened this issue 1 year ago • 1 comments

System information

  • OS Platform and Distribution: Linux Ubuntu 20.04
  • Flax, jax, jaxlib versions: flax -> 0.6.11, jax -> 0.4.13, jaxlib -> 0.4.13+cuda11.cudnn86
  • Python version: Python 3.9.19
  • GPU/TPU model and memory: GPU4090 with 24GB
  • CUDA version: cuda 11.8

Problem you have encountered:

When I try to initialize a Flax model on a specific gpu device (for example, gpu 1), the bias and kernel params are located on different gpu devices.

What you expected to happen:

The bias and kernel params should be put on the same gpu device.

Steps to reproduce:

  import jax
  import jax.numpy as jnp
  from jax import tree_util
  from flax import linen as nn

  device = jax.devices("gpu")[1]

  class MyModel(nn.Module):
      @nn.compact
      def __call__(self, x):
          x = nn.Conv(64, (3, 3), 1, name='conv1')(x)
          x = nn.relu(x)
          return x

  rng = jax.random.PRNGKey(0)
  rng = jax.device_put(rng, device)
  dummy_input = jax.device_put(jnp.ones((5, 64, 64, 32)), device) 

  model = MyModel()  
  model_params = model.init({'params': rng}, dummy_input)
  # model_params = tree_util.tree_map(lambda x: jax.device_put(x, device), model_params)
  print(tree_util.tree_map(lambda x: (x.device()), model_params))

The output is

FrozenDict({
    params: {
        conv1: {
            bias: gpu(id=0),
            kernel: gpu(id=1),
        },
    },
})

Thanks!

YunxiTang avatar Aug 06 '24 06:08 YunxiTang

Hi @YunxiTang, I am able to reproduce this issue.

In practice, I have seen flax models initialized on cpu, and migrated/replicated to devices later. Two examples:

  1. Migrating params post-initialization to GPU.
    # Optional: Init on `cpu`.
    model_params = jax.jit(model.init, backend="cpu")({'params': rng}, dummy_input)
    model_params = jax.device_put(model_params, device)
    jax.tree.map(lambda p: p.device, model_params)
    # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
    
  2. Using jax.default_device scope.
    with jax.default_device(device):
        model_params = model.init({'params': rng}, dummy_input)
        print(tree_util.tree_map(lambda x: (x.device), model_params))
        # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
    

I will let Flax team comment on the default behavior in your case.

MasterSkepticista avatar Aug 30 '24 04:08 MasterSkepticista