GraphGPS
GraphGPS copied to clipboard
Where is the "batch.node_label_index" property set
When i set cfg.dataset.task=node, cfg.model.type=gnn , cfg.gnn.stage_type=stack, then it come s to self.post_mp = GNNHead(dim_in=d_in, dim_out=dim_out) in gnn using:
class GNNNodeHead(nn.Module):
'''Head of GNN, node prediction'''
def __init__(self, dim_in, dim_out):
super(GNNNodeHead, self).__init__()
self.layer_post_mp = MLP(dim_in,
dim_out,
num_layers=cfg.gnn.layers_post_mp,
bias=True)
def _apply_index(self, batch):
if batch.node_label_index.shape[0] == batch.node_label.shape[0]:
return batch.node_feature[batch.node_label_index], batch.node_label
else:
return batch.node_feature[batch.node_label_index], \
batch.node_label[batch.node_label_index]
def forward(self, batch):
batch = self.layer_post_mp(batch)
pred, label = self._apply_index(batch)
return pred, label
i want to know Where is the "batch.node_label_index", "batch.node_label" property set
Sorry to bother , but I think I've found the answer