STGraph icon indicating copy to clipboard operation
STGraph copied to clipboard

Seastar - RGCN

Open JoelMathewC opened this issue 2 years ago • 4 comments

Seastar's original implementation does not present a vertex centric program for RGCN, it rather uses a handwritten kernel in dgl-hack. Let's try to write a vertex-centric program for RGCN, this issue will track any issues faced along the way.

Details about the original implementation of RGCN can be found here.

JoelMathewC avatar Feb 22 '23 05:02 JoelMathewC

Trying to run train.py by using dgl.RelGraphConv to see if the train file can execute without dgl-hack. Turns out there is a custom function add_edges_with_type that was implemented. Additionally had to make some modifications to input tensor sizes but I got RGCN to run successfully on DGL.

The training log for DGL-RGCN on aifb dataset is

Train Accuracy: 0.2857 | Train Loss: 48903460.0000 | Validation Accuracy: 0.3214 | Validation loss: 56560608.0000
Train Accuracy: 0.2857 | Train Loss: 48815928.0000 | Validation Accuracy: 0.3214 | Validation loss: 56243768.0000
Train Accuracy: 0.2857 | Train Loss: 47918792.0000 | Validation Accuracy: 0.3214 | Validation loss: 54764828.0000
Epoch 00003 | Train Forward Time(s) 0.0026 | Backward Time(s) 0.0017
Train Accuracy: 0.2946 | Train Loss: 47904824.0000 | Validation Accuracy: 0.3214 | Validation loss: 54428792.0000
Epoch 00004 | Train Forward Time(s) 0.0026 | Backward Time(s) 0.0018
Train Accuracy: 0.2946 | Train Loss: 48096584.0000 | Validation Accuracy: 0.3214 | Validation loss: 54330104.0000
Epoch 00005 | Train Forward Time(s) 0.0029 | Backward Time(s) 0.0022
Train Accuracy: 0.3036 | Train Loss: 48391628.0000 | Validation Accuracy: 0.3214 | Validation loss: 54448644.0000
Epoch 00006 | Train Forward Time(s) 0.0029 | Backward Time(s) 0.0020
Train Accuracy: 0.3036 | Train Loss: 48688448.0000 | Validation Accuracy: 0.3214 | Validation loss: 54529480.0000
Epoch 00007 | Train Forward Time(s) 0.0028 | Backward Time(s) 0.0021
Train Accuracy: 0.3125 | Train Loss: 48979032.0000 | Validation Accuracy: 0.3214 | Validation loss: 54579388.0000
Epoch 00008 | Train Forward Time(s) 0.0027 | Backward Time(s) 0.0018
Train Accuracy: 0.3125 | Train Loss: 49252564.0000 | Validation Accuracy: 0.3214 | Validation loss: 54631652.0000
Epoch 00009 | Train Forward Time(s) 0.0028 | Backward Time(s) 0.0018
Train Accuracy: 0.3125 | Train Loss: 49504344.0000 | Validation Accuracy: 0.3214 | Validation loss: 54682868.0000
Epoch 00010 | Train Forward Time(s) 0.0030 | Backward Time(s) 0.0020
Train Accuracy: 0.3125 | Train Loss: 49736864.0000 | Validation Accuracy: 0.3214 | Validation loss: 54717436.0000

However it seems like the model is not training.

JoelMathewC avatar Feb 22 '23 09:02 JoelMathewC

Fixed the DGL code on the aifb dataset. This dataset does not have node features so instead we label each node and then assign it a random feature from a dictionary of features as given torch.nn.Embedding. The task is to predict the type of node.

The updated training log is

