dopamine icon indicating copy to clipboard operation
dopamine copied to clipboard

Bug in the Reply Buffer: end of episodes is not correctly handled

Open theovincent opened this issue 11 months ago • 1 comments

Hi,

Thank you for your great work. It is really cool to open source such an amazing code base!

TL;DR

@yogesh1q2w, and I noticed that the last transitions of a trajectory are not properly handled. Indeed, multiple ReplayElements with a terminal flag are stored when only one is given to the accumulator.

It is problematic because the additional terminal states do not correspond to states that can be observed from the environment. This is problematic because we use function approximation.

How to reproduce?

After forking the repo and running

python3.11.5 -m venv env_cpu
source env_cpu/bin/activate
pip install --upgrade pip setuptools wheel
pip install -e .

I ran

import numpy as np
from dopamine.jax.replay_memory import accumulator, samplers, replay_buffer, elements

transition_accumulator = accumulator.TransitionAccumulator(stack_size=4, update_horizon=1, gamma=0.99)
sampling_distribution = samplers.UniformSamplingDistribution(seed=1)
rb = replay_buffer.ReplayBuffer(
	transition_accumulator=transition_accumulator,
	sampling_distribution=sampling_distribution,
	batch_size=1,
	max_capacity=50,
	compress=False
)

for i in range(8):
	rb.add(elements.TransitionElement(i * np.ones(1), i, i, False if i < 7 else True, False))

print(rb._memory)
OrderedDict([(0,
              ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
             (1,
              ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
             (2,
              ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
             (3,
              ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
             (4,
              ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
             (5,
              ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
             (6,
              ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (7,
              ReplayElement(state=array([[0., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (8,
              ReplayElement(state=array([[0., 0., 5., 6.]]), action=6, reward=6.0, next_state=array([[0., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
             (9,
              ReplayElement(state=array([[0., 0., 0., 6.]]), action=6, reward=6.0, next_state=array([[0., 0., 6., 7.]]), is_terminal=True, episode_end=True))])

The last 3 ReplayElements are incorrect. They should not have been added.

How to fix the bug?

Replacing the following lines https://github.com/google/dopamine/blob/bec5f4e108b0572e58fc1af73136e978237c8463/dopamine/jax/replay_memory/accumulator.py#L74-L82 by

    # Check if we have a valid transition, i.e. we either
    #   1) have accumulated more transitions than the update horizon and the
    #      last element is not terminal
    #   2) have a trajectory shorter than the update horizon, but the
    #      last element is terminal and we have enough frames to stack
    if not (
        (trajectory_len > self._update_horizon and not last_transition.is_terminal)
        or (trajectory_len > self._stack_size and last_transition.is_terminal)
    ):
        return None

solves the issue. Indeed, by running the same code again, we obtain:

OrderedDict([(0,
              ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
             (1,
              ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
             (2,
              ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
             (3,
              ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
             (4,
              ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
             (5,
              ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
             (6,
              ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True))])

The last ReplayElements have been filtered 🎉

theovincent avatar Feb 18 '25 14:02 theovincent

I made the change on this fork: https://github.com/theovincent/dopamine

Let me know if you would like me to make a PR 🙂

theovincent avatar Feb 18 '25 14:02 theovincent