git-re-basin icon indicating copy to clipboard operation
git-re-basin copied to clipboard

Cost Matrix Computation in Weight Matching

Open frallebini opened this issue 3 years ago • 10 comments

Hi, I read the paper and I am having a really hard time reconciling the formula

weight_matching

with the actual computation of the cost matrix for the LAP in weight_matching.py, namely

A = jnp.zeros((n, n))
for wk, axis in ps.perm_to_axes[p]:
  w_a = params_a[wk]
  w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
  w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
  w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
  A += w_a @ w_b.T

Are you following a different mathematical derivation or am I missing something?

frallebini avatar Oct 29 '22 18:10 frallebini

Hi @frallebini! The writeup in the paper is for the special case of an MLP with no bias terms -- the version in the code is just more general. The connection here is that there's a sum over all weight arrays that interact with that P_\ell. Then for each one, we need to apply its relevant permutations on all other axis, take the Frobenius inner product with the reference model, and all those terms together. So A represents that sum, each for loop iterations adds a single term in to the sum, get_permuted_param applies the other (non-P_\ell) permutations to w_b, and the moveaxis-reshape-matmul corresponds to the Frobenius inner product with w_a.

samuela avatar Oct 29 '22 18:10 samuela

Thanks @samuela, I understand that the code is a generalization of the MLP with no bias case, but still:

  1. If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?
  2. How does get_permuted_param "skip" the non-P_\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?

frallebini avatar Oct 29 '22 20:10 frallebini

If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?

Ack, you're right! I messed up: it's not actually a Frobenius inner product, just a regular matrix product. The moveaxis-reshape combo is necessary to flatten dimensions that we don't care about in the case of non-2d weight arrays.

How does get_permuted_param "skip" the non-P_\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?

Yup, that's exactly what except_axis is doing. But I think you may have it backwards -- except_axis is excepting the P_\ell axis but applying all other fixed P's to all the other axes.

samuela avatar Oct 29 '22 20:10 samuela

Ok, but let us consider the MLP-with-no bias case. The way the paper models weight matching as an LAP is

weight_matching_complete

In other words, it computes A as

paper (1)

What the code does, instead—if I understood correctly—is computing A by

  1. Permuting w_b disregarding P_\ell
  2. Transposing it
  3. Multiplying w_a by it

In other words

code (2)

I don't think (1) and (2) are the same thing though.

frallebini avatar Oct 29 '22 22:10 frallebini

Hmm I think the error here is in the first line of (2): The shapes here don't line up since $W_\ell^A$ has shape (n, *) and $W_{\ell+1}^A$ has shape (*, n). So adding those things together will result in a shape error if your layers have different widths.

I think tracing out the code for the MLP without bias terms case is a good idea. In that case we run through the for wk, axis in ps.perm_to_axes[p]: loop two times: once for $W_\ell$ and once for $W_{\ell+1}$.

  • For $W_\ell$: First of all, axis=0 since $W_\ell$ has shape (n, *). Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) will give us $W_\ell^B P_{\ell-1}^T$. In other words, $W_\ell^B$ but with the other permutations -- $P_{\ell-1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1)) will be a no-op since axis = 0. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1)) will also be a no-op. So, w_a @ w_b.T is $W_\ell^A (W_\ell^B P_{\ell-1}^T)^T$ matches up with the first term in the sum.
  • For $W_{\ell+1}$: In this case axis = 1 since $W_{\ell+1}$ has shape (*, n). Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) will give us $P_{\ell+1} W_{\ell+1}^B$. In other words, $W_{\ell+1}^B$ but with the other permutations -- $P_{\ell+1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1)) will result in a transpose, aka $(W_{\ell+1}^A)^T$. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1)) will also result in a transpose, aka $(W_{\ell+1}^B)^T P_{\ell+1}^T$. So, w_a @ w_b.T matches up with the second term in the sum.

samuela avatar Nov 01 '22 00:11 samuela

Ok, the role of moveaxis is clear, and the computation matches the formula in the paper for an MLP with no biases.

On the other hand, the reshape((n, -1)) (extending the reasoning to the presence of biases):

  • Is always a no-op for weight matrices—as n is either the number of rows of $W_\ell$ or it is the number of columns of $W_{\ell+1}$, which however has already been transposed by the moveaxis.
  • It is needed in order to transform the (n,) bias vectors into (n, 1) vectors so that w_a @ w_b.T is a (n, n) matrix which can be added to A.

Right?

frallebini avatar Nov 01 '22 17:11 frallebini

That's correct! In addition, it's necessary when dealing weight arrays of higher shapes as well, eg in a convolutional layer where the weights have shape (w, h, channel_in, channel_out).

samuela avatar Nov 01 '22 23:11 samuela

Hi, I read the code and I really did not understand the following snippet. Because It relates to the weight matching algorithm, so I post here. In the line 199 weight_matching.py:

perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

According to the above line, if W_\ell has shape [m, n] (m is output feature dim, n is input feature dim) in the Dense layer, then the shape of the permutation matrix P_\ell will be [n, n]. But when I read the paper, I think it should be [m, m].

Sorry for the silly question, but might you explain? @samuela @frallebini

Thank you!

LeCongThuong avatar Dec 09 '22 16:12 LeCongThuong

Hi @LeCongThuong, ps.perm_to_axes is a dict of form PermutationId => [(ParamId, Axis), ...] where in this case PermutationIds are strings, ParamIds are also strings, and Axiss are integers. So for example in an MLP (without bias and assuming that weights have shape [out_dim, in_dim]) terms this dict would look something like

{ "P_5": [("Dense_5/kernel", 0), ("Dense_6/kernel", 1)], ... }

Therefore, axes[0][0] will be something like "Dense_0/kernel" and axes[0][1] will be 0. HTH!

samuela avatar Dec 09 '22 20:12 samuela

Thank you so much for replying @samuela!

I tried to understand ps.perm_to_axes and got the meaning of Axis. Axis, from what I got from your comment, it will let us know to permute W_b to another axis than "Axis''. Following your above example, I think it should be

{ "P_5": [("Dense_5/kernel", 1), ("Dense_6/kernel", 0)], ... }

From that axes[0][1] will be 1, thus the shape of P_l will be [n, n].

Thank you again for replying to my question.

LeCongThuong avatar Dec 10 '22 02:12 LeCongThuong