Getting-Things-Done-with-Pytorch
Getting-Things-Done-with-Pytorch copied to clipboard
How to train binary classification
I use sentiment analysis with bert, however it is multiclass classification, how to change for binary class text classification.
Same as multiclass classification with few modifications.
- n_classes = 2, in last layer , self.out = nn.Linear(self.bert.config.hidden_size, n_classes), actually this will be automatically handled by the code for mutlicalss classification itself --- model = SentimentClassifier(len(class_names))
- replace softmax with sigmoid here --- F.softmax(model(input_ids, attention_mask), dim=1)
- loss function should be changed to BinaryCrossEntropyLoss i.e nn.BCELoss() from ---- loss_fn = nn.CrossEntropyLoss().to(device)