A Small Issue with the MLP Model
In the MLP model, I think in the last layer it should be F.log_softmax instead of softmax. Otherwise, the NLL loss would return negative values.
In the MLP model, I think in the last layer it should be F.log_softmax instead of softmax. Otherwise, the NLL loss would return negative values.
I agree with you, I came across the same problem
Can you please specify it is on which line inside models.py? I am new to this, thanks!
We agree with mahf93, tiangency. The code uses Softmax() (line 16 in ref [1]) as the output unit and NLLLoss() as the loss function (line 34 in ref [2]), which will result in the negative loss values.
According to ref [3], by default, the NLLLoss is calculated as:
$$ \ell(x, y) = \frac{1}{N} \sum_{n=1}^N \ell_n = \frac{1}{N} \sum_{n=1}^N (-x_{n,y_n}), $$
where $x$ is the input, $y$ is the target, $N$ is the batch size, $\ell_n$ is the per-example loss. Suppose that $x_n = [0.5, 0.3, 0.2]$ (i.e., there are 3 classes), $y_n = 1$, then the loss $\ell_n = -x_n[y_n] = -x_n[1] = -0.3$.
So according to $\text{softmax}(x) \in (0,1)$ (i.e., $0 < \text{softmax}(x) < 1$), the loss value is always negative. However, if we replace nn.Softmax(dim=1) with nn.LogSoftmax(dim=1) (line 16 in ref [1]), since $\log (\text{softmax}(x)) < 0$, the loss value will be always positive.
It is noted that the combination of LogSoftmax() and NLLLoss() is equivalent to CrossEntropyLoss() according to ref [4] (not rigorous, please see ref [4] for details).
At last, we use python federated_main.py --model=mlp --dataset=mnist --gpu=0 --iid=1 --epochs=10 to test the two settings and get the results that LogSoftmax (Test accuracy 93.64%) and Softmax (Test accuracy 90.64%), which shows that the former is better. However, we can not find that there is one rule that the value of loss must be positive.
Refs: [1] https://github.com/AshwinRJ/Federated-Learning-PyTorch/blob/master/src/models.py [2] https://github.com/AshwinRJ/Federated-Learning-PyTorch/blob/master/src/update.py [3] https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss [4] https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html