GPJax icon indicating copy to clipboard operation
GPJax copied to clipboard

Natgrads

Open daniel-dodd opened this issue 3 years ago • 5 comments

This PR seeks to add natural gradients to GPJax, as well as two new Gaussian variational family parameterisations.

Please check the type of change your PR introduces:

  • [ ] Bugfix
  • [ x] Feature
  • [ ] Code style update (formatting, renaming)
  • [x ] Refactoring (no functional changes, no api changes)
  • [ ] Build related changes
  • [ ] Documentation content changes
  • [ ] Other (please describe):

Current state:

Code currently only works for the natural parameterisation case. Main thing (asides from the obvious simplification of the rough codebase and improvement of the API) is that we lack natural gradients for general parameterisations.

There are two notebooks associated with this PR:

  • natgrads.ipynb has the case the variational family is chosen as the natural parameterisation - this should work but I have not tested it since I rebased with the master branch.
  • Natural Gradient General case.ipynb is provided to show a rough sketch of what the general case might look like, that will involve the user defining a bijection between their parameterisation and the natural parameterisation.

There are some unit tests provided for the new variational families, and for natural_gradients.py in its current state, as well as changes made in parameters.py.

As a final note, it is likely that some of the functions defined in natural_gradients.py for stopping gradients might be better generalised and added to abstractions.py or parameters.py

daniel-dodd avatar Jul 21 '22 15:07 daniel-dodd

@thomaspinder it would be good to get your thoughts on the code so far (some of it rather rough). I believe the implementation works. The main issue is that the notebook needs writing before merging, I'll get on with this. Also some tests need writing. I'll need to test some benchmarks against fit_batches once #99 is merged, to see if performance is up to scratch.

daniel-dodd avatar Aug 19 '22 21:08 daniel-dodd

Codecov Report

:exclamation: No coverage uploaded for pull request base (v0.5_update@7773eef). Click here to learn what that means. The diff coverage is n/a.

:exclamation: Current head cc1c318 differs from pull request most recent head 24032c0. Consider uploading reports for the commit 24032c0 to get more accurate results

@@              Coverage Diff               @@
##             v0.5_update      #90   +/-   ##
==============================================
  Coverage               ?   99.22%           
==============================================
  Files                  ?       14           
  Lines                  ?     1154           
  Branches               ?        0           
==============================================
  Hits                   ?     1145           
  Misses                 ?        9           
  Partials               ?        0           
Flag Coverage Δ
unittests 99.22% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

codecov[bot] avatar Aug 19 '22 21:08 codecov[bot]

@thomaspinder I have rebased the branch with master and would appreciate a review.

On my end, I need to:

  • Update and add some tests
  • Write the notebook

daniel-dodd avatar Aug 23 '22 12:08 daniel-dodd

@thomaspinder for the commits d84ede9, e0b60d4 and bd6e4aa, I have addressed many of your comments and suggestions.

Outstanding issues are to write unit tests, the notebook, and explanation of parameter optimisation order.

daniel-dodd avatar Aug 23 '22 17:08 daniel-dodd

Hi @thomaspinder, I believe I have addressed your comments so far. I expect all functions to be covered by unit tests now (but we'll see what CodeCov says). I would appreciate a review on the latest code changes - and I expect we can improve it further.

One key thing that I dislike, is that the fit_natgrads abstraction takes it fist argument as the variational inference strategy, while fit and fit_batches take in objectives (I find this inconsistent). I'm not sure if you had any thoughts on this?

I am going to start work on writing the notebook today.

daniel-dodd avatar Aug 24 '22 13:08 daniel-dodd