Clement Gehring

Results 9 issues of Clement Gehring

Preallocating memory based on available memory easily leads to a race condition when launching several jax processes that share a GPU. It would be useful to be able to set...

P2 (eventual)
NVIDIA GPU

This is as much a question as it is a feature request. What is the reasoning for not allowing a module instance from being created (but not used) outside `hk.transform`?...

enhancement

Currently, there is no easy way to just call an initializer outside of `hk.transform`. This can be worked around but it discourages using `hk.next_rng_key` when writing custom initializers since whatever...

enhancement

With the (somewhat) recent changes to how `jax` handles custom VJPs, it is now possible to define derivatives using the function for which we are defining the derivative. Since the...

The simplest way would to be to use [`tjax.custom_vjp`](https://github.com/NeilGirdhar/tjax/blob/a38695f328ec891fa2f4b78be23ec0abde34bb30/tjax/shims.py#L20), but this is currently not possible due to `tjax`'s python version requirements (#27).

Now that the two phase solver only returns the final solution, we should make sure there is a working mechanism for extracting from the forward and backward solvers any internal...

If we want to leverage @NeilGirdhar's [`tjax`](https://github.com/NeilGirdhar/tjax/) (e.g., a more flexible custom_vjp, pytree typing, pytree dataclasses), we'll need to use python 3.8 or greater since this package and some of...

The mapping of problem index to problem files is not visible without diving into the code. Either we establish an explicit mapping of indices to pddl problem files, e.g., a...