pytorch-struct icon indicating copy to clipboard operation
pytorch-struct copied to clipboard

differentiable samples (rsample)

Open RafaelPo opened this issue 5 years ago • 11 comments

Are there plans to introduce differentiable samples?

Thanks!

RafaelPo avatar Nov 24 '20 23:11 RafaelPo

Yeah... we are trying that out currently actually. There are a lot of different ways to do it with discrete distributions, did you have one in mind?

srush avatar Nov 24 '20 23:11 srush

Hi,

I was thinking of applying results from: https://arxiv.org/pdf/2002.08676.pdf, recursively on the marginals... do you think that would work?

RafaelPo avatar Nov 25 '20 00:11 RafaelPo

Yes I think that would be cool. We have some of the papers referenced in that work already implemented, such as differentiable dynamic programming semiring. But it is not exposed in the api. I'm a bit hesistant to call it rsample, because it is biased. Maybe we should have a separate api function that exposes some of these tricks? If you are interested would be happy for a contribution.

srush avatar Nov 25 '20 03:11 srush

Hi,

here is some code I have been playing with: image image

RafaelPo avatar Nov 25 '20 14:11 RafaelPo

Nice, that is similar in spirit to this code which we have been working on https://github.com/harvardnlp/pytorch-struct/pull/81 .

We can integrate them both in to the library.

There might also be a way to do this by only calling cvxpy many fewer time.

srush avatar Nov 25 '20 14:11 srush

I will have a look thanks!

How could you save on the number of runs?

also, I think they are supposed to be unbiased, no?

RafaelPo avatar Nov 25 '20 14:11 RafaelPo

Very neat. So I think that instead of first computing marginals, we can apply this approach in the backward operation of the semiring itself. This is how I compute unbiased gumbel-max samples (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR70) .

It seems like I can just change this line (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR71) from an argmax to your CVX code to get a differentiable sample? This should work for all models.

Another advantage of this method is that it will batch across n (our internal code does log n steps instead of n for linear chain).

I agree the forward sample is unbiased, but I will have to read the paper to understand if the gradient is unbiased to? (but I believe you).

srush avatar Nov 25 '20 16:11 srush

Hi,

Not sure how this compares to what you guys have been working on, but for what it's worth I have implemented a version of a biased rsample that uses local gumbel perturbations and temperature-controlled marginals (this is the marginal stochastic softmax trick from https://arxiv.org/abs/2006.08063) directly in the StructDistrubution class as:

def rsample(self, sample_shape=torch.Size(), temp=1.0, noise_shape=None, sample_batch_size=10):
        r"""
        Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)`

        NOTE: These samples are biased.

        This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples.
        As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to
        a deterministic distribution that is always uniform over all values.

        The approximation of the zero-temp limit comes from the fact that we use polynomial (instead of exponential)
        perturbations, see:
          [Perturb-and-MAP](https://ttic.uchicago.edu/~gpapan/pubs/confr/PapandreouYuille_PerturbAndMap_ieee-c-iccv11.pdf)
          [Stochastic Softmax Tricks](https://arxiv.org/abs/2006.08063)

        Parameters:
            sample_shape (int): number of samples
            temp (float): (default=1.0) relaxation temperature
            noise_shape (torch.Shape): specify lower-order perturbations by placing ones along any of the potential dims
            sample_batch_size (int): size of batches to calculates samples

        Returns:
            samples (*sample_shape x batch_shape x event_shape*)

        """
        # Sanity checks
        if type(sample_shape) == int:
            nsamples = sample_shape
        else:
            assert len(sample_shape) == 1
            nsamples = sample_shape[0]
        if sample_batch_size > nsamples:
            sample_batch_size = nsamples
        samples = []

        if noise_shape is None:
            noise_shape = self.log_potentials.shape[1:]

        assert len(noise_shape) == len(self.log_potentials.shape[1:])
        assert all(
            s1 == 1 or s1 == s2 for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:])
        ), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}"

        # Sampling
        for k in range(nsamples):
            if k % sample_batch_size == 0:
                shape = self.log_potentials.shape
                B = shape[0]
                s_log_potentials = (
                    self.log_potentials.reshape(1, *shape)
                    .repeat(sample_batch_size, *tuple(1 for _ in shape))
                    .reshape(-1, *shape[1:])
                )

                s_lengths = self.lengths
                if s_lengths is not None:
                    s_shape = s_lengths.shape
                    s_lengths = (
                        s_lengths.reshape(1, *s_shape)
                        .repeat(sample_batch_size, *tuple(1 for _ in s_shape))
                        .reshape(-1, *s_shape[1:])
                    )

                noise = (
                    torch.distributions.Gumbel(0, 1)
                    .sample((sample_batch_size * B, *noise_shape))
                    .expand_as(s_log_potentials)
                ).to(s_log_potentials.device)
                noisy_potentials = (s_log_potentials + noise) / temp

                r_sample = (
                    self._struct(LogSemiring)
                    .marginals(noisy_potentials, s_lengths)
                    .reshape(sample_batch_size, B, *shape[1:])
                )
                samples.append(r_sample)
        return torch.cat(samples, dim=0)[:nsamples]

Let me know if you'd like me to submit as a pr (with whatever changes you think make sense).

Thanks, Tom

teffland avatar Dec 05 '20 19:12 teffland

Awesome sounds like we have three different methods. The one in my PR is from Yao's NeurIPS work https://arxiv.org/abs/2011.14244 which is unbiased forward and biased backward. Maybe we should have a phone call and figure out the differences and how to document and compare them.

srush avatar Dec 05 '20 20:12 srush

Very interesting, I'll take a look at the paper -- unbiased forward sounds like a big plus. I'm available for a call to discuss pretty much whenever.

teffland avatar Dec 07 '20 14:12 teffland

Not sure how what I proposed compares to the rest, it seems (way) more computationally expensive but I would be interested in a call as well, but I am based in England.

RafaelPo avatar Dec 07 '20 15:12 RafaelPo