API compatibility of DynUNet and MedNeXt in deep supervision mode
Is your feature request related to a problem? Please describe. If deep supervision is set to true the DynUNet model returns a tensor with the second dimension containing the upscaled intermediate outputs of the model, whereas the new MedNeXt implementation returns a tuple with the tensors of the intermediate outputs in their original low res.
Describe the solution you'd like For compatibility purposes and ease of use, I would suggest using the DynUNet behavior also in MedNeXt. I think having the intermediate outputs all in the same output resolution makes them easier to use for loss function calculations. Plus, the DynUNet is already stable for some time whereas the MedNeXt is not.
Describe alternatives you've considered Obviously, the DynUNet behavior could be changed to act identical to the MedNeXt, which might be slightly more memory efficient, but that would certainly break more existing code and it is less convenient if references need to be downscaled for loss calculations.
The tuple case also exists for the SegResNetDS where deep supervision is available. The MedNeXt implementation contains this as it was developed in consistency with SegResNetDS.
Only the tuple option seems to make sense to me here. While upscaling is an option, I don't think it should be the default. What if the developer requires the intermediate maps to be operated on in the original dims? And as you pointed, this is a more mem. efficient option too.
I do agree that there should be consistency - given that, should DynUNet be ported to SegResNetDS and MedNeXt behaviour?
Hm didn't know about the SegResNetDs. I only run into this issue as I was exchanging a DynUnet by the MedNext (btw. thanks for the impelmentation :) ).
I agree the tuple version makse more sense, and the user should decide how to handle the different scales. However I found the DynUnet version is quite handy to use in loss calls, and didnt want to brake to much existing code. I think there is at least one tutorial using the DynUnet somewhere around. But I will look into adapting the DynUnet tomorrow.
The authors of the DynUnet argued in their args description with "a restriction of TorchScript" for their Tensor version with an additional dim. Do you know what they are referring to? And out of interest, how does your code for the loss calc looks like, when using the tuple version? Do you upsacle the output or downsacle the refernce?
I use the loss here, which does upsample as you mentioned https://github.com/Project-MONAI/MONAI/blob/46a5272196a6c2590ca2589029eed8e4d56ff008/monai/losses/ds_loss.py#L23-L85
While the end result is upsampling of the network intermediates - I do think this should be performed within the loss calc. module as opposed to the network itself.
For the torchscript Q, I am not very familiar with the DynUnet structure. Could you point me to where this is mentioned?
Its mentioned here:
https://github.com/Project-MONAI/MONAI/blob/8ac8e0d52dced4667fcc73812e012078a35e3359/monai/networks/nets/dynunet.py#L114