DeepSpeed
DeepSpeed copied to clipboard
[BUG] FlopsProfiler upsample flops compute bug
Describe the bug the upsample flops compute code : `def _upsample_flops_compute(*args, **kwargs):
scale_factor = kwargs.get('scale_factor', None)
if scale_factor is None and len(args) > 2:
scale_factor = args[2]
assert scale_factor is not None, "either size or scale_factor should be defined"
flops = input.numel()
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
flops *= int(_prod(scale_factor))
else:
flops *= scale_factor**len(input)
return flops, 0`
len(scale_factor) == len(input) -> len(scale_factor) == len(input.size()) flops *= scale_factor ** len(input) -> flops *= scale_factor ** (len(input.size())-1)
upsample flops compute in torchstat: https://github.com/Swall0w/torchstat/blob/master/torchstat/compute_flops.py#L83