DiffSBDD icon indicating copy to clipboard operation
DiffSBDD copied to clipboard

Why reconstruction loss aggregation during training is applied over the whole batch instead of only over the relevant samples?

Open yemanbh opened this issue 10 months ago • 0 comments

In line 297 in the lightning_modules.py file, the reconstruction loss was aggregated over the whole batch: "info['loss_0'] = loss_0.mean(0)" but in conditional_model.py from line 291 to 294 lines shown below, reconstruction loss was computed only for denoising t=0, which makes sense since a batch in training normally contains samples from the different stages of denoising.

loss_0_h = -log_ph_given_z0 * t_is_zero.squeeze()

# apply t_is_zero mask
error_t_lig = error_t_lig * t_is_not_zero.squeeze()

Thus, the aggregation should be applied over the samples from denoising t=0, otherwise, the loss component will be penalised since the expectation is computed over samples which are not actually from t=0.

Is this implementation error or there was a reason behind the current implementation?

yemanbh avatar Mar 18 '25 11:03 yemanbh