MAS-PyTorch
MAS-PyTorch copied to clipboard
about classification_head
hi, is the classification_head right in utils/model_class?
class classification_head(nn.Module): """
Each task has a seperate classification head which houses the features that
are specific to that particular task. These features are unshared across tasks
as described in section 5.1 of the paper
"""
def __init__(self, in_features, out_features):
super(classification_head, self).__init__()
self.fc = nn.Linear(in_features, out_features)
def forward(self, x):
return x