brax icon indicating copy to clipboard operation
brax copied to clipboard

PPO train code refactor for checkpointing and curriculum compatibility

Open btnorman opened this issue 3 years ago • 2 comments

I refactored the PPO/train code to increase its compatibility with check-pointing and environment variation, while keeping backwards compatibility

This splits the train function into multiple functions:

  • make_train_space: a function that creates the training functions, e.g. training_epoch_with_timing, evaluator, etc... and wraps them in a returned simple namespace
  • init_training_state: a function that initializes the training state
  • init_env_state: a function that initializes the environment state
  • and run_train: a function that performs a training run, when passed a training state, environment state, and train space

This enables the training functions to be run multiple times with jit compilation only occurring once. This also adds:

  • train: a function that works identically to the previous train function, by calling the above described functions
  • checkpoint_train: a simple check pointing version of train, that should be compatible with preemption.

The general aim is that run_train should enable people to easily make their own check pointing and environment variation / curriculum generation on top of the Brax PPO training code, without having to modify the Brax PPO code internally.

This is my first pull request, so let me know if I have made any rookie mistakes! And thanks for the great physics engine!

Fyi, I have not been able to test on in an environment with multiple processors, e.g. a TPU slice

btnorman avatar Jul 26 '22 15:07 btnorman

Hi btnorman,

I'm very glad to see you're showing interest in Brax!

The idea of the agents directory is to show example implementations of popular algorithms with Brax. We know it doesn't cover all usages, and so it's expected that people will fork those examples to get the algorithm to do what they want.

We expect different users will have different opinions regarding Check-pointing for example - as a result we prefer to leave it unimplemented.

There is also that we want the interface for all algorithms to be as close as possible, so I don't believe introducing the abstractions you propose only for PPO can go through.

For those reasons, I think it makes sense this PR stays out of the main branch.

Have fun with Brax!

m-orsini avatar Jul 28 '22 08:07 m-orsini

Hi! Thanks for looking over it.

I want to make sure I have not misrepresented the intention of the proposed contribution!

The contribution comes in two separate parts.

  1. Modifying the structure of the PPO code to more easily allow Brax users to build on top of it, without Brax users having to modify the internals of the Brax PPO code.

    To illustrate the value, imagine a user who solely wants to implement check pointing, and is otherwise happy with the out of the box PPO code. At the moment, this user has two options. A) they can copy the Brax code and introduce check pointing to the internals, or B) they can use a different PPO implementation, either their own or e.g. PyTorch.

    Both of these options, A) and B) come with significant overheads. Copying the Brax code and modifying it first involves parsing the logic and structure of the code, and to someone unfamiliar with Jax that is a lot of work. It also involves consulting multiple different Brax files to understand how all the code fits together. Using a different PPO implementation is also a lot of work as there is the overhead of working out how Brax and that implementation work well together. Both of these are also error prone as if one introduces an error or typo, the RL training might subtly break.

    By making the proposed change, we introduce a third option for this user. C) they can take the abstractions provided, and build on top of them. This is significantly less work, as it needs minimal understanding of Jax and minimal consultation of other Brax files.

  2. Introducing a very simple example implementation of checkpointing that demonstrates how someone could use the components introduced to make their own checkpointing code easily, building onto of the Brax code, without having to write their own PPO logic.

    This relates to the hypothetical user who wants their custom checkpointing code. They can copy the example implementation of checkpointing, which only involves file loading, and without any understanding of Jax, or the need to consult any other Brax files, modify it to suit their checkpointing needs.

    This second part is only meant to be illustrative. If the changes went ahead then a checkpointing example might be better placed in a notebook instead.

If the aim of making it easier to build on top of existing Brax training code seems valuable, but the proposed implementation lacking, perhaps there is something else we can do, and I would be happy to help!

If you decide the change does seem valuable then I would be happy to modify the other algorithms so that they are all consistent.

Fyi, to allay a potential concern, in using the checkpointing code with these abstractions there is very low overhead.

btnorman avatar Jul 28 '22 13:07 btnorman