Getting-Things-Done-with-Pytorch icon indicating copy to clipboard operation
Getting-Things-Done-with-Pytorch copied to clipboard

How to train binary classification

Open ps3-app opened this issue 5 years ago • 2 comments

I use sentiment analysis with bert, however it is multiclass classification, how to change for binary class text classification.

ps3-app avatar Jan 20 '21 16:01 ps3-app

Same as multiclass classification with few modifications.

  1. n_classes = 2, in last layer , self.out = nn.Linear(self.bert.config.hidden_size, n_classes), actually this will be automatically handled by the code for mutlicalss classification itself --- model = SentimentClassifier(len(class_names))
  2. replace softmax with sigmoid here --- F.softmax(model(input_ids, attention_mask), dim=1)
  3. loss function should be changed to BinaryCrossEntropyLoss i.e nn.BCELoss() from ---- loss_fn = nn.CrossEntropyLoss().to(device)

kforcodeai avatar Jan 23 '21 12:01 kforcodeai