pytorch_coma
pytorch_coma copied to clipboard
Not using real ChebConv?
I noticed there is no self-loop during propagation in ChebConv_Coma, therefore the code is not using the "real" laplacian matrix? Is this behavior consistent with the tensorflow version?
minimum code to reproduce the issue:
"""
Aimed the check the dense chebconv is consistent with sparse chebconv
"""
import torch
from layers import ChebConv_Coma, DenseChebConv
from psbody.mesh import Mesh
import mesh_operations
from main import scipy_to_torch_sparse
import numpy as np
in_channels = 1
out_channels = 1
batch_size = 2
K = 2
m = Mesh(filename='../template/template.obj')
# m.v = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
# m.f = np.array([[0, 1, 2]])
sparse_adj = mesh_operations.get_vert_connectivity(m.v, m.f).tocoo()
sparse_adj = scipy_to_torch_sparse(sparse_adj)
dense_adj = (sparse_adj / 2).to_dense()
n_vertex = m.v.shape[0]
x = torch.randn(batch_size, n_vertex, in_channels)
# x = torch.tensor([[[1], [2], [3]], [[1], [2], [3]]], dtype=torch.float32)
sparse_conv = ChebConv_Coma(in_channels, out_channels, K)
dense_conv = DenseChebConv(in_channels, out_channels, K)
dense_conv.weight = sparse_conv.weight
dense_conv.bias = sparse_conv.bias
edge_index, edge_norm = ChebConv_Coma.norm(sparse_adj._indices(), n_vertex)
y_sparse = sparse_conv(x, edge_index, edge_norm)
y_dense = dense_conv(x, dense_adj)
print(torch.isclose(y_sparse, y_dense))
layers.py
...
class DenseChebConv(nn.Module):
def __init__(self, in_channels, out_channels, K, bias=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cheb_order = K
self.weight = nn.Parameter(torch.Tensor(K, in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
normal(self.weight, 0, 0.1)
normal(self.bias, 0, 0.1)
def forward(self, x, adj):
"""
:param x: batch x V x in_C
:param adj: V x V, may be sparse, float32!!
:return: batch x V x out_C
"""
d_vec = torch.sum(adj, dim=1)
# D = torch.diag(d_vec) # column wise add
inv_sqrt_d = d_vec.pow(-1 / 2)
inv_sqrt_D = torch.diag(inv_sqrt_d)
L = inv_sqrt_D @ -adj @ inv_sqrt_D # fixme: low efficiency TODO: not real D_sym?
Tx_0 = x
out = torch.matmul(x, self.weight[0])
if self.weight.size(0) > 1:
Tx_1 = L @ Tx_0
out += torch.matmul(Tx_1, self.weight[1])
for i in range(2, self.weight.size(0)):
Tx_2 = 2 * L @ Tx_1 - Tx_0
out += torch.matmul(Tx_2, self.weight[i])
Tx_0, Tx_1 = Tx_1, Tx_2
if self.bias is not None:
out += self.bias
return out
When saying laplacian matrix I mean that as shown in wikipedia