escnn icon indicating copy to clipboard operation
escnn copied to clipboard

Align SequentialModule.add_module behavior with torch.nn.Module.add_module

Open dmklee opened this issue 2 years ago • 0 comments

The assertion checks in SequentialModule.add_module assume that the module is being added to the end of the sequence. However, it calls torch.nn.Module.add_module which allows for overwriting existing modules by name. So, if one tries to overwrite an existing module by name, an assertion error is raised. I modified the assertion checks so that modules can be overwritten so long as they have the correct in_type and out_type according to the modules before and after them in the sequence.

I encountered this issue when trying to call torch.nn.SyncBatchNorm.convert_sync_batchnorm on a SequentialModule. With this change, you can sync InnerBatchNorm's across machines during training since they use torch.nn.BatchNormXd under the hood.

dmklee avatar Sep 11 '23 14:09 dmklee