[BUG] Inconsistent behavior in RenameTransform
Describe the bug
I think that there are a couple of bugs in the RenameTransform implementation. Let's through them go one-by-one:
- Specs are only modified if the key is in
in_keys. Keys specified inin_keys_invdon't lead to a change in the specs. - Only the
full_action_specis modified in the transform, but not theaction_spec. - The behavior of the
_invkeys is different to the documented behavior.
To Reproduce
Using this simple snippet as a basis. The goal is to map all inputs and outputs of an environment to another key.
from torchrl.envs import RenameTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from tensordict import TensorDict
import torch
base_env = GymEnv("CartPole-v1")
transformed_env = TransformedEnv(
base_env,
RenameTransform(
in_keys=[
"observation",
"terminated",
"truncated",
"reward",
"done",
],
out_keys=[
("stuff", "observation"),
("stuff", "terminated"),
("stuff", "truncated"),
("stuff", "reward"),
("stuff", "done"),
],
in_keys_inv=[
("stuff", "action"),
],
out_keys_inv=[
"action",
],
),
)
At first, the transformation seems to have worked:
>>> base_env.reset()
TensorDict(
fields={
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> transformed_env.reset()
TensorDict(
fields={
stuff: TensorDict(
fields={
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
Specs are only modified if the key is in in_keys. Keys specified in in_keys_inv don't lead to a change in the specs.
But looking at the specs we see that the input spec was not modified:
>>> transformed_env.input_spec
CompositeSpec(
full_state_spec: CompositeSpec(
,
device=None,
shape=torch.Size([])),
full_action_spec: CompositeSpec(
action: OneHotDiscreteTensorSpec(
shape=torch.Size([2]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=None,
shape=torch.Size([])),
device=None,
shape=torch.Size([]))
Only the full_action_spec is modified in the transform, but not the action_spec.
By adding the action key to in_keys (and ("stuff", "action") to the out_keys), the full_action_spec is correctly transformed, but not the action_spec:
>>> transformed_env.full_action_spec
CompositeSpec(
stuff: CompositeSpec(
action: OneHotDiscreteTensorSpec(
shape=torch.Size([2]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=None,
shape=torch.Size([])),
device=None,
shape=torch.Size([]))
>>> transformed_env.action_spec
OneHotDiscreteTensorSpec(
shape=torch.Size([2]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.int64,
domain=discrete)
However, doing this (adding the action key to in_keys) will cause the environment to crash when step is called, as the action does not exist in the outputs:
>>> transformed_env.step(TensorDict({"stuff": {"action": torch.zeros(2)}}))
KeyError: 'key "action" not found in TensorDict with keys [\'done\', \'observation\', \'reward\', \'terminated\', \'truncated\']'
The behavior of the _inv keys is different to the documented behavior.
The documentation states:
in_keys_inv: the entries to rename before passing the input tensordict to :meth:`EnvBase._step`.
out_keys_inv: the names of the renamed entries passed to :meth:`EnvBase._step`.
Thus, our transform should be:
in_keys_inv=[
("stuff", "action"), # input to the `transformed_env.step`
]
out_keys_inv=[
"action", # input to `EnvBase._step`.
]
However, this is not the case:
transformed_env.step(TensorDict({"stuff": {"action": torch.zeros(2)}}))
*** KeyError: 'key "action" not found in TensorDict with keys [\'stuff\']'
By inverting the in_keys_inv and out_keys_inv, we get it to run:
>>> transformed_env.step(TensorDict({"stuff": {"action": torch.zeros(2)}}))
TensorDict(
fields={
next: TensorDict(
fields={
stuff: TensorDict(
fields={
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
stuff: TensorDict(
fields={
action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
Expected behavior
The following snippet should run without errors (I think this covers the bugs described above):
from torchrl.envs import RenameTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from tensordict import TensorDict
import torch
base_env = GymEnv("CartPole-v1")
transformed_env = TransformedEnv(
base_env,
RenameTransform(
in_keys=[
"observation",
"terminated",
"truncated",
"reward",
"done",
],
out_keys=[
("stuff", "observation"),
("stuff", "terminated"),
("stuff", "truncated"),
("stuff", "reward"),
("stuff", "done"),
],
in_keys_inv=[
("stuff", "action"),
],
out_keys_inv=[
"action",
],
),
)
# BUG: Specs are only modified if the key is in in_keys. Keys specified in in_keys_inv don't lead to a change in the specs.
assert "stuff" in transformed_env.full_action_spec
assert "action" in transformed_env.full_action_spec["stuff"]
# BUG: Only the full_action_spec is modified in the transform, but not the action_spec.
assert "stuff" in transformed_env.action_spec
assert "action" in transformed_env.action_spec["stuff"]
base_env.reset()
transformed_env.reset()
base_env.step(TensorDict({"action": torch.zeros(2)}))
# BUG: The behavior of the _inv keys is different to the documented behavior.
transformed_env.step(TensorDict({"stuff": {"action": torch.zeros(2)}}))
System info
Ubuntu 22.04 Python 3.10.14 torch 2.4.1 torchrl 0.5.0
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)
I've created a PR with fixes for the issues described here: https://github.com/pytorch/rl/pull/2442
Just to be clear, here's how one should think about in/out keys:
I think the logic within _inv is right. I agree the phrasing of the docstrings isn't super clear, here's how to understand it:
in_keys (sequence of NestedKey): the entries to rename
out_keys (sequence of NestedKey): the name of the entries after renaming.
in_keys_inv (sequence of NestedKey, optional): the entries to rename before
passing the input tensordict to :meth:`EnvBase._step`.
out_keys_inv (sequence of NestedKey, optional): the names of the renamed
entries passed to :meth:`EnvBase._step`.
out_keys_inv are the names of the renamed entries: we have renamed "action" in "other_action"
in_keys_inv is the name just before we pass it to env_base, ie the names that the base env is expecting.
Discussion continues in https://github.com/pytorch/rl/pull/2442#issuecomment-2361757878
The visualization really helps ! 🙏 what is extracted from the documentation or should it be added?
Yeah let's add it! Should it go in the generic transform doc or in the doc string of rename?
A lot of transformations are using the in_keys, out_keys, in_keys_inv and out_keys_inv so I would suggest to put that as part of the API reference for torchrl.envs.transforms