DI-engine icon indicating copy to clipboard operation
DI-engine copied to clipboard

PPO Policy Bug in Parallel Mode

Open tianhan4 opened this issue 4 years ago • 2 comments

  • [ ] I have marked all applicable categories:
    • [ ] exception-raising bug
    • [x] RL algorithm bug
    • [ ] system worker bug
    • [ ] system utils bug
    • [ ] code design/refactor
    • [ ] documentation request
    • [ ] new feature request
  • [x] I have visited the readme and doc
  • [x] I have searched through the issue tracker and pr tracker
  • [ ] I have mentioned version numbers, operating system and environment, where applicable:
    import ding, torch, sys
    print(ding.__version__, torch.__version__, sys.version, sys.platform)
    

The value_norm is used in _get_train_sample in PPO Policy, which is used in the _process_timestep function in collector. However, in parallel mode, the collector doesn't have value_norm which is only initialized in _init_learn. Thus, raise the exception "AttributeError: 'PPOCommandModePolicy' object has no attribute '_value_norm".

tianhan4 avatar Jan 06 '22 07:01 tianhan4

Can you offer your launch script or main entry file to help us reproduce your result?

PaParaZz1 avatar Jan 06 '22 07:01 PaParaZz1

from easydict import EasyDict

cartpole_dqn_config = dict(
    exp_name='cartpole_ppo',
    env=dict(
        collector_env_num=8,
        collector_episode_num=2,
        evaluator_env_num=5,
        evaluator_episode_num=1,
        stop_value=195,
    ),
    policy=dict(
        cuda=False,
        action_space='discrete',
        model=dict(
            obs_shape=4,
            action_shape=2,
            action_space='discrete',
        ),
        learn=dict(
            batch_size=32,
            learning_rate=0.001,
            value_weight=0.5,
            entropy_weight=0.01,
            clip_ratio=0.2,
            learner=dict(
                learner_num=1,
                send_policy_freq=1,
            ),
        ),
        collect=dict(
            n_sample=256,
            unroll_len=1,
            discount_factor=0.9,
            gae_lambda=0.95,
            collector=dict(
                collector_num=2,
                update_policy_second=3,
            ),
        ),
        eval=dict(evaluator=dict(eval_freq=50, )),
        other=dict(
            eps=dict(
                type='exp',
                start=0.95,
                end=0.1,
                decay=100000,
            ),
            replay_buffer=dict(
                replay_buffer_size=100000,
                enable_track_used_data=False,
            ),
            commander=dict(
                collector_task_space=2,
                learner_task_space=1,
                eval_interval=5,
            ),
        ),
    ),
)
cartpole_dqn_config = EasyDict(cartpole_dqn_config)
main_config = cartpole_dqn_config

cartpole_dqn_create_config = dict(
    env=dict(
        type='cartpole',
        import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
    ),
    env_manager=dict(type='base'),
    policy=dict(type='ppo_command'),
    learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
    collector=dict(
        type='zergling',
        import_names=['ding.worker.collector.zergling_parallel_collector'],
    ),
    commander=dict(
        type='solo',
        import_names=['ding.worker.coordinator.solo_parallel_commander'],
    ),
    comm_learner=dict(
        type='flask_fs',
        import_names=['ding.worker.learner.comm.flask_fs_learner'],
    ),
    comm_collector=dict(
        type='flask_fs',
        import_names=['ding.worker.collector.comm.flask_fs_collector'],
    ),
)
cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
create_config = cartpole_dqn_create_config

cartpole_dqn_system_config = dict(
    coordinator=dict(),
    path_data='./{}/data'.format(main_config.exp_name),
    path_policy='./{}/policy'.format(main_config.exp_name),
    communication_mode='auto',
    learner_gpu_num=1,
)
cartpole_dqn_system_config = EasyDict(cartpole_dqn_system_config)
system_config = cartpole_dqn_system_config

if __name__ == '__main__':
    from ding.entry.parallel_entry import parallel_pipeline
    parallel_pipeline([main_config, create_config, system_config], seed=9)

tianhan4 avatar Jan 06 '22 08:01 tianhan4