Add 2D segmentation models from `segmentation_models.pytorch`
Is your feature request related to a problem? Please describe. Currently, we do not have any ImageNet-pretrained segmentation models
Describe the solution you'd like For some applications, the ImageNet-pretrained models from segmentation_models.pytorch could be useful.
Describe alternatives you've considered N.A.
Additional context
These models should have a check for the patch size for [224,224] since all encoder backends are defined for only that.
Following the implementation of ImageNet_VGG, we could write a new submodule called imagenet_unet and define specific encoder backends as dictionary entries in GANDLF.models.__init__.py:
global_models_dict
global_models_dict = {
"unet": unet,
"unet_multilayer": unet_multilayer,
...
"imagenet_vgg11": imagenet_vgg11,
...
"imagenet_unet_resnet34": imagenet_unet_resnet34
"imagenet_unet_resnet18": imagenet_unet_resnet18
...
...
}
Where imagenet_unet_resnet34 would simply be:
import segmentation_models_pytorch as smp
class imagenet_unet_resnet(ModelBase):
def __init__(
self,
parameters,
encoder_name="resnet34",
) -> None:
model = smp.Unet(
encoder_name=encoder_name, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=self.n_channels, # model input channels (1 for gray-scale images, 3 for RGB, etc.), defined from ModelBase
classes=self.n_classes, # model output channels (number of classes in your dataset), defined from ModelBase
)
def imagenet_unet_resnet34(params):
return imagenet_unet_resnet(params, "resnet34")
def imagenet_unet_resnet18(params):
return imagenet_unet_resnet(params, "resnet18")
Stale issue message