[Bug] Log transform not applied when computing posterior in heteroskedastic GP model
The heteroskedastic GP model uses two GPs internally: The first models the observation noise, and the second estimates the function value. As far as I understand, the noise model output is in log domain, such that after exponentiation this output will be positive.
Therefore, I expected to find something like torch.exp to undo the log transformation before the noise estimate is combined with the second GP. The only transformation I found, however, comes from the HeteroskedasticNoise's noise constraint GreaterThan(1e-4), which applies a softplus transform. For negative inputs (< -2) the softplus function yields similar values as the exp function. But in case of larger noise, softplus and exp differ significantly. As a result, the noise is underestimated.
Have I missed the piece of code which untransforms the noise from the log domain, or is it the softplus function in the noise constraint? In case of the latter, shouldn't the noise constraint use an exp transform instead of softplus?
If this is the case, I would propose to change this line to
heteroskedastic_noise = HeteroskedasticNoise(
noise_model=noise_model,
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL, transform=torch.exp, inv_transform=torch.log),
)
likelihood = _GaussianLikelihoodBase(heteroskedastic_noise)
System Info
- BoTorch Version 0.4.0
- GPyTorch Version 1.4.2
- PyTorch Version 1.7.1
- Ubuntu 20.04.2 LTS
Great question. So the idea behind the model is to use the Log OutcomeTransform in the noise model: https://github.com/pytorch/botorch/blob/master/botorch/models/gp_regression.py#L394
In general, the outcome transform should automatically untransform the predictions back to the original scale. However, this is only happening in the BoTorch posterior call, which actually does not occur when the model is evaluated inside the gpytorch likelihood (it is evaluated via forward there).
cc @saitcakmak
I'll have to take a closer look at this to confirm, but it appears that you're right about the noise potentially being underestimated due to this bug. In the meantime, your suggested change seems reasonable, but you will want to take out the Log outcome transform in that case.
@Balandat, looks like outcome transforms have been applied in the posterior call, going all the way back to #327. So, this bug would've been around since then.
Yeah, I'm not suggesting it wasn't, just flagging the issue of a BoTorch model with outcome transform being used in a context where it may not be proper.
I looked into it a bit closer and I can verify that the Log transform is never triggered during inference. This leads to severely underestimating the noise (assuming large noise) as suggested. Here is a simple script:
import torch
from botorch import fit_gpytorch_model
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from gpytorch import ExactMarginalLogLikelihood
t_X = torch.rand(10, 2)
t_Y_var = torch.ones(10, 1) * 10
model = HeteroskedasticSingleTaskGP(
train_X=t_X,
train_Y=torch.randn(10, 1),
train_Yvar=t_Y_var,
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll)
noise_free = model.posterior(t_X, observation_noise=False)
w_noise = model.posterior(t_X, observation_noise=True)
print(noise_free.variance)
print(w_noise.variance)
print(noise_free.variance + t_Y_var)
Output:
tensor([[0.4486],
[0.4773],
[0.4461],
[0.6378],
[0.5526],
[0.5563],
[0.5930],
[0.6342],
[0.5976],
[0.6315]], grad_fn=<UnsqueezeBackward0>)
tensor([[2.8468],
[2.8750],
[2.8443],
[3.0356],
[2.9505],
[2.9540],
[2.9910],
[3.0320],
[2.9954],
[3.0294]], grad_fn=<UnsqueezeBackward0>)
tensor([[10.4486],
[10.4773],
[10.4461],
[10.6378],
[10.5526],
[10.5563],
[10.5930],
[10.6342],
[10.5976],
[10.6315]], grad_fn=<AddBackward0>)
You can see that the posterior with noise is only added log(10) variance rather than the full 10.
Thanks for checking. Did you try this also with @btlorch' suggested modification (and removal of the Log outcome transform from the noise model)?
I tried a few versions of that but I couldn't get it to work correctly.
Thanks for looking into this issue.
[...] your suggested change seems reasonable, but you will want to take out the
Logoutcome transform in that case.
But if the Log outcome transform is removed, is the noise model still learning the log variance? Should the caller then provide log(train_Y_var) instead of train_Y_var as noise labels?
Please find below an example. Unfortunately I don't remember where all the code pieces came from, but some are based on this GPyTorch tutorial and this notebook from a GPyTorch issue. In this example, I modified the heterskedastic GP by replacing the noise constraint's softplus transform with an exp transform, as described above. I hope that we can find a fix that works for both @saitcakmak's example and my example.
Unmodified heteroskedastic GP: It underestimates the noise variance.

