NetRAX icon indicating copy to clipboard operation
NetRAX copied to clipboard

Derivative of Network Loglikelihood for brlen optimization

Open lutteropp opened this issue 4 years ago • 44 comments

(from https://github.com/lutteropp/NetRAX/issues/43)

Use Newton-Raphson instead of Brent for optimizing branches (this will give us some niiiiiice additional speedup). This requires some Maths now, as I need to understand how to compute the derivative of network loglikelihood. Given the derivatives of the displayed tree loglikelihoods, this should be easy. I then need to understand how pll_modules computes likelihood derivatives for trees (something something sumtable), and how this can be generalized to networks.

Luckily, given the derivatives of displayed tree loglikelihoods, the derivative of the network loglikelihood is easy to compute (little mistake in the image, but not relevant here - the displayed tree loglikelihoods of course depend on the partition, also on brlens, model parameters, etc):

(EDIT: THIS IS WRONG, SEE LATER COMMENT) Unbenannte Notiz - 14 03 2021 14 09 - Seite 1

Thus the main challenge here is to understand these sumtables used for tree loglh derivatives in pll_modules, being especially careful about how this works with virtual rerooting. And then I need to do this for the displayed trees displayed by the network and voilà, then we can switch from Brent brlen-opt to Newton-Raphson brlen-opt! :-)

lutteropp avatar Mar 14 '21 13:03 lutteropp

One tricky thing: There can be displayed trees for which the current branch we want to optimize is inactive. What happens to their sumtables? Since here we care about derivatives, we can safely skip those trees, not compute sumtables for them, and the derivative and second derivative of these tree loglikelihoods is zero.

There is even some more CLV saving potential in this (to be explored if later enough runtime gets spent in CLV updates again): If we can safely skip a tree, then we don't need to update its CLVs after virtual rerooting.

lutteropp avatar Mar 15 '21 11:03 lutteropp

I made a mistake in the partition loglikelihood formulas above. New image coming soon.

lutteropp avatar Mar 15 '21 13:03 lutteropp

Do we really need the correct derivative of phylogenetic network likelihood? It is apparently pretty ugly: https://en.wikipedia.org/wiki/Product_rule#Product_of_more_than_two_factors

lutteropp avatar Mar 15 '21 14:03 lutteropp

