GAE does not support LSTM-based value network.
Motivation
I got the following error when I used GAE with an LSTM-based value network:
RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.
Here is the code I ran:
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.objectives.value import GAE
class ValueNetwork(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(
input_size=2,
hidden_size=1,
num_layers=1,
bidirectional=False,
batch_first=True
)
def forward(self, i):
output, (hidden_state, cell_state) = self.lstm(i)
return hidden_state
def main():
value_network = ValueNetwork()
value_dict_module = TensorDictModule(value_network, in_keys=["observation"], out_keys=["value"])
gae = GAE(
gamma=0.98,
lmbda=0.95,
value_network=value_dict_module
)
gae.set_keys(
advantage="advantage",
value_target="value_target",
value="value",
)
tensor_dict = TensorDict({
"next": {
"observation": torch.FloatTensor([
[[8, 9], [10, 11]],
[[12, 13], [14, 15]]
]),
"reward": torch.FloatTensor([[1], [-1]]),
"done": torch.BoolTensor([[1], [1]]),
"terminated": torch.BoolTensor([[1], [1]])
},
"observation": torch.FloatTensor([
[[0, 1], [2, 3]],
[[4, 5], [6, 7]]
])
}, batch_size=2)
output_tensor_dict = gae(tensor_dict)
print(f"output_tensor_dict: {output_tensor_dict}")
advantage = output_tensor_dict["advantage"]
print(f"advantage: {advantage}")
main()
The error was caused by this exact line:
output_tensor_dict = gae(tensor_dict)
I tried using unbatched input and realized that GAE does not support unbatched input.
For example, this is the unbatched input I tried:
tensor_dict = TensorDict({
"next": {
"observation": torch.FloatTensor([[4, 5], [6, 7]]),
"reward": torch.FloatTensor([1]),
"done": torch.BoolTensor([1]),
"terminated": torch.BoolTensor([1])
},
"observation": torch.FloatTensor([[0, 1], [2, 3]])
}, batch_size=[])
And I got this error from GAE:
RuntimeError: Expected input tensordict to have at least one dimensions, got tensordict.batch_size = torch.Size([])
Therefore, I concluded that GAE does not support an LSTM-based value network.
Solution
GAE should support an LSTM-based value network.
Alternatives
GAE should support unbatched tensor dict as an input.
Additional context
I'm using torchrl version: 0.5.0.
I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
I see that you're using the Torch LSTM in your snippet. Maybe try with TorchRL's version of it (LSTMModule)?
Also, as described in #2372, you'll need to use python_based=True in your LSTMModule.
@thomasbbrunner, I have the same problem you had in #2372. I am trying recurrent PPO, used LSTMModule with python_based = True. I got the same error
"RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow."
In this part of the rnn script :
if self.recurrent_mode and is_init[..., 1:].any():
basically slicing is_init in is_init[..., 1:] gives the error. The code works with the shifted flag in GAE set True, but the performance of the PPO is bad. Maybe you had an insight about the problem?
Same issue here, @vmoens any ideas? I can try to patch it if you point me in the right direction
Closed by #2962 #2941