Heteroskedastic GP with noise constraint that uses exp transform instead of softplus (see code below): The estimated noise variance closely follows the ground truth.


from __future__ import annotations
from typing import Any, List, Optional
import torch
from botorch.models.gp_regression import SingleTaskGP, MIN_INFERRED_NOISE_LEVEL
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import Log, OutcomeTransform
from botorch.models.utils import validate_input_scaling
from gpytorch.constraints.constraints import GreaterThan
from gpytorch.likelihoods.gaussian_likelihood import (
GaussianLikelihood,
_GaussianLikelihoodBase,
)
from gpytorch.likelihoods.noise_models import HeteroskedasticNoise
from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm
from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
from torch import Tensor
import matplotlib.pyplot as plt
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_model
class CustomHeteroskedasticSingleTaskGP(SingleTaskGP):
r"""A single-task exact GP model using a heteroskeastic noise model.
This model internally wraps another GP (a SingleTaskGP) to model the
observation noise. This allows the likelihood to make out-of-sample
predictions for the observation noise levels.
"""
def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
outcome_transform: Optional[OutcomeTransform] = None,
input_transform: Optional[InputTransform] = None,
) -> None:
r"""A single-task exact GP model using a heteroskedastic noise model.
Args:
train_X: A `batch_shape x n x d` tensor of training features.
train_Y: A `batch_shape x n x m` tensor of training observations.
train_Yvar: A `batch_shape x n x m` tensor of observed measurement
noise.
outcome_transform: An outcome transform that is applied to the
training data during instantiation and to the posterior during
inference (that is, the `Posterior` obtained by calling
`.posterior` on the model will be on the original scale).
Note that the noise model internally log-transforms the
variances, which will happen after this transform is applied.
input_transform: An input transfrom that is applied in the model's
forward pass.
Example:
>>> train_X = torch.rand(20, 2)
>>> train_Y = torch.sin(train_X).sum(dim=1, keepdim=True)
>>> se = torch.norm(train_X, dim=1, keepdim=True)
>>> train_Yvar = 0.1 + se * torch.rand_like(train_Y)
>>> model = HeteroskedasticSingleTaskGP(train_X, train_Y, train_Yvar)
"""
if outcome_transform is not None:
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
validate_input_scaling(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
self._set_dimensions(train_X=train_X, train_Y=train_Y)
noise_likelihood = GaussianLikelihood(
noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log),
batch_shape=self._aug_batch_shape,
noise_constraint=GreaterThan(
MIN_INFERRED_NOISE_LEVEL, transform=None, initial_value=1.0
),
)
noise_model = SingleTaskGP(
train_X=train_X,
train_Y=train_Yvar,
likelihood=noise_likelihood,
outcome_transform=Log(),
input_transform=input_transform,
)
heteroskedastic_noise = HeteroskedasticNoise(
noise_model=noise_model,
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL, transform=torch.exp, inv_transform=torch.log),
)
likelihood = _GaussianLikelihoodBase(heteroskedastic_noise)
super().__init__(
train_X=train_X,
train_Y=train_Y,
likelihood=likelihood,
input_transform=input_transform,
)
self.register_added_loss_term("noise_added_loss")
self.update_added_loss_term(
"noise_added_loss", NoiseModelAddedLossTerm(noise_model)
)
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self.to(train_X)
if __name__ == "__main__":
# Define the training data
X_train = torch.linspace(0, 6, 500)
y_clean = 10 * torch.sin(X_train)
y_noise = 0.05 + 0.1 * X_train.pow(2)
y_train = y_clean + y_noise * torch.randn_like(X_train)
y_train_var = y_noise ** 2
X_train_botorch = X_train.unsqueeze(dim=-1)
y_train_botorch = y_train.unsqueeze(dim=-1)
y_train_var_botorch = y_train_var.unsqueeze(dim=-1)
def plot_predictions(model):
X_test = torch.linspace(-2, 8, 100)
X_test_botorch = X_test.unsqueeze(dim=-1)
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
with torch.no_grad():
f_posterior = model.posterior(X_test_botorch)
y_posterior = model.posterior(X_test_botorch, observation_noise=True)
# Get confidence region of latent posterior
f_lower, f_upper = f_posterior.mvn.confidence_region()
# Plot training samples
ax.plot(X_train.numpy(), y_train.numpy(), 'k*', label="Observed data")
# Plot predictive mean in blue
line, = ax.plot(X_test.numpy(), f_posterior.mean.numpy(), 'b', label="Mean")
# Confidence interval according to latent posterior
ax.fill_between(X_test.numpy(), f_lower.detach().cpu().numpy(), f_upper.detach().cpu().numpy(), alpha=0.3,
color=line.get_color(), label="Confidence (f)")
# Get confidence region of posterior predictive distribution
y_lower, y_upper = y_posterior.mvn.confidence_region()
# Plot confidence interval from posterior predictive distribution
ax.fill_between(X_test.numpy(), y_lower.detach().cpu().numpy(), y_upper.detach().cpu().numpy(), alpha=0.1,
color=line.get_color(), label="Confidence (y)")
ax.legend(loc="best")
# ax.set_ylim([-2, 2])
fig.tight_layout()
return fig, ax
def plot_predicted_vs_true_observation_noise(model):
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
with torch.no_grad():
y_train_posterior = model.posterior(X_train_botorch, observation_noise=True)
ax.plot(X_train, y_train_var, label="Ground truth observation noise")
ax.plot(X_train, y_train_posterior.mvn.variance.detach().numpy(), label="Predicted observation noise")
ax.set_xlabel("x")
ax.set_ylabel("y_var")
ax.legend(loc="best")
fig.tight_layout()
return fig, ax
het_model = CustomHeteroskedasticSingleTaskGP(train_X=X_train_botorch, train_Y=y_train_botorch, train_Yvar=y_train_var_botorch)
het_mll = ExactMarginalLogLikelihood(het_model.likelihood, het_model)
het_mll = fit_gpytorch_model(het_mll)
# Show fit
plot_predictions(het_model)
# Show ground truth variance vs. predicted variance
plot_predicted_vs_true_observation_noise(het_model)
plt.show()
@btlorch that's interesting. In your plots, it definitely looks like changing to exp fixes things. For some reason, the same change leads to underestimating the noise in my example. Also, as an fyi, keeping outcome_transform=Log() is equivalent to simply replacing train_Y=train_Yvar with train_y=torch.log(train_Yvar) (since it doesn't kick in during inference). It also seems like the inv_transform=torch.log has no effect. I'll dig in some more to see what is happening in HeteroscedasticNoise that may lead to this inconsistent behavior.
@Balandat , I think the use of input_transform in the noise_model is also problematic since the train_X doesn't get stored as transformed here.
I think the use of input_transform in the noise_model is also problematic since the train_X doesn't get stored as transformed here.
Yeah. https://github.com/cornellius-gp/gpytorch/issues/1652 would help a lot here.
Re the outcome transform, the issue is that fundamentally that we'd have to use a botorch concept within a gpytorch component, which is probably not something we want to do. I guess one could also try to upstream the outcome transformations, but that would require some thinking and be a bit of work. It seems like the right thing to do is to find a more direct solution in the short term to fix this particular issue, and take a step back and think about this class of problems more carefully.
Here's a solution that works for both @btlorch's example and mine. Note that I removed noise_likelihood (this leads to some weird predictions for the noise_model), replaced the outcome transform with a simple log over train_Yvar, and trained the noise_model. I also didn't use the inv_transform=torch.log since it never gets used.
noise_model = SingleTaskGP(
train_X=train_X,
train_Y=torch.log(train_Yvar),
# likelihood=noise_likelihood,
# outcome_transform=Log(),
input_transform=input_transform, # NOTE: potential bug here
)
mll = ExactMarginalLogLikelihood(noise_model.likelihood, noise_model)
fit_gpytorch_model(mll)
heteroskedastic_noise = HeteroskedasticNoise(
noise_model=noise_model,
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL, transform=torch.exp, inv_transform=None),
)
I don't particularly like it (especially the training part) but it's the only one that worked for both examples so far.
Interesting. Yeah the fact that you have to do pre-training isn't ideal - jointly training the noise model and the outer model that utilizes the noise model will presumably result in a somewhat different overall model. Using non-inverse functions for the transforms is also a bit hacky...
@saitcakmak Do you know if this has been fixed?
@esantorella AFAIK, it is still broken
:(
:(