Graphormer icon indicating copy to clipboard operation
Graphormer copied to clipboard

Graphormer module

Open paridhimaheshwari2708 opened this issue 3 years ago • 10 comments

How can I use the graphormer model with custom dataloader and training scripts (not the fairseq commands)? My data consists of DGL graphs and my setup uses DGL GraphConv layers. I want to use Graphormer as a torch.nn.Module (like any other GNN layer) and encode DGL graphs in my setup. How can I use Graphormer model alone and replace DGL GraphConv layers with Graphormer layers?

paridhimaheshwari2708 avatar Aug 11 '22 22:08 paridhimaheshwari2708

Hi! Based on our current implementation, it might take some extra effort to satisfy your need. You can try to wrap up code directly related to Graphormer model (which is scattered among 4~5 files, like graphormer/models/graphormer.py, graphormer/tasks/graph_prediction.py, etc.) into a single file. For this, you can also refer to fairseq's build_model funtion implementation. Then you can import it as a normal python module.

mavisguan avatar Aug 15 '22 06:08 mavisguan

@mavisguan Hi, thank you for your suggestion. Could you also provide more information about the dataloader? I have a custom dataset of DGL graphs and a task-specific sampling that happens in the dataloader. How can I wrap it into the format needed by Graphormer? Specifically, what does the input to graphormer (batched_data in code snippet here) look like?

paridhimaheshwari2708 avatar Aug 17 '22 01:08 paridhimaheshwari2708

@paridhimaheshwari2708 You can use an example DGL dataset (like qm7b) and put a breakpoint before the line of code you've mentioned, and see how batched_data looks like: image I use this debugging toolkit to set breakpoints in big python projects: https://github.com/volltin/vpack. I think it's a very helpful tool for digging into Graphormer's code. We're sorry that our tutorials on customizing datasets is incomplete and a bit ambiguous, and we're working on updating our tutorials, please stay tuned.

mavisguan avatar Aug 17 '22 05:08 mavisguan

@mavisguan Thank you, this is really helpful! I noticed here that DGL graphs are actually converted into PyG graphs. Is that right? If so, is batched_data obtained from torch_geometric.loader.DataLoader or is it a dictionary with the above keys?

paridhimaheshwari2708 avatar Aug 17 '22 16:08 paridhimaheshwari2708

Yes, you're right. I think batched_data is obtained from torch_geometric.loader.DataLoader, which contains batch_size PYG graphs, and it's also a dictionary. image

mavisguan avatar Aug 18 '22 03:08 mavisguan

Yes, you're right. I think batched_data is obtained from torch_geometric.loader.DataLoader, which contains batch_size PYG graphs, and it's also a dictionary. image Sorry to bother you,I am wondering how I can convert my dataset(not the dgl,obg or pyg,but the graph adj)into the type that fits batched_data,thanks a lot.

laowu-code avatar Aug 19 '22 03:08 laowu-code

@laowu-code You can try to convert your dataset into PYG's customized dataset, following their official document https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html, then add this dataset in graphormer/data/pyg_datasets/pyg_dataset_lookup_table.py.

mavisguan avatar Aug 22 '22 01:08 mavisguan

@mavisguan Thanks for your reply,I will try it,it means a lot to me.

laowu-code avatar Aug 22 '22 03:08 laowu-code

@paridhimaheshwari2708 You can use an example DGL dataset (like qm7b) and put a breakpoint before the line of code you've mentioned, and see how batched_data looks like: image I use this debugging toolkit to set breakpoints in big python projects: https://github.com/volltin/vpack. I think it's a very helpful tool for digging into Graphormer's code. We're sorry that our tutorials on customizing datasets is incomplete and a bit ambiguous, and we're working on updating our tutorials, please stay tuned.

Hi mavisguan,

Is it possible to construct this dictionary manually without a dataloader to pass to the model for testing purposes (just to ensure that the modules are setup correctly)? For example:

batched_data = { "idx": , "x": , "attn_bias": , "attn_edge_type": , "spatial_pos": , "degree": , "edge_input": , "y": }

If so, what format is each component? And, does this need to be wrapped as TensorDict, or is it a dictionary of tensors? I just want to load one molecule example to see if I'm doing this correctly - I modified the code to combine "in degree" and "out degree" and remove multi-hop. However, I am becoming confused by this line:

node_feature = self.atom_encoder(x).sum(dim=-2)

I would have expected the 'x' key to be an adjacency matrix, but it appears that this information should include node types for the encoder? But, the same dictionary key is passed to multi-head attention, so I think I'm confusing myself.

BrandenKeck avatar Jun 14 '24 16:06 BrandenKeck

For anyone that comes across this that is also struggling with the data structure - I think I understand now after reading wrapper.py. "batched_data" appears to be a python dictionary with tensors of the following shapes:

  • idx: (batch size) dataset indices
  • x: (batch size, num atoms, num atom features) where the features are atomic number, chirality, etc. (there is an example from torch geometric with these features here )
  • edge_input: (batch size, num atoms, num atoms, max dist. between atoms, num bond features) contains the path information from shortest paths. I'm still trying to understand it, but it can be generated from the code in the repo without a full understanding
  • attn_bias: (batch size, num atoms + 1, num atoms + 1) is initialized to zeros with an extra atom dimension presumably for the virtual node
  • attn_edge_type: (batch size, num atoms, num atoms, num bond features) is simply a list of edge types in somewhat of an adjacency matrix format
  • spacial_pos: (batch size, num atoms, num atoms) is the shortest path length between each atom and every other atom
  • degree: (batch size, num atoms) is the degree of each node in the graph. I modified wrapper.py and the graphformer code to include only "degree" as I am testing this on undirected graphs. The wrapper.py code sets in_degree = out_degree by default if both are included.

BrandenKeck avatar Jun 19 '24 15:06 BrandenKeck