DiffSBDD
DiffSBDD copied to clipboard
Why reconstruction loss aggregation during training is applied over the whole batch instead of only over the relevant samples?
In line 297 in the lightning_modules.py file, the reconstruction loss was aggregated over the whole batch:
"info['loss_0'] = loss_0.mean(0)"
but in conditional_model.py from line 291 to 294 lines shown below, reconstruction loss was computed only for denoising t=0, which makes sense since a batch in training normally contains samples from the different stages of denoising.
loss_0_h = -log_ph_given_z0 * t_is_zero.squeeze()
# apply t_is_zero mask
error_t_lig = error_t_lig * t_is_not_zero.squeeze()
Thus, the aggregation should be applied over the samples from denoising t=0, otherwise, the loss component will be penalised since the expectation is computed over samples which are not actually from t=0.
Is this implementation error or there was a reason behind the current implementation?