GPJax
GPJax copied to clipboard
bug: Jitting objectives.
Bug Report
GPJax version: 0.7.2
Tagging @henrymoss.
As noted in #402 there is issue with jitting objective. This functionality should be removed.
Option (a)
Remove these following lines of code on AbstractObjective:
...
def __hash__(self):
return hash(tuple(jtu.tree_leaves(self))) # Probably put this on the Module!
def __call__(self, *args, **kwargs) -> ScalarFloat:
return self.step(*args, **kwargs)
So that jax.jit(gpx.ConjugateMLL(negative=True)) errors. This code is dodgy.
Objectives could still be passed as objective=gpx.ConjugateMLL(negative=True) without the jit which is not really needed in the first place, as code is traced with the lax.scan.
Option (b)
Revert back to the previous objective design in GPJax that comprised an outer and inner function:
# e.g.,
def elbo(negative: bool ...) -> callable:
def elbo_fun(model, batch) -> Float[Array, ""]:
...
return elbo_fun
Or even just have objectives defined from a minimisation perspective.
def negative_elbo(model, batch) -> Float[Array, ""]:
...
def negative_log_likelihood(model, batch) -> Float[Array, ""]:
...
Update: going for the design of option (b)(i), i.e.,
# e.g.,
def elbo(negative: bool ...) -> callable:
def elbo_fun(model, batch) -> Float[Array, ""]:
...
return elbo_fun