agents icon indicating copy to clipboard operation
agents copied to clipboard

Using Actor- Learner API and reverb for PPO agent

Open sibyjackgrove opened this issue 4 years ago • 4 comments

I am trying to adapt the SAC minitaur tutorial which uses the Actor-Learner API and reverb to work with the PPO agent. I changed the tf_agent from sac_agent.SacAgent to the ppo_clip_agent.PPOClipAgent

tf_agent = ppo_clip_agent.PPOClipAgent(
        time_step_spec,
        action_spec,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        actor_net=actor_net,
        value_net=value_net,
        entropy_regularization=0.0,
        importance_ratio_clipping=0.2,
        normalize_observations=False,
        normalize_rewards=False,
        use_gae=True,
        num_epochs=25,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        train_step_counter=train_step)

However, I keep getting this error when the training loop starts and collect_actor.run() is encountered. Could you point me to what else I must change to get the PPO agent working with Actor-Learner API and reverb?

step = 0: AverageReturn = 5.256985, AverageEpisodeLength = 10.000000
Start training...
Traceback (most recent call last):
  File "/home/user/.local/lib/python3.7/site-packages/reverb/trajectory_writer.py", line 276, in append
    flat_column_data = self._reorder_like_flat_structure(data_with_path_flat)
  File "/home/user/.local/lib/python3.7/site-packages/reverb/trajectory_writer.py", line 444, in _reorder_like_flat_structure
    flat_data[self._path_to_column_index[path]] = value
KeyError: ('policy_info', 'dist_params', 'loc')

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "PPO_unitcommitment_rev0.py", line 294, in <module>
    collect_actor.run()
  File "/home/user/.conda/envs/miplearn-rl/lib/python3.7/site-packages/tf_agents/train/actor.py", line 149, in run
    self._time_step, self._policy_state)
  File "/home/user/.conda/envs/miplearn-rl/lib/python3.7/site-packages/tf_agents/drivers/py_driver.py", line 123, in run
    observer(traj)
  File "/home/user/.conda/envs/miplearn-rl/lib/python3.7/site-packages/tf_agents/replay_buffers/reverb_utils.py", line 353, in __call__
    self._writer.append(trajectory)
  File "/home/user/.local/lib/python3.7/site-packages/reverb/trajectory_writer.py", line 284, in append
    flat_column_data = self._reorder_like_flat_structure(data_with_path_flat)
  File "/home/user/.local/lib/python3.7/site-packages/reverb/trajectory_writer.py", line 444, in _reorder_like_flat_structure
    flat_data[self._path_to_column_index[path]] = value
KeyError: ('policy_info', 'dist_params', 'loc')
[reverb/cc/platform/default/server.cc:84] Shutting down replay server

sibyjackgrove avatar Jan 14 '22 21:01 sibyjackgrove

You also need to use the PPOPolicy which will collect and the proper extra information. Take a look at https://github.com/tensorflow/agents/blob/master/tf_agents/examples/ppo/schulman17/train_eval_lib.py

sguada avatar Jan 15 '22 01:01 sguada

@sguada Thank you for the link to this code. I had missed this and this is exactly what I want. I see that there is num_environments parameter on line 175. Can this code spawn parallel environments and collect data from all of the like what we could do in previtrain_eval_clip_agent.py. If not could you point me to the changes I need to make to do that.

sibyjackgrove avatar Jan 17 '22 17:01 sibyjackgrove

I'd like to know this also

profPlum avatar Jan 18 '22 01:01 profPlum

@profPlum It seems parallel environments may not work with Reverb. So I am using the train_eval_clip_agent.py for now.

It seems train_eval_clip_agent.py will not work due to an error in parallel_py_environment.ParallelPyEnvironment. To make it work you have to replace:

tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments))

with following:

def _env_creator(environment_name='CartPole-v1'):
    return functools.partial(suite_gym.load, environment_name=environment_name)

def _parallel_env_creator(num_parallel_environments=5, environment_name='CartPole-v1'):
    return functools.partial(parallel_py_environment.ParallelPyEnvironment,env_constructors=[_env_creator(environment_name) for _ in range(num_parallel_environments)])

tf_env = tf_py_environment.TFPyEnvironment(_parallel_env_creator())   

sibyjackgrove avatar Jan 24 '22 21:01 sibyjackgrove