agent.train --> TypeError: call() missing 1 required positional argument: 'network_state'
I'm trying to use a DDPG agent with actor and critic networks, and a TFUniform replay buffer, training on my custom environment.
I've extracted a training experience from the buffer using:
dataset = buffer.as_dataset(sample_batch_size=32,
num_steps=2)
iterator = iter(dataset)
experience,info = next(iterator)
but when I call
agent.train(experience)
I get a TypeError:
Traceback (most recent call last):
File "/home/stephanie/Documents/TT/MyA2C.py", line 102, in <module>
agent.train(experience)
File "/home/stephanie/.local/lib/python3.8/site-packages/tf_agents/agents/tf_agent.py", line 506, in train
loss_info = self._train_fn(
File "/home/stephanie/.local/lib/python3.8/site-packages/tf_agents/utils/common.py", line 185, in with_check_resource_vars
return fn(*fn_args, **fn_kwargs)
File "/home/stephanie/.local/lib/python3.8/site-packages/tf_agents/agents/ddpg/ddpg_agent.py", line 237, in _train
critic_loss = self.critic_loss(time_steps, actions, next_time_steps,
File "/home/stephanie/.local/lib/python3.8/site-packages/tf_agents/agents/ddpg/ddpg_agent.py", line 297, in critic_loss
target_actions, _ = self._target_actor_network(
File "/home/stephanie/.local/lib/python3.8/site-packages/tf_agents/networks/network.py", line 413, in __call__
outputs, new_state = super(Network, self).__call__(**normalized_kwargs)
File "/home/stephanie/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 985, in __call__
outputs = call_fn(inputs, *args, **kwargs)
TypeError: call() missing 1 required positional argument: 'network_state'
i get the same thing when plugging a DdpgAgent into the SAC tutorial notebook.
Did you solve it? I also got the same error when I implemented td3agent.
Changing the network parent from tfagents.network to keras.layer (+ init change) helps me to find where the bug was (error message was clearer)
Same problem with td3... Any help is appreciated!
I resolved this by using tf_agents.agents.ddpg.actor_network.ActorNetwork. tf_agents.networks.actor_distribution_network.ActorDistributionNetwork does not work.