Epoch 00000 | Train Accuracy: 0.3393 | Train Loss: 1.3365 | Validation Accuracy: 0.3214 | Validation loss: 1.3165
Epoch 00001 | Train Accuracy: 0.3929 | Train Loss: 1.2702 | Validation Accuracy: 0.3929 | Validation loss: 1.2714
Epoch 00002 | Train Accuracy: 0.4375 | Train Loss: 1.2172 | Validation Accuracy: 0.4286 | Validation loss: 1.2256
Epoch 00003 | Train Accuracy: 0.6429 | Train Loss: 1.1597 | Validation Accuracy: 0.6786 | Validation loss: 1.1632
Epoch 00004 | Train Accuracy: 0.6964 | Train Loss: 1.1040 | Validation Accuracy: 0.7500 | Validation loss: 1.0950
Epoch 00005 | Train Accuracy: 0.7321 | Train Loss: 1.0607 | Validation Accuracy: 0.7500 | Validation loss: 1.0387
Epoch 00006 | Train Accuracy: 0.7768 | Train Loss: 1.0308 | Validation Accuracy: 0.7857 | Validation loss: 1.0008
Epoch 00007 | Train Accuracy: 0.8036 | Train Loss: 1.0075 | Validation Accuracy: 0.7500 | Validation loss: 0.9773
Epoch 00008 | Train Accuracy: 0.8125 | Train Loss: 0.9861 | Validation Accuracy: 0.8214 | Validation loss: 0.9633
Epoch 00009 | Train Accuracy: 0.8214 | Train Loss: 0.9651 | Validation Accuracy: 0.7857 | Validation loss: 0.9560
Epoch 00010 | Train Accuracy: 0.8571 | Train Loss: 0.9454 | Validation Accuracy: 0.7857 | Validation loss: 0.9532
Epoch 00011 | Train Accuracy: 0.8571 | Train Loss: 0.9280 | Validation Accuracy: 0.7500 | Validation loss: 0.9518
Epoch 00012 | Train Accuracy: 0.8661 | Train Loss: 0.9129 | Validation Accuracy: 0.7500 | Validation loss: 0.9480
Epoch 00013 | Train Accuracy: 0.8750 | Train Loss: 0.8998 | Validation Accuracy: 0.8214 | Validation loss: 0.9403
Epoch 00014 | Train Accuracy: 0.8750 | Train Loss: 0.8888 | Validation Accuracy: 0.8571 | Validation loss: 0.9298
Epoch 00015 | Train Accuracy: 0.8839 | Train Loss: 0.8798 | Validation Accuracy: 0.8929 | Validation loss: 0.9183
Epoch 00016 | Train Accuracy: 0.8839 | Train Loss: 0.8723 | Validation Accuracy: 0.8929 | Validation loss: 0.9078
Epoch 00017 | Train Accuracy: 0.8839 | Train Loss: 0.8660 | Validation Accuracy: 0.8929 | Validation loss: 0.8992
Epoch 00018 | Train Accuracy: 0.8839 | Train Loss: 0.8601 | Validation Accuracy: 0.8929 | Validation loss: 0.8923
Epoch 00019 | Train Accuracy: 0.8839 | Train Loss: 0.8544 | Validation Accuracy: 0.8929 | Validation loss: 0.8874
Epoch 00020 | Train Accuracy: 0.9018 | Train Loss: 0.8487 | Validation Accuracy: 0.8929 | Validation loss: 0.8839
Epoch 00021 | Train Accuracy: 0.9107 | Train Loss: 0.8434 | Validation Accuracy: 0.8929 | Validation loss: 0.8817
Epoch 00022 | Train Accuracy: 0.9196 | Train Loss: 0.8384 | Validation Accuracy: 0.8929 | Validation loss: 0.8804
Epoch 00023 | Train Accuracy: 0.9196 | Train Loss: 0.8336 | Validation Accuracy: 0.8929 | Validation loss: 0.8794
Epoch 00024 | Train Accuracy: 0.9196 | Train Loss: 0.8298 | Validation Accuracy: 0.8929 | Validation loss: 0.8785

JoelMathewC avatar Apr 17 '23 17:04 JoelMathewC

The modified code has been moved to the exp/rgcn/dgl folder in the new seastar/rgcn branch. Using lr of 0.001 the training log is as follows

