rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] CUDAgraph policy changes the learning process

Open m-krastev opened this issue 8 months ago • 0 comments

Describe the bug

I have a single collector and I am training with ClipPPO a simple path finding policy. The agent must find a path in 3D around a simple maze while maximizing the coverage of the maze. It receives as input small patches centered around its current position and outputs the parameters to a Beta distribution (see example). From the distribution, we sample a value which is mapped to some action delta ($a_t \in R^3 $) to determine the next position.

To accelerate the experience sampling, I opted to apply CUDA graphs, but it seems to lead to unstable learning.

A specific pathology I notice is that the standard deviation of the actions determined by the CUDAgraph'd policy tends to zero.

To Reproduce

I can't provide my full code, but the only difference is turning the cudagraph_policy=True in the collector. Maybe good to know that the policy module is torch.compile'd BEFORE it is passed onto the collector.

Policy (CNN extraction head + MLP, Beta distribution):

class IndependentBeta(Independent):
    def __init__(
        self,
        alpha: torch.Tensor,
        beta: torch.Tensor,
        min: Union[float, torch.Tensor] = 0.0,
        max: Union[float, torch.Tensor] = 1.0,
        event_dims: int = 1,
    ):
        self.min = torch.as_tensor(min, device=alpha.device).broadcast_to(alpha.shape)
        self.max = torch.as_tensor(max, device=alpha.device).broadcast_to(alpha.shape)
        self.scale = self.max - self.min
        self.eps = torch.finfo(alpha.dtype).eps
        base_dist = Beta(alpha, beta)
        super().__init__(base_dist, event_dims)

    def sample(self, sample_shape: torch.Size = torch.Size()):
        return super().sample(sample_shape) * self.scale + self.min

    def rsample(self, sample_shape: torch.Size = torch.Size()):
        return super().rsample(sample_shape) * self.scale + self.min

    def log_prob(self, value: torch.Tensor):
        return super().log_prob(((value - self.min) / self.scale).clamp(self.eps, 1.0 - self.eps))

policy_module = ProbabilisticActor(
        module=TensorDictSequential(
            actor_cnn_module,  # Outputs TD with "alpha, beta"
        ),
        spec=action_spec,
        in_keys=["alpha", "beta"],
        out_keys=["action"],
        distribution_class=IndependentBeta,
        return_log_prob=True,
        default_interaction_type=InteractionType.RANDOM,
).to(device)

policy_module.compile(fullgraph=True, dynamic=False)

Collector:

    collector = SyncDataCollector(
        create_env_fn=env_maker,  # Function to create environments
        policy=policy_module,
        # Total frames (steps) to collect in training
        total_frames=total_timesteps - collected_frames,
        frames_per_batch=config.frames_per_batch,
        # No initial random exploration phase needed if policy handles exploration
        init_random_frames=-1,
        split_trajs=False,  # Process rollouts as single batch
        device=device,  # Device for collector ops (usually same as models/env)
        # Device where data is stored (can be CPU if memory is tight)
        storing_device=device,
        max_frames_per_traj=config.max_episode_steps, 
        # num_threads=8
        # cudagraph_policy=True
    )

Training code (AMP+BF16):

    for i, batch_data in enumerate(collector, start=collected_frames):
        current_frames = batch_data.numel()  # Number of steps collected in this batch
        collected_frames += current_frames

        for _ in range(config.update_epochs):
            batch_data = batch_data.reshape(-1)

            with (
                torch.no_grad(),
                torch.autocast(device.type, amp_dtype, enabled=config.amp),
            ):
                    adv_module(batch_data)

            for j in range(0, config.frames_per_batch, batch_size):
                minibatch = batch_data[j : j + batch_size]
                with torch.autocast(device.type, amp_dtype, enabled=config.amp):
                    loss_dict = loss_module(minibatch)

                    actor_loss = loss_dict["loss_objective"] + loss_dict["loss_entropy"]
                    critic_loss = loss_dict["loss_critic"]

                optimizer.zero_grad()
                scaler.scale(actor_loss).backward()
                scaler.scale(critic_loss).backward()
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    loss_module.parameters(), config.max_grad_norm
                )
                scaler.step(optimizer)
                scaler.update()

Expected behavior

The loss/training curves to be the same.

Screenshots

Environment is a custom environment. Red is with cudagraph_policy=False, gray is cudagraph_policy=True

Image

Page 2:

Image

System info

  • torchrl==0.8.0 (pip, uv)
  • Python 3.12.3
  • Linux

Additional context

Add any other context about the problem here.

Reason and Possible fixes

No real idea why that might be.

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

m-krastev avatar Jun 15 '25 17:06 m-krastev