probability icon indicating copy to clipboard operation
probability copied to clipboard

Optimize `_kl_matrix_normal_matrix_normal` for `tfd.MatrixNormalLinearOperator`

Open fdtomasi opened this issue 4 years ago • 1 comments

Hello! I was glad to notice the introduction of the tfd.MatrixNormal in 0.13 (even if I saw it a bit late, ie today). However, the implementation of the KL divergence still relies on the general KL implementation for mvn_linear_operator's, hence there is a dense operation on the scale parameter (a Kronecker product), which can be very expensive in high dimensions. I did implement a MatrixNormal distribution before the introduction of this, and I implemented the explicit KL divergence when the LinearOperatorKronecker has only two matrices (ie, the case of MatrixNormal). The implementation is the following (based on the _kl_brute_force of MultivariateNormalLinearOperator):

def _kl_matrix_normal_matrix_normal_optimized(a, b, name=None):
    """Batched KL divergence `KL(a || b)` for Matrix Normals.

    Args:
      a: Instance of `MatrixNormalLinearOperator`.
      b: Instance of `MatrixNormalLinearOperator`.
      name: (optional) name to use for created ops. Default "kl_mvn".

    Returns:
      Batchwise `KL(a || b)`.
    """

    def squared_frobenius_norm(x):
        """Helper to make KL calculation slightly more readable."""
        # http://mathworld.wolfram.com/FrobeniusNorm.html
        # The gradient of KL[p,q] is not defined when p==q. The culprit is
        # tf.norm, i.e., we cannot use the commented out code.
        # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1]))
        return tf.reduce_sum(tf.square(x), axis=[-2, -1])

    with tf.name_scope(name or "kl_mn"):
        # Calculation is based on:
        # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
        # and,
        # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
        # i.e.,
        #   If Ca = AA', Cb = BB', then
        #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
        #                  = tr[inv(B) A A' inv(B)']
        #                  = tr[(inv(B) A) (inv(B) A)']
        #                  = sum_{ij} (inv(B) A)_{ij}**2
        #                  = ||inv(B) A||_F**2
        # where ||.||_F is the Frobenius norm and the second equality follows from
        # the cyclic permutation property.
        KInvS_h = b.scale_row.solve(a.scale_row.to_dense())
        KInvS_x = b.scale_column.solve(a.scale_column.to_dense())

        Mt = b.mean() - a.mean()
        transpose = tfb.Transpose(rightmost_transposed_ndims=2)

        n = tf.cast(b.scale_row.domain_dimension_tensor(), b.dtype)
        p = tf.cast(b.scale_column.domain_dimension_tensor(), b.dtype)

        kl_div = (
            p * (b.scale_row.log_abs_determinant() - a.scale_row.log_abs_determinant())
            + n * (b.scale_column.log_abs_determinant() - a.scale_column.log_abs_determinant())
            - 0.5 * n * p
            + 0.5 * squared_frobenius_norm(KInvS_h) * squared_frobenius_norm(KInvS_x)
            + 0.5
            * tf.reduce_sum(
            tf.linalg.cholesky_solve(b.scale_column.to_dense(), transpose.forward(Mt))
            * transpose.forward(tf.linalg.cholesky_solve(b.scale_row.to_dense(), Mt)),
            [-1, -2],
        )
        )
        tensorshape_util.set_shape(
            kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape)
        )
        return kl_div
from sklearn import datasets
import tensorflow
import tensorflow_probability as tfp
tfd = tfp.distributions

# Initialize a single 2 x 500 Matrix Normal.
n_dim = 500
mu = tf.ones((2, n_dim))
col_cov = datasets.make_spd_matrix(n_dim).astype(np.float32)
scale_column = tf.linalg.LinearOperatorLowerTriangular(tf.linalg.cholesky(col_cov))
scale_row = tf.linalg.LinearOperatorDiag([0.9, 0.8])

mvn = tfd.MatrixNormalLinearOperator(
    loc=mu,
    scale_row=scale_row,
    scale_column=scale_column
)

col_cov = datasets.make_spd_matrix(n_dim).astype(np.float32)
scale_column = tf.linalg.LinearOperatorLowerTriangular(tf.linalg.cholesky(col_cov))
scale_row = tf.linalg.LinearOperatorDiag([0.2, 0.4])

mvn2 = tfd.MatrixNormalLinearOperator(
    loc=mu,
    scale_row=scale_row,
    scale_column=scale_column
)

tf.debugging.assert_near(tfd.kl_divergence(mvn, mvn2), _kl_brute_force(mvn, mvn2))
%%timeit
tfd.kl_divergence(mvn, mvn2)

46.9 ms ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%%timeit
_kl_matrix_normal_matrix_normal_optimized(mvn, mvn2)

11.2 ms ± 4.6 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

If this is welcomed, I could happily open a PR. If you want to maintain the current (and surely more readable) implementation as well, I am not sure though how to give the option to select one or the other implementation when computing the KL divergence.

fdtomasi avatar Jan 14 '22 12:01 fdtomasi

Hi, contributions welcome! Feel free to open a PR with this!

srvasude avatar Mar 22 '22 05:03 srvasude