numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

`SkewMultivariateNormal` and `SkewMultivariateStudentT`

Open colehaus opened this issue 3 years ago • 4 comments

I have implemented versions of both of these:

from __future__ import annotations

from typing import Union, cast

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.scipy.linalg import cho_solve
from jax.scipy.special import gammaln
from numpy.typing import NDArray
from numpyro.distributions import (
    Chi2,
    Distribution,
    MultivariateNormal,
    MultivariateStudentT,
    Normal,
    constraints,
)
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample

def delta(skewers_: NDArray[float], cov_: NDArray[float]):
    return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
        1 + jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
    )

# Efficient computation of the distribution functions of student's t chi-squared and f to moderate accuracy
# https://sci-hub.se/10.1080/00949658208810542
# Have to use approximation because `betainc` doesn't have grads defined.
# Which means we can't use the official `StudentT.cdf`
@jax.jit
def t_cdf_approx(df: Union[NDArray[float], float], t: Union[NDArray[float], float]):
    a = df - 1 / 2
    b = 48 * a**2
    # Add epsilon to avoid undefined gradient at 0
    z = jnp.sqrt(a * jnp.log(1 + t**2 / df) + 1e-24)
    u = (
        z
        + (z**3 + 3 * z) / b
        - (4 * z**7 + 33 * z**5 + 240 * z**3 + 855 * z) / (10 * b * (b + 0.8 * z**4 + 100))
    )
    return Normal(loc=0, scale=1).cdf(u * jnp.sign(t))

# Regularized Multivariate Regression Models with Skew-t Error Distributions
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateNormal(Distribution):  # type: ignore # pylint: disable=too-many-instance-attributes
    arg_constraints = {
        "loc": constraints.real_vector,
        "scale_tril": constraints.lower_cholesky,
        "skewers": constraints.real_vector,
    }
    support = constraints.real_vector
    reparametrized_params = ["loc", "scale_tril", "skewers"]
    uv_norm = Normal(0.0, 1.0)

    @staticmethod
    def mk_big_mv_norm(loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]):
        cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
        delta_ = delta(skewers, cov)
        cov_star = jnp.block(
            [
                [jnp.ones(skewers.shape[:-1] + (1, 1)), jnp.expand_dims(delta_, axis=-2)],
                [jnp.expand_dims(delta_, axis=-1), cov],
            ]
        )

        return MultivariateNormal(loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star))
    def __init__(
        self,
        scale_tril: NDArray[float],
        skewers: NDArray[float],
        loc: Union[NDArray[float], float] = 0,
        validate_args: None = None,
    ):
        if jnp.ndim(loc) == 0:
            (loc_,) = promote_shapes(loc, shape=(1,))
        else:
            loc_ = cast(NDArray[float], loc)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
        )
        (self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
        (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
        (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])
        cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
        self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

        # Used for sampling
        self._big_mv_norm = self.mk_big_mv_norm(
            # The blog post just uses unstandardized skewers here but that leads to
            # a discrepancy between sampling and log_prob
            loc=self.loc,
            skewers=skewers / self._std_devs,
            scale_tril=scale_tril,
        )
        # Used for log_prob
        self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)

        skew_mean = jnp.sqrt(2 / jnp.pi) * delta(self.skewers / self._std_devs, cov_batch)
        self._mean = self.loc + skew_mean
        # The paper just uses `mean` here but that's definitely not right because
        # it potentially leads to covariance matrices which are not positive semi definite
        self._covariance = cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)

        event_shape = jnp.shape(self.scale_tril)[-1:]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )
    @validate_sample
    def log_prob(self, value: NDArray[float]) -> NDArray[float]:
        return (
            jnp.log(2)
            + self._mv_norm.log_prob(value)
            + jnp.log(
                self.uv_norm.cdf(jnp.einsum("...k,...k->...", (value - self.loc) / self._std_devs, self.skewers))
            )
        )
    @staticmethod
    def infer_shapes(loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
        event_shape = (scale_tril[-1],)
        batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
        return batch_shape, event_shape
    # https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
    def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
        assert is_prng_key(key)
        x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
        sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
        return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc
    @property
    def mean(self):
        return jnp.broadcast_to(self._mean, self.shape())
    @property
    def covariance_matrix(self):
        return self._covariance

# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateStudentT(Distribution):  # type: ignore # pylint: disable=too-many-instance-attributes
    arg_constraints = {
        "df": constraints.positive,
        "loc": constraints.real_vector,
        "scale_tril": constraints.lower_cholesky,
        "skewers": constraints.real_vector,
    }
    support = constraints.real_vector
    reparametrized_params = ["df", "loc", "scale_tril", "skewers"]

    def __init__(  # pylint: disable=too-many-arguments
        self,
        df: float,
        scale_tril: NDArray[float],
        skewers: NDArray[float],
        loc: Union[NDArray[float], float] = 0,
        validate_args: None = None,
    ):
        if jnp.ndim(loc) == 0:
            (loc_,) = promote_shapes(loc, shape=(1,))
        else:
            loc_ = cast(NDArray[float], loc)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(df), jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
        )
        (self.df,) = promote_shapes(df, shape=batch_shape)
        (self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
        (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
        (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])

        self._width = scale_tril.shape[-1]

        # For log_prob
        self._mv_t = MultivariateStudentT(df=df, scale_tril=scale_tril, loc=loc)
        eye = jnp.broadcast_to(jnp.eye(self._width), shape=batch_shape + scale_tril.shape[-2:])
        prec_scale_tril = jnp.linalg.cholesky(cho_solve((self.scale_tril, True), eye))
        self.prec = jnp.einsum("...ij,...hj->...ih", prec_scale_tril, prec_scale_tril)
        cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
        self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

        # For sample
        self._mv_skew_norm = SkewMultivariateNormal(
            scale_tril=scale_tril, loc=jnp.zeros(self._width), skewers=skewers
        )
        self._chi2 = Chi2(self.df)

        # Mean
        b = jnp.sqrt(self.df / jnp.pi) * jnp.exp(gammaln((self.df - 1) / 2) - gammaln(self.df / 2))
        skew_mean = b[..., jnp.newaxis] * delta(self.skewers / self._std_devs, cov_batch)
        self._mean = self.loc + skew_mean
        # The paper says we should multiply by the std devs but that produces results that
        # disagree with `sample` and with `SkewMultivariateNormal`
        # It also says we should use `_mean` instead of `skew_mean` but that allows for
        # covariance matrices which are not positive semi-definite
        self._covariance = jnp.array((self.df / (self.df - 2)))[
            ..., jnp.newaxis, jnp.newaxis
        ] * cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)

        event_shape = jnp.shape(self.scale_tril)[-1:]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )
    @validate_sample
    def log_prob(self, value: NDArray[float]) -> NDArray[float]:
        distance = value - self.loc
        Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance)
        # Have to use approximation because `betainc` doesn't have grads defined.
        # Which means we can't use the official `StudentT.cdf`
        skew = t_cdf_approx(
            self.df + self._width,
            jnp.einsum(
                "...k,...k->...",
                self.skewers,
                jnp.einsum(
                    "...i,...->...i", distance / self._std_devs, jnp.sqrt((self.df + self._width) / (Qy + self.df))
                ),
            ),
        )
        return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)
    @staticmethod
    def infer_shapes(df: float, loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
        event_shape = (scale_tril[-1],)
        batch_shape = lax.broadcast_shapes(df, loc[:-1], scale_tril[:-2], skewers[:-1])
        return batch_shape, event_shape
    def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
        assert is_prng_key(key)
        key_normal, key_chi2 = random.split(key)
        normal = self._mv_skew_norm.sample(key_normal, sample_shape=sample_shape)
        chi = self._chi2.sample(key_chi2, sample_shape)
        return self.loc + jnp.einsum("...i,...->...i", normal, jnp.sqrt(self.df / chi))
    @property
    def mean(self):
        return jnp.broadcast_to(self._mean, self.shape())
    @property
    def covariance_matrix(self):
        return self._covariance

(I also have some coding testing them.)

  1. Is there interest in upstreaming these?
  2. Are there obvious simplifications?
  3. SkewMultivariateStudentT is notably slower than MultivariateStudentT in some circumstances. Are there any obvious performance improvements available?

colehaus avatar Jul 17 '22 04:07 colehaus

@colehaus yes, i'm sure a PR would be welcome.

Would it help to use tfp.math.betainc?

import tensorflow_probability.substrates.jax as tfp

martinjankowiak avatar Jul 17 '22 18:07 martinjankowiak

Unless I'm misunderstanding you: There's a comment in the source describing the problem there which is that betainc doesn't have all grads defined: https://github.com/tensorflow/probability/issues/655#issuecomment-558236514.

colehaus avatar Jul 17 '22 19:07 colehaus

FYI since the last release, tfp.math.betainc has grad w.r.t. all parameters. I would suggest to have 3 PRs for:

  • StudentT.cdf (which locally import import tensorflow_probability.substrates.jax as tfp)
  • SkewMVN
  • SkewMVT

fehiepsi avatar Jul 17 '22 21:07 fehiepsi

Ah, that's good news! And that sounds like a reasonable plan. I probably won't be able to think about doing it for a few weeks though.

colehaus avatar Jul 18 '22 18:07 colehaus