Train Accuracy: 0.2143 | Train Loss: 1.4091 | Validation Accuracy: 0.2500 | Validation loss: 1.3989
Train Accuracy: 0.7857 | Train Loss: 0.9575 | Validation Accuracy: 0.8571 | Validation loss: 0.8758
Train Accuracy: 0.8214 | Train Loss: 0.9158 | Validation Accuracy: 0.9286 | Validation loss: 0.8145
Epoch 00003 | Train Forward Time(s) 0.0689 | Backward Time(s) 0.1214
Train Accuracy: 0.8661 | Train Loss: 0.8774 | Validation Accuracy: 0.9286 | Validation loss: 0.8150
Epoch 00004 | Train Forward Time(s) 0.0696 | Backward Time(s) 0.1222
Train Accuracy: 0.8661 | Train Loss: 0.8770 | Validation Accuracy: 0.9286 | Validation loss: 0.8157
Epoch 00005 | Train Forward Time(s) 0.0702 | Backward Time(s) 0.1221
Train Accuracy: 0.8661 | Train Loss: 0.8773 | Validation Accuracy: 0.9286 | Validation loss: 0.8161
Epoch 00006 | Train Forward Time(s) 0.0701 | Backward Time(s) 0.1212
Train Accuracy: 0.8661 | Train Loss: 0.8774 | Validation Accuracy: 0.9286 | Validation loss: 0.8159
Epoch 00007 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1210
Train Accuracy: 0.8661 | Train Loss: 0.8770 | Validation Accuracy: 0.9286 | Validation loss: 0.8155
Epoch 00008 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1227
Train Accuracy: 0.8750 | Train Loss: 0.8695 | Validation Accuracy: 0.9286 | Validation loss: 0.8143
Epoch 00009 | Train Forward Time(s) 0.0706 | Backward Time(s) 0.1212
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9286 | Validation loss: 0.8117
Epoch 00010 | Train Forward Time(s) 0.0699 | Backward Time(s) 0.1214
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9286 | Validation loss: 0.8063
Epoch 00011 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1210
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9286 | Validation loss: 0.7981
Epoch 00012 | Train Forward Time(s) 0.0692 | Backward Time(s) 0.1227
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9643 | Validation loss: 0.7901
Epoch 00013 | Train Forward Time(s) 0.0702 | Backward Time(s) 0.1224
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9643 | Validation loss: 0.7849
Epoch 00014 | Train Forward Time(s) 0.0702 | Backward Time(s) 0.1209
Train Accuracy: 0.8750 | Train Loss: 0.8687 | Validation Accuracy: 0.9643 | Validation loss: 0.7822
Epoch 00015 | Train Forward Time(s) 0.0694 | Backward Time(s) 0.1217
Train Accuracy: 0.8750 | Train Loss: 0.8685 | Validation Accuracy: 0.9643 | Validation loss: 0.7808
Epoch 00016 | Train Forward Time(s) 0.0690 | Backward Time(s) 0.1226
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7800
Epoch 00017 | Train Forward Time(s) 0.0712 | Backward Time(s) 0.1218
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7797
Epoch 00018 | Train Forward Time(s) 0.0695 | Backward Time(s) 0.1217
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7796
Epoch 00019 | Train Forward Time(s) 0.0694 | Backward Time(s) 0.1222
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00020 | Train Forward Time(s) 0.0718 | Backward Time(s) 0.1224
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00021 | Train Forward Time(s) 0.0708 | Backward Time(s) 0.1217
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00022 | Train Forward Time(s) 0.0710 | Backward Time(s) 0.1216
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00023 | Train Forward Time(s) 0.0711 | Backward Time(s) 0.1220
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
Epoch 00024 | Train Forward Time(s) 0.0713 | Backward Time(s) 0.1229
Train Accuracy: 0.8839 | Train Loss: 0.8597 | Validation Accuracy: 0.9643 | Validation loss: 0.7795
max memory allocated 9701254656
Test Accuracy: 0.9444 | Test loss: 0.7994

Mean forward time: 0.070180
Mean backward time: 0.121936
^^^9.034997^^^0.192116

JoelMathewC avatar Apr 17 '23 20:04 JoelMathewC

We've isolated quite a few changes that need to be made in the codegen portion to support RGCN. We stopped when posed with the question of whether there was a benefit to moving this support into the compiler. Technically it is possible to split a relational graph into homogenous subgraphs that can be processed by seastar.

Note: To handle the need for input feature vectors multiplied with multiple weight matrices torch.bmm can be used.

JoelMathewC avatar Apr 24 '23 07:04 JoelMathewC