GPflow icon indicating copy to clipboard operation
GPflow copied to clipboard

Kullback–Leibler divergence using TF distributions

Open vdutor opened this issue 7 years ago • 14 comments

Hi,

I've re-implemented GPflow's gauss_kl method using the tensorflow_probability.distributions.kl_divergence method. It looks as follows:

import tensorflow as tf
import tensorflow_probability as tfp

def gauss_kl_tfp_distributions(q_mu, q_sqrt, K=None):
    """
    Kullback-Leibler divergence KL[q(U) || p(U)],
    with q(U) ~ N(q_mu, q_sqrt^2) the variational Gaussian posterior
    and p(U) ~ N(0, K) the prior. If K is None we assume a whitened 
    prior, p(U) ~ N(0, I).
    
    We assume L independent distributions, then
    
    :param q_mu: L variational means, M x L
    :param q_sqrt: L variational covariances, 
                - cholesky: L x M x M or
                - diag elements: M x L
    :param K (Kuu): M x M or L x M x M
    """
    q_mu = tf.linalg.matrix_transpose(q_mu)  # L x M
    if K is None:
        prior = tfp.distributions.MultivariateNormalDiag(loc=tf.zeros_like(q_mu), scale_diag=tf.ones_like(q_mu))
        # alternatively:
        # identity = tf.linalg.LinearOperatorIdentity(tf.shape(q_mu)[-1], dtype=q_mu.dtype)
        # prior = tfp.distributions.MultivariateNormalLinearOperator(loc=tf.zeros_like(q_mu), scale=identity)
    else:
        prior = tfp.distributions.MultivariateNormalTriL(
            loc=tf.zeros_like(q_mu), scale_tril=tf.linalg.cholesky(K)
        )
    
    if q_sqrt.shape.ndims == 2:  # M x L
        q_sqrt_T = tf.linalg.matrix_transpose(q_sqrt)
        posterior = tfp.distributions.MultivariateNormalDiag(loc=q_mu, scale_diag=q_sqrt_T)

    elif q_sqrt.shape.ndims == 3:  # L x M x M
        posterior = tfp.distributions.MultivariateNormalTriL(loc=q_mu, scale_tril=q_sqrt)
    
    return tf.reduce_sum(posterior.kl_divergence(prior))

I think this implementation is more readable, but unfortunately it is slower. Here is a comparison of execution time:

M = 750 / L = 10 q_sqrt q_mu TF GPFLOW
whitened M x L M x L 410 µs ± 61.7 µs 374 µs ± 58.6 µs
L x M x M M x L 395 µs ± 68.2 µs 310 µs ± 41.8 µs
unwhitened K: M x M M x L M x L 109 ms ± 6.44 ms 1.01 ms ± 49.7 µs
L x M x M M x L 2.93 ms ± 103 µs 0.86 ms ± 77.5 µs
unwhitened K: L x M x M M x L M x L 107 ms ± 3.37 ms
L x M x M M x L 2.95 ms ± 38.2 µs

As long as we use the whitened representation we achieve more or less the same computation time. In the other cases, however, the delay is not neglectable. It is caused by the fact that tf.distributions don't broadcast as we would expect, which forces us to tile covariance matrices.

I'll try to resolve this issue and report back.

vdutor avatar Mar 08 '18 11:03 vdutor

@dustinvtran is there a neat way to make the kl-divergences defined in tfd broadcast nicely? It's annoying that the tile op is killing us.

jameshensman avatar Mar 08 '18 12:03 jameshensman

