GPJax icon indicating copy to clipboard operation
GPJax copied to clipboard

bug: Jitting objectives.

Open daniel-dodd opened this issue 2 years ago • 1 comments

Bug Report

GPJax version: 0.7.2

Tagging @henrymoss.

As noted in #402 there is issue with jitting objective. This functionality should be removed.

Option (a)

Remove these following lines of code on AbstractObjective:

    ...
    def __hash__(self):
        return hash(tuple(jtu.tree_leaves(self)))  # Probably put this on the Module!

    def __call__(self, *args, **kwargs) -> ScalarFloat:
        return self.step(*args, **kwargs)

So that jax.jit(gpx.ConjugateMLL(negative=True)) errors. This code is dodgy.

Objectives could still be passed as objective=gpx.ConjugateMLL(negative=True) without the jit which is not really needed in the first place, as code is traced with the lax.scan.

Option (b)

Revert back to the previous objective design in GPJax that comprised an outer and inner function:

# e.g.,
def elbo(negative: bool ...) -> callable:
      def elbo_fun(model, batch) -> Float[Array, ""]:
             ...
      return elbo_fun

Or even just have objectives defined from a minimisation perspective.

def negative_elbo(model, batch) -> Float[Array, ""]:
       ...

def negative_log_likelihood(model, batch) -> Float[Array, ""]:
      ...

daniel-dodd avatar Nov 06 '23 16:11 daniel-dodd

Update: going for the design of option (b)(i), i.e.,

# e.g.,
def elbo(negative: bool ...) -> callable:
      def elbo_fun(model, batch) -> Float[Array, ""]:
             ...
      return elbo_fun

daniel-dodd avatar Nov 16 '23 13:11 daniel-dodd