GraphGPS icon indicating copy to clipboard operation
GraphGPS copied to clipboard

Where is the "batch.node_label_index" property set

Open jasperzelu opened this issue 2 years ago • 1 comments

When i set cfg.dataset.task=node, cfg.model.type=gnn , cfg.gnn.stage_type=stack, then it come s to self.post_mp = GNNHead(dim_in=d_in, dim_out=dim_out) in gnn using:

class GNNNodeHead(nn.Module):
    '''Head of GNN, node prediction'''
    def __init__(self, dim_in, dim_out):
        super(GNNNodeHead, self).__init__()
        self.layer_post_mp = MLP(dim_in,
                                 dim_out,
                                 num_layers=cfg.gnn.layers_post_mp,
                                 bias=True)

    def _apply_index(self, batch):
        if batch.node_label_index.shape[0] == batch.node_label.shape[0]:
            return batch.node_feature[batch.node_label_index], batch.node_label
        else:
            return batch.node_feature[batch.node_label_index], \
                   batch.node_label[batch.node_label_index]

    def forward(self, batch):
        batch = self.layer_post_mp(batch)
        pred, label = self._apply_index(batch)
        return pred, label

i want to know Where is the "batch.node_label_index", "batch.node_label" property set

jasperzelu avatar Dec 15 '23 02:12 jasperzelu

Sorry to bother , but I think I've found the answer

jasperzelu avatar Dec 15 '23 02:12 jasperzelu