Thanks for pinging me. That's a great question. My understanding is because you'd like to broadcast K along a batch dimension, that should work without having to tile. (At least, this is according to the docstrings (https://www.tensorflow.org/api_docs/python/tf/distributions/Distribution).)

@jvdillon, @langmore may be able to say more?

dustinvtran avatar Mar 08 '18 21:03 dustinvtran

That's right. Needing tf.tile is actually fairly rare. This is good because using it is expensive. And if things aren't broadcasting then its more likely we have a bug than it is you doing something wrong.

jvdillon avatar Mar 08 '18 22:03 jvdillon

Hi @jvdillon, @dustinvtran. Thanks for your response.

I've tried broadcasting over the batch (first dimension) before and it doesn't seem to behave as I would expect. Please consider the following minimal pathological example:

import tensorflow as tf
tf.__version__   # 1.5.0
import tensorflow.contrib.distributions as tfd
import numpy as np

L, M = 10, 200
q_mu_posterior = tf.random_normal((L, M))  # L x M
q_sqrt = tf.tile(tf.eye(M)[None, ...], [L, 1, 1])  # L x M x M
K = tf.eye(M)[None, ...]  # 1 x M x M
q_mu_prior = tf.zeros((1, M))  # 1 x M

prior = tfd.MultivariateNormalFullCovariance(loc=q_mu_prior, covariance_matrix=K)  # N(0, K)
posterior = tfd.MultivariateNormalTriL(loc=q_mu_posterior, scale_tril=q_sqrt) # L times N(q_mu, q_sqrt^2)

posterior.kl_divergence(prior) 

I think this should return L KL-divergences, i.e. KL[posterior_i(u) || prior(u)] for i = 1..L, but this error is thrown:

InvalidArgumentError: Dimension 0 in both shapes must be equal, but are 1 and 10. 
for MatrixTriangularSolve' (op: 'MatrixTriangularSolve') with input shapes: [1,200,200], [10,200,200].

Looks like this boils down to TF issue 216 (https://github.com/tensorflow/tensorflow/issues/216)?

vdutor avatar Mar 09 '18 08:03 vdutor

Hi all, One of our internal 20%-ers Anudhyan has just added broadcasting support in the matmul (thanks again for the awesome work!): https://github.com/tensorflow/tensorflow/commit/47ab68d265a96b6e7be06afd1b4b47e0114c0ee9

What this will surface as a change to matmul in a few weeks (4/18) where broadcasting will be done by default (and we'll need to make sure internal users don't get broken by this). We'll need to also change our version of tf.linalg.LinearOperator(s) to use the broadcasted matmul, and most things in TFP should just benefit from the broadcasting :).

.solve and .matrix_triangular_solve will still need changes.

srvasude avatar Apr 04 '19 03:04 srvasude

Hi @srvasude Thanks for keeping us up to date. We (GPflow devs) are really happy with this development

vdutor avatar Apr 04 '19 09:04 vdutor

The date for compatibility was delayed (to today). I am in the process of submitting a PR to make tf.linalg.LinearOperator use this new broadcasted op (so I expect people can take advantage of this starting some time next week).

If things work out, I should be doing the same thing to matrix_triangular_solve.

srvasude avatar Apr 27 '19 01:04 srvasude

Hi all, I recently added broadcasting triangular solve to TF, and switched LinearOperator to use this.

srvasude avatar Jan 28 '20 16:01 srvasude

@srvasude, can you give a link to the code?

awav avatar Jan 28 '20 17:01 awav

Broadcasting triangular solve: https://github.com/tensorflow/tensorflow/commit/b105944eb6c563849a085a1765d6700ee2c0f35c

Using this in LinearOperator and a few other places: https://github.com/tensorflow/tensorflow/commit/0897278b482e7f015ac3cbc98b60eb0ed0f9bede

srvasude avatar Jan 28 '20 17:01 srvasude

Thanks a lot @srvasude. Probably, I can close this as well https://github.com/tensorflow/tensorflow/issues/26204.

awav avatar Jan 28 '20 17:01 awav

Yep that sounds good. Let me know if you see performance / memory issues. I can try to see if I can improve on this (I'll also some point in the future try to get broadcasting to work for matrix_solve and a few others, but it's not a high priority for me).

srvasude avatar Jan 28 '20 19:01 srvasude

Hi, I am facing a very strange issue. I am facing cholseky errors at some places in the code and No errors on replacing gpflow gauss_kl with kl using tf distributions.

Any clue what could be the reason?

ayush29 avatar May 30 '20 05:05 ayush29

tfp.distributions now broadcasts successfully (I've updated the example in the issue description to work with TF 2.3/tfp 0.12), but unfortunately for the unwhitened case (i.e., K is not None), our implementation is still a factor ~3x faster than TFP. Would be nice to eventually be able to retire maintaining our code...

(NB- gpflow does support all combinations of K=None/[M, M] tensor/[L, M, M] tensor)

st-- avatar Oct 05 '20 14:10 st--