rl icon indicating copy to clipboard operation
rl copied to clipboard

GAE does not support LSTM-based value network.

Open levelrin opened this issue 1 year ago • 2 comments

Motivation

I got the following error when I used GAE with an LSTM-based value network:

RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.

Here is the code I ran:

import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.objectives.value import GAE

class ValueNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=2,
            hidden_size=1,
            num_layers=1,
            bidirectional=False,
            batch_first=True
        )

    def forward(self, i):
        output, (hidden_state, cell_state) = self.lstm(i)
        return hidden_state

def main():
    value_network = ValueNetwork()
    value_dict_module = TensorDictModule(value_network, in_keys=["observation"], out_keys=["value"])
    gae = GAE(
        gamma=0.98,
        lmbda=0.95,
        value_network=value_dict_module
    )
    gae.set_keys(
        advantage="advantage",
        value_target="value_target",
        value="value",
    )
    tensor_dict = TensorDict({
        "next": {
            "observation": torch.FloatTensor([
                [[8, 9], [10, 11]],
                [[12, 13], [14, 15]]
            ]),
            "reward": torch.FloatTensor([[1], [-1]]),
            "done": torch.BoolTensor([[1], [1]]),
            "terminated": torch.BoolTensor([[1], [1]])
        },
        "observation": torch.FloatTensor([
            [[0, 1], [2, 3]],
            [[4, 5], [6, 7]]
        ])
    }, batch_size=2)

    output_tensor_dict = gae(tensor_dict)
    print(f"output_tensor_dict: {output_tensor_dict}")
    advantage = output_tensor_dict["advantage"]
    print(f"advantage: {advantage}")

main()

The error was caused by this exact line:

output_tensor_dict = gae(tensor_dict)

I tried using unbatched input and realized that GAE does not support unbatched input. For example, this is the unbatched input I tried:

tensor_dict = TensorDict({
    "next": {
        "observation": torch.FloatTensor([[4, 5], [6, 7]]),
        "reward": torch.FloatTensor([1]),
        "done": torch.BoolTensor([1]),
        "terminated": torch.BoolTensor([1])
    },
    "observation": torch.FloatTensor([[0, 1], [2, 3]])
}, batch_size=[])

And I got this error from GAE:

RuntimeError: Expected input tensordict to have at least one dimensions, got tensordict.batch_size = torch.Size([])

Therefore, I concluded that GAE does not support an LSTM-based value network.

Solution

GAE should support an LSTM-based value network.

Alternatives

GAE should support unbatched tensor dict as an input.

Additional context

I'm using torchrl version: 0.5.0.

I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)

levelrin avatar Sep 19 '24 11:09 levelrin

I see that you're using the Torch LSTM in your snippet. Maybe try with TorchRL's version of it (LSTMModule)?

Also, as described in #2372, you'll need to use python_based=True in your LSTMModule.

thomasbbrunner avatar Sep 27 '24 08:09 thomasbbrunner

@thomasbbrunner, I have the same problem you had in #2372. I am trying recurrent PPO, used LSTMModule with python_based = True. I got the same error

"RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow."

In this part of the rnn script :

if self.recurrent_mode and is_init[..., 1:].any():

basically slicing is_init in is_init[..., 1:] gives the error. The code works with the shifted flag in GAE set True, but the performance of the PPO is bad. Maybe you had an insight about the problem?

MahmoudUwk avatar Oct 21 '24 16:10 MahmoudUwk

Same issue here, @vmoens any ideas? I can try to patch it if you point me in the right direction

edavidk7 avatar May 04 '25 21:05 edavidk7

Closed by #2962 #2941

vmoens avatar May 20 '25 10:05 vmoens