tjax
tjax copied to clipboard
Tools for JAX
============= Tools for JAX
.. role:: bash(code) :language: bash
.. role:: python(code) :language: python
This repository implements a variety of tools for the differential programming library
JAX <https://github.com/google/jax>_.
Major components
Tjax's major components are:
-
A
dataclass <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/dataclasses>_ andmypy_plugin <https://github.com/NeilGirdhar/tjax/blob/master/tjax/mypy_plugin.py>_ decorator :python:dataclasssthat facilitates defining structured JAX objects (so-called "pytrees"), which benefits from:- the ability to mark fields as static (not available in
chex.dataclass), - a MyPy plugin, and
- a display method that produces formatted text according to the tree structure.
- the ability to mark fields as static (not available in
-
A
fixed_point <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/fixed_point>_ finding library heavily based onfax <https://github.com/gehring/fax>_. Our library- supports stochastic iterated functions, and
- uses dataclasses instead of closures to avoid leaking JAX tracers.
-
A
shim <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/gradient>_ for the gradient transformation libraryoptax <https://github.com/deepmind/optax>_ that supports:- easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
- gradient transformation objects that can be passed dynamically to jitted functions, and
- generic type annotations.
Minor components
Tjax also includes:
-
A pretty printer :python:
print_genericfor aggregate and vector types, including dataclasses. (Seedisplay <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/display.py>_.) -
Versions of :python:
custom_vjpand :python:custom_jvpthat support being used on methods. (Seeshims <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/shims.py>_.) -
Tools for working with cotangents. (See
cotangent_tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/cotangent_tools.py>_.) -
JAX tree registration for
NetworkX <https://networkx.github.io/>_ graph types. (Seegraph <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/graph.py>_.) -
Leaky integration :python:
leaky_integrateand Ornstein-Uhlenbeck process iteration :python:diffused_leaky_integrate. (Seeleaky_integral <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/leaky_integral.py>_.) -
An improved version of :python:
jax.tree_util.Partial. (Seepartial <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/partial.py>_.) -
A Matplotlib trajectory plotter :python:
PlottableTrajectory. (Seeplottable_trajectory <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/plottable_trajectory.py>_.) -
A testing function :python:
assert_tree_allclosethat automatically produces testing code. And, a related function :python:tree_allclose. (Seetesting <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/testing.py>_.) -
Basic tools like :python:
divide_where. (Seetools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/tools.py>_.)
Also, see the documentation <https://neilgirdhar.github.io/tjax/tjax/index.html>_.
Contribution guidelines
-
Conventions: PEP8.
-
How to run tests: :bash:
pytest . -
How to clean the source:
- :bash:
isort tjax - :bash:
pylint tjax - :bash:
mypy tjax - :bash:
flake8 tjax
- :bash: