probability
probability copied to clipboard
`ZeroSum` bijector and `ZeroSumNormal` distribution
numpyro and pymc have a zero sum normal distribution based on a zero sum bijector, (see e.g. numpyro zero sum normal and numpyro zero sum transform)).
I was wondering if there is any appetite in adding this to TFP? I have already got a simple port working (needs some changes, in particular maybe allowing variable number of axes to be constrained to sum to zero):
"""ZeroSum bijector."""
import tensorflow as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util
class ZeroSum(bijector.AutoCompositeTensorBijector):
def __init__(self, validate_args=False, name="zero_sum"):
parameters = dict(locals())
super(ZeroSum, self).__init__(
is_constant_jacobian=True,
forward_min_event_ndims=1,
validate_args=validate_args,
parameters=parameters,
name=name,
)
@classmethod
def _parameter_properties(cls, dtype):
return dict()
def _forward(self, x):
n = ps.cast(ps.shape(x)[-1], x.dtype) + 1
sum_vals = tf.reduce_sum(x, axis=-1, keepdims=True)
norm = sum_vals / (ps.sqrt(n) + n)
fill_val = norm - sum_vals / ps.sqrt(n)
out = tf.concat([x, fill_val], axis=-1)
return out - norm
def _inverse(self, y):
normalized_axis = ps.rank(y) - 1
n = ps.cast(ps.shape(y)[normalized_axis], y.dtype)
last = y[..., -1]
sum_vals = -last * ps.sqrt(n)
norm = sum_vals / (ps.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis
return y[(*slice_before, slice(None, -1))] + norm
def _inverse_log_det_jacobian(self, y):
return tf.zeros([], dtype=y.dtype)
def _forward_log_det_jacobian(self, x):
return tf.zeros([], dtype=x.dtype)
def _forward_event_shape(self, input_shape):
return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)
def _forward_event_shape_tensor(self, input_shape):
n = ps.shape(input_shape)[-1]
return ps.tensor_scatter_nd_add(input_shape, [[n - 1]], [1])
def _inverse_event_shape(self, input_shape):
return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)
def _inverse_event_shape_tensor(self, input_shape):
n = ps.shape(input_shape)[-1]
return ps.tensor_scatter_nd_sub(input_shape, [[n - 1]], [1])
usage:
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
zero_sum_normal = tfd.TransformedDistribution(
distribution=tfd.MultivariateNormalDiag(loc=0.0, scale_diag=[1.0, 1.0]),
bijector=ZeroSum(),
)
zero_sum_normal
# <tfp.distributions.TransformedDistribution 'zero_sumMultivariateNormalDiag' batch_shape=[] event_shape=[3] dtype=float32>
samples = zero_sum_normal.sample(int(1e7))
np.max(np.abs(np.sum(samples, axis=-1)))
# 4.7683716e-07
np.mean(samples, axis=0)
# array([ 4.8274879e-04, -5.6865485e-04, 8.5900021e-05], dtype=float32)
np.std(samples, axis=0)
# array([0.8100124, 0.8101918, 0.8100392], dtype=float32)
compare to numpyro:
import jax.numpy as jnp
import jax.random as jr
import numpyro.distributions as dist
zero_sum_normal = dist.ZeroSumNormal(scale=jnp.array(1.0), event_shape=[3])
rng = jr.key(123)
samples = zero_sum_normal.sample(rng, sample_shape=(int(1e7),))
jnp.max(jnp.abs(jnp.sum(samples, axis=1)))
# Array(5.9604645e-07, dtype=float32)
jnp.mean(samples, axis=0)
# Array([-1.7739683e-04, -6.7088688e-05, 2.4448565e-04], dtype=float32)
jnp.std(samples, axis=0)
# Array([0.8164292 , 0.81622946, 0.8164195 ], dtype=float32)
If this is something useful, I can work on bits of it over the next couple of weeks, or if someone else wants to take it over, that's great too.
Thanks.