How to improve interpretability for multi-label classification BERT models?
❓ Questions and Help
Our Use-Case
Thanks to the authors/contributors for such an amazing project! Our use-case is to use Captum to help visualize the word importance for our BERT-based multi-label classification model.
Our Problem
The issue we are experiencing is that we can only make this work well for single-label classification models. For multi-label, we are not getting a good result. E.g., irrelevant words are being highlighted as important (unlike in single-label).
Our Model
Our BERT model is fine-tuned on over a million records and there are 125 classes. For each record, there is at least one class.
Our Understanding
To level set our understanding on how to use Captum for multi-label classification, we need an attribute on each class individually using target attribute and a corresponding target index.
Our Evaluation and Findings
To evaluate, we trained both a multi-label and single-label classification BERT model and used Captum to compare the visualization for the same class between both models.
But we found that, for the same single class, the visualization of multi-label is significantly worse than the single-label model. For example, the multi-label model will apply colors on irrelevant words, even including punctuation. But the single label model visualization pays more attention to the key words, and provides a much more intuitive highlighting.
Help Needed
Is it possible to achieve a similar level of high performance of the visualization for the multi-label case as we are experiencing in the single-label case? Any tips or guidance would be much appreciated.
@mgh1, as you mentioned, with Captum you can call attribute for each target task. Do you have the colab notbook that you can share with us and we can take a look? The visualization should still be meaningful both for single and multi-task use cases.