GaNDLF icon indicating copy to clipboard operation
GaNDLF copied to clipboard

Add 2D segmentation models from `segmentation_models.pytorch`

Open sarthakpati opened this issue 3 years ago • 2 comments

Is your feature request related to a problem? Please describe. Currently, we do not have any ImageNet-pretrained segmentation models

Describe the solution you'd like For some applications, the ImageNet-pretrained models from segmentation_models.pytorch could be useful.

Describe alternatives you've considered N.A.

Additional context These models should have a check for the patch size for [224,224] since all encoder backends are defined for only that.

sarthakpati avatar May 19 '22 18:05 sarthakpati

Following the implementation of ImageNet_VGG, we could write a new submodule called imagenet_unet and define specific encoder backends as dictionary entries in GANDLF.models.__init__.py:

global_models_dict 
global_models_dict = {
    "unet": unet,
    "unet_multilayer": unet_multilayer,
...
    "imagenet_vgg11": imagenet_vgg11,
...
    "imagenet_unet_resnet34": imagenet_unet_resnet34
    "imagenet_unet_resnet18": imagenet_unet_resnet18
...
...
}

Where imagenet_unet_resnet34 would simply be:

import segmentation_models_pytorch as smp

class imagenet_unet_resnet(ModelBase):
    def __init__(
        self,
        parameters,
        encoder_name="resnet34",
    ) -> None:    
        model = smp.Unet(
            encoder_name=encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights="imagenet",   # use `imagenet` pre-trained weights for encoder initialization
            in_channels=self.n_channels,      # model input channels (1 for gray-scale images, 3 for RGB, etc.), defined from ModelBase
            classes=self.n_classes,                 # model output channels (number of classes in your dataset), defined from ModelBase
        )

def imagenet_unet_resnet34(params):
    return imagenet_unet_resnet(params, "resnet34")


def imagenet_unet_resnet18(params):
    return imagenet_unet_resnet(params, "resnet18")

sarthakpati avatar May 19 '22 18:05 sarthakpati

Stale issue message

github-actions[bot] avatar Jul 18 '22 19:07 github-actions[bot]