Bug in the Reply Buffer: end of episodes is not correctly handled
Hi,
Thank you for your great work. It is really cool to open source such an amazing code base!
TL;DR
@yogesh1q2w, and I noticed that the last transitions of a trajectory are not properly handled. Indeed, multiple ReplayElements with a terminal flag are stored when only one is given to the accumulator.
It is problematic because the additional terminal states do not correspond to states that can be observed from the environment. This is problematic because we use function approximation.
How to reproduce?
After forking the repo and running
python3.11.5 -m venv env_cpu
source env_cpu/bin/activate
pip install --upgrade pip setuptools wheel
pip install -e .
I ran
import numpy as np
from dopamine.jax.replay_memory import accumulator, samplers, replay_buffer, elements
transition_accumulator = accumulator.TransitionAccumulator(stack_size=4, update_horizon=1, gamma=0.99)
sampling_distribution = samplers.UniformSamplingDistribution(seed=1)
rb = replay_buffer.ReplayBuffer(
transition_accumulator=transition_accumulator,
sampling_distribution=sampling_distribution,
batch_size=1,
max_capacity=50,
compress=False
)
for i in range(8):
rb.add(elements.TransitionElement(i * np.ones(1), i, i, False if i < 7 else True, False))
print(rb._memory)
OrderedDict([(0,
ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
(1,
ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
(2,
ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
(3,
ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
(4,
ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
(5,
ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
(6,
ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
(7,
ReplayElement(state=array([[0., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
(8,
ReplayElement(state=array([[0., 0., 5., 6.]]), action=6, reward=6.0, next_state=array([[0., 5., 6., 7.]]), is_terminal=True, episode_end=True)),
(9,
ReplayElement(state=array([[0., 0., 0., 6.]]), action=6, reward=6.0, next_state=array([[0., 0., 6., 7.]]), is_terminal=True, episode_end=True))])
The last 3 ReplayElements are incorrect. They should not have been added.
How to fix the bug?
Replacing the following lines https://github.com/google/dopamine/blob/bec5f4e108b0572e58fc1af73136e978237c8463/dopamine/jax/replay_memory/accumulator.py#L74-L82 by
# Check if we have a valid transition, i.e. we either
# 1) have accumulated more transitions than the update horizon and the
# last element is not terminal
# 2) have a trajectory shorter than the update horizon, but the
# last element is terminal and we have enough frames to stack
if not (
(trajectory_len > self._update_horizon and not last_transition.is_terminal)
or (trajectory_len > self._stack_size and last_transition.is_terminal)
):
return None
solves the issue. Indeed, by running the same code again, we obtain:
OrderedDict([(0,
ReplayElement(state=array([[0., 0., 0., 0.]]), action=0, reward=0.0, next_state=array([[0., 0., 0., 1.]]), is_terminal=False, episode_end=False)),
(1,
ReplayElement(state=array([[0., 0., 0., 1.]]), action=1, reward=1.0, next_state=array([[0., 0., 1., 2.]]), is_terminal=False, episode_end=False)),
(2,
ReplayElement(state=array([[0., 0., 1., 2.]]), action=2, reward=2.0, next_state=array([[0., 1., 2., 3.]]), is_terminal=False, episode_end=False)),
(3,
ReplayElement(state=array([[0., 1., 2., 3.]]), action=3, reward=3.0, next_state=array([[1., 2., 3., 4.]]), is_terminal=False, episode_end=False)),
(4,
ReplayElement(state=array([[1., 2., 3., 4.]]), action=4, reward=4.0, next_state=array([[2., 3., 4., 5.]]), is_terminal=False, episode_end=False)),
(5,
ReplayElement(state=array([[2., 3., 4., 5.]]), action=5, reward=5.0, next_state=array([[3., 4., 5., 6.]]), is_terminal=False, episode_end=False)),
(6,
ReplayElement(state=array([[3., 4., 5., 6.]]), action=6, reward=6.0, next_state=array([[4., 5., 6., 7.]]), is_terminal=True, episode_end=True))])
The last ReplayElements have been filtered 🎉
I made the change on this fork: https://github.com/theovincent/dopamine
Let me know if you would like me to make a PR 🙂