Now I did the Math correctly, and it does not look good for Newton-Raphson on phylogenetic networks :-( We can still do it as it is known to converge quicker, but it does not help in getting rid of pll_compute_edge_loglikelihood. Unbenannte Notiz - 15 03 2021 16 28 - Seite 3

lutteropp avatar Mar 15 '21 15:03 lutteropp

Or maybe we can still do something here? I need to dig into how the Math is done for derivative of tree likelihood. After all, there is a product involved there, too (the product over per-site likelihoods ... plus, also a product over partitions).

lutteropp avatar Mar 15 '21 15:03 lutteropp

Whichever trick is used for trees there should also work for the networks. Because with the approach I used here, I get into the same problem if there is only a single displayed tree. There has to be some further trick involved!

lutteropp avatar Mar 15 '21 15:03 lutteropp

Okay... now I understand the text in Alexeys thesis (https://cme.h-its.org/exelixis/pubs/dissAlexey.pdf, page 26). What this sumtable gives us is fast computation of:

  • L(b)
  • L'(b)
  • L''(b)

My problem was that I thought it only gives L'(b) and L''(b)

lutteropp avatar Mar 15 '21 15:03 lutteropp

My confusion originated from this: For trees, we end up getting and using these:

  • (loglL(T))' -> first derivative of loglikelihood of tree T
  • (logL(T))'' -> second derivative of loglikelihood of tree T

But what the trick with the sumtable actually gives us are fast computations of (with respect to branch b):

  • L(b) -> persite likelihood of tree T
  • L'(b) -> first derivative of persite likelihood of tree T
  • L''(b) -> second derivative of persite likelihood of tree T

Instead of directly reusing the final result ((loglL(T))' and (loglL(T))'') from the trees case, I need to plug these faster-computed tree persite likelihoods and their derivatives into the network loglikelihood and loglikelihood derivative formulas.

lutteropp avatar Mar 15 '21 15:03 lutteropp

I got it all together now. The main difference (compared to the tree case) is that in order to compute network loglikelihood derivatives, we also need to use the sumtables to compute the displayed tree loglikelihoods. Only the tree loglikelihood derivatives are not enough for us.

Unbenannte Notiz - 15 03 2021 17 41 - Seite 5

lutteropp avatar Mar 15 '21 16:03 lutteropp

Wait... do I really need the fully computed displayed tree loglikelihoods here?

For LikelihoodModel.BEST, we can do without explicitly computing them. Because there we essentially have one tree per partition... For LikelihoodModel.AVERAGE, I currently see no way around actually computing the displayed tree loglikelihoods... But I am still not convinced that they are really needed.

--> Looks like there is no way around actually inserting the displayed tree persite loglikelihoods (and their derivatives) into the formulas for the networks case. Thus, another drawing will follow soon here. :-)

lutteropp avatar Mar 15 '21 16:03 lutteropp

Turns out that for networks, we can't go down to a per-site computation basis if we have more than one displayed tree.

  • I am convinced about this in the LikelihoodType.AVERAGE case.
  • For the LikelihoodType.BEST case it depends on whether max{f', g'} / max{f, g} = max{f'/f, g'/g} always holds or not.

Unbenannte Notiz - 15 03 2021 18 33 - Seite 6

lutteropp avatar Mar 15 '21 17:03 lutteropp

Okay, found a counterexample (for x > 0):

  • f(x) = 2x + 3
  • g(x) = 2x + 5
  • f'(x) = 2
  • g'(x) = 2
  • max{f'(x), g'(x)} = 2
  • max {f(x), g(x)} = 2x+5
  • max{f'(x), g'(x)} / max{f(x), g(x)} = 2/(2x+5)
  • max {f'/f, g'/g} = max{2/(2x+3), 2/(2x+5)} = 2/(2x+3)

----> max{f', g'} / max{f, g} != max{f'/f, g'/g}

---> Thus, no matter which likelihood model (out of AVERAGE and BEST) we are using, as soon as we have more than one displayed tree, we cannot profit from going down to the persite-level.

lutteropp avatar Mar 15 '21 17:03 lutteropp

This makes it clear that for networks, we need to compute all three of logL(T|P), (logL(T|P))', (logL(T|P))'' with help of the sumtable.

The only question remains whether it makes sense to go down to a per-site level in case the network has only one displayed tree. Special treatment for the single-tree case has led to numerical quirks when computing network loglikelihood before (there, while mathematically the same, the different operations (like, adding a log(exp()) around it or not) led to a slight numerical difference that confused BIC).

lutteropp avatar Mar 15 '21 17:03 lutteropp

Computation of second network loglikelihood derivative caused numerical problems (underflow/overflow). Thus, I had to play a bit with the formula: Unbenanntes Notizbuch (19)-1

lutteropp avatar Mar 27 '21 16:03 lutteropp

Damn the Newton Raphson in pll-modules is awful, also the second network loglikelihood derivative is such a high number that it does not fit into a double. Luckily, what Newton-Raphson actually needs is just the quotient of first and second derivative, as well as the sign of the second derivative...

I'll just switch to implementing vanilla Newton-Raphson myself, using https://www.cup.uni-muenchen.de/ch/compchem/geom/nr.html and Alexis lecture slides ...

lutteropp avatar Mar 27 '21 16:03 lutteropp

Or no wait, maybe I can adapt the pll_modules implementation to directly take the quotient as argument...

lutteropp avatar Mar 27 '21 17:03 lutteropp

Or no wait, maybe I can adapt the pll_modules implementation to directly take the quotient as argument...

I tried this (passing the quotient and trimming too large and too small values before passing them to the pll-modules NR implementation), but it made NR never find a better branch length.

Switching to writing my own NR implementation now.

lutteropp avatar Mar 27 '21 17:03 lutteropp

Mhm. It also happens with my own NR implementation... maybe the derivatives are still wrong?

lutteropp avatar Mar 27 '21 17:03 lutteropp

I checked how Newton-Raphson decides on the next branch length to propose: It essentially subtracts first_logl_derivative/second_logl_derivative from the previous branch length. With these values I get for network first and second logl derivative (computed on a network with 0 reticulations), NR cannot work:

Starting with network logl: -771.9593909 proposing brlen: 0.067574

Network partition likelihood derivatives for partition 0: partition_lh: 1.971781113e-182 partition_lh_prime: 6.673864004e-06 partition_lh_prime_prime: 4.482882894e+181 partition_logl: -418.3915497 partition_logl_prime: 3.384688068e+176 partition_logl_prime_prime: 2.273519543e+363

Network partition likelihood derivatives for partition 1: partition_lh: 2.80180277e-154 partition_lh_prime: 322.845173 partition_lh_prime_prime: 6.284610245e+108 partition_logl: -353.5678413 partition_logl_prime: 1.152276586e+156 partition_logl_prime_prime: -1.32774133e+312

Network loglikelihood derivatives: network_logl_prime: 3.384688068e+176 network_logl_prime_prime: 2.273519543e+363 network_logl_prime / network_logl_prime_prime: 1.488743775e-187

Maybe something is wrong with my network loglikelihood derivative formula?

lutteropp avatar Mar 27 '21 17:03 lutteropp

I also printed the loglikelihood and its derivatives for the single displayed tree of the network:

For partition 0: tree_logl: -418.3915497 tree_logl_prime: -11.91731155 tree_logl_prime_prime: 418.2681682

For partition 1: tree_logl: -353.5678413 tree_logl_prime: 5.777172868 tree_logl_prime_prime: 250.5172939

lutteropp avatar Mar 27 '21 17:03 lutteropp

From looking at these numbers, I infer that something must have gone wrong when computing partition_logl_prime and partition_logl_prime_prime.

lutteropp avatar Mar 27 '21 18:03 lutteropp

Formulas I used (same here for LikelihoodModel.AVERAGE and LikelihoodModel.BEST, as there is only 1 displayed tree with tree_prob 1.0):

partition_lh = exp(tree_logl) * tree_prob
partition_lh_prime = exp(tree_logl_prime) * tree_prob
partition_lh_prime_prime = exp(tree_logl_prime_prime) * tree_prob

partition_logl = log(partition_lh)
partition_logl_prime = (partition_lh_prime / partition_lh)
partition_logl_prime_prime = exp(log(partition_lh_prime_prime) - log(partition_lh)) - exp(2*log(partition_lh_prime) - 2*log(partition_lh))

Of course, in this case partition_logl_prime should be the same as tree_logl_prime. And partition_logl_prime_prime should be the same as tree_logl_prime_prime. The question is: Why isn't this the case?

lutteropp avatar Mar 27 '21 18:03 lutteropp

Maybe I misunderstood these formulas from Alexeys PhD thesis, which I used for deriving the partition_logl_prime and partition_logl_prime_prime formulas?

Screenshot from 2021-03-27 19-14-11

lutteropp avatar Mar 27 '21 18:03 lutteropp

I got it now: My mistake was that I derived the log. Instead, one should see it as a function g. Example:

Partition loglh, 2 displayed trees, LikelihoodModel.AVERAGE: logl(P) = logl(T_1|P) * p(T_1) + logl(T_2|P) * p(T_2)

Now, think as if we already have tree loglikelihoods. Instead of logl, we write g, which gives us:

g(P) = g(T_1|P) * p(T_1) + g(T_2|P) * p(T_2)

And then, we easily see how to derive it: g'(P) = g'(T_1|P) * p(T_1) + g'(T_2|P) * p(T_2).

---> We can compute the partition loglikelihood derivatives out of the displayed tree loglikelihood derivatives.

lutteropp avatar Mar 27 '21 18:03 lutteropp

Mhm. But if I follow this, then the network loglikeihood derivative is simply the sum of the partition loglikelihood derivatives... I tried this as well, but this als didn't work out with Newton Raphson.

lutteropp avatar Mar 27 '21 18:03 lutteropp

It looks much better now though. Now the error is "Exceeded maximum number of iterations" and the proposed branch lengths appear at least somewhat reasonable...

lutteropp avatar Mar 27 '21 19:03 lutteropp

But if this is the correct network loglikelihood derivatives definition now (-> sum of partition loglikelihood derivatives, computed out of displayed tree loglikelihood derivatives), then we don't even need to compute tree_logl out of the sumtable anymore...

I'm still leaving it in as optional, as this allows for BRENT with sumtable combination.

lutteropp avatar Mar 27 '21 19:03 lutteropp

OK it looks like I got Newton-Raphson Branch-Length optimization for networks to work. It is lots of dark voodoo though!!!

lutteropp avatar Mar 27 '21 19:03 lutteropp

Posting the correct maths for network loglikelihood derivatives here soon. (I'm still on vacation, after all)

lutteropp avatar Mar 27 '21 19:03 lutteropp

Mhmm... no. Now NR only works for zero-reticulation networks. The network loglikelihood derivatives are still wrong if there is more than one displayed tree.

lutteropp avatar Mar 27 '21 20:03 lutteropp