Optimize `_kl_matrix_normal_matrix_normal` for `tfd.MatrixNormalLinearOperator`
Hello! I was glad to notice the introduction of the tfd.MatrixNormal in 0.13 (even if I saw it a bit late, ie today). However, the implementation of the KL divergence still relies on the general KL implementation for mvn_linear_operator's, hence there is a dense operation on the scale parameter (a Kronecker product), which can be very expensive in high dimensions. I did implement a MatrixNormal distribution before the introduction of this, and I implemented the explicit KL divergence when the LinearOperatorKronecker has only two matrices (ie, the case of MatrixNormal). The implementation is the following (based on the _kl_brute_force of MultivariateNormalLinearOperator):
def _kl_matrix_normal_matrix_normal_optimized(a, b, name=None):
"""Batched KL divergence `KL(a || b)` for Matrix Normals.
Args:
a: Instance of `MatrixNormalLinearOperator`.
b: Instance of `MatrixNormalLinearOperator`.
name: (optional) name to use for created ops. Default "kl_mvn".
Returns:
Batchwise `KL(a || b)`.
"""
def squared_frobenius_norm(x):
"""Helper to make KL calculation slightly more readable."""
# http://mathworld.wolfram.com/FrobeniusNorm.html
# The gradient of KL[p,q] is not defined when p==q. The culprit is
# tf.norm, i.e., we cannot use the commented out code.
# return tf.square(tf.norm(x, ord="fro", axis=[-2, -1]))
return tf.reduce_sum(tf.square(x), axis=[-2, -1])
with tf.name_scope(name or "kl_mn"):
# Calculation is based on:
# http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
# and,
# https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
# i.e.,
# If Ca = AA', Cb = BB', then
# tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
# = tr[inv(B) A A' inv(B)']
# = tr[(inv(B) A) (inv(B) A)']
# = sum_{ij} (inv(B) A)_{ij}**2
# = ||inv(B) A||_F**2
# where ||.||_F is the Frobenius norm and the second equality follows from
# the cyclic permutation property.
KInvS_h = b.scale_row.solve(a.scale_row.to_dense())
KInvS_x = b.scale_column.solve(a.scale_column.to_dense())
Mt = b.mean() - a.mean()
transpose = tfb.Transpose(rightmost_transposed_ndims=2)
n = tf.cast(b.scale_row.domain_dimension_tensor(), b.dtype)
p = tf.cast(b.scale_column.domain_dimension_tensor(), b.dtype)
kl_div = (
p * (b.scale_row.log_abs_determinant() - a.scale_row.log_abs_determinant())
+ n * (b.scale_column.log_abs_determinant() - a.scale_column.log_abs_determinant())
- 0.5 * n * p
+ 0.5 * squared_frobenius_norm(KInvS_h) * squared_frobenius_norm(KInvS_x)
+ 0.5
* tf.reduce_sum(
tf.linalg.cholesky_solve(b.scale_column.to_dense(), transpose.forward(Mt))
* transpose.forward(tf.linalg.cholesky_solve(b.scale_row.to_dense(), Mt)),
[-1, -2],
)
)
tensorshape_util.set_shape(
kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape)
)
return kl_div
from sklearn import datasets
import tensorflow
import tensorflow_probability as tfp
tfd = tfp.distributions
# Initialize a single 2 x 500 Matrix Normal.
n_dim = 500
mu = tf.ones((2, n_dim))
col_cov = datasets.make_spd_matrix(n_dim).astype(np.float32)
scale_column = tf.linalg.LinearOperatorLowerTriangular(tf.linalg.cholesky(col_cov))
scale_row = tf.linalg.LinearOperatorDiag([0.9, 0.8])
mvn = tfd.MatrixNormalLinearOperator(
loc=mu,
scale_row=scale_row,
scale_column=scale_column
)
col_cov = datasets.make_spd_matrix(n_dim).astype(np.float32)
scale_column = tf.linalg.LinearOperatorLowerTriangular(tf.linalg.cholesky(col_cov))
scale_row = tf.linalg.LinearOperatorDiag([0.2, 0.4])
mvn2 = tfd.MatrixNormalLinearOperator(
loc=mu,
scale_row=scale_row,
scale_column=scale_column
)
tf.debugging.assert_near(tfd.kl_divergence(mvn, mvn2), _kl_brute_force(mvn, mvn2))
%%timeit
tfd.kl_divergence(mvn, mvn2)
46.9 ms ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
_kl_matrix_normal_matrix_normal_optimized(mvn, mvn2)
11.2 ms ± 4.6 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
If this is welcomed, I could happily open a PR. If you want to maintain the current (and surely more readable) implementation as well, I am not sure though how to give the option to select one or the other implementation when computing the KL divergence.
Hi, contributions welcome! Feel free to open a PR with this!