relationPrediction
relationPrediction copied to clipboard
why only update tail entities?Is there an error here
` mask_indices = torch.unique(batch_inputs[:, 2]).cuda() mask = torch.zeros(self.entity_embeddings.shape[0]).cuda() mask[mask_indices] = 1.0
entities_upgraded = self.entity_embeddings.mm(self.W_entities)
out_entity_1 = entities_upgraded + \
mask.unsqueeze(-1).expand_as(out_entity_1) * out_entity_1
`