CSI
CSI copied to clipboard
How to define the joint_labels
joint_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0) I do not understand what is the meaning of this code.