relationPrediction icon indicating copy to clipboard operation
relationPrediction copied to clipboard

why only update tail entities?Is there an error here

Open jweihe opened this issue 3 years ago • 0 comments

` mask_indices = torch.unique(batch_inputs[:, 2]).cuda() mask = torch.zeros(self.entity_embeddings.shape[0]).cuda() mask[mask_indices] = 1.0

    entities_upgraded = self.entity_embeddings.mm(self.W_entities)
    out_entity_1 = entities_upgraded + \
        mask.unsqueeze(-1).expand_as(out_entity_1) * out_entity_1

`

jweihe avatar Apr 10 '22 08:04 jweihe