Compatibility with Torch > 2.0.1
Hello @shariqfarooq123
We're working with the latest nightly version of torch. They seem to have added type asserts to interpolate (this commit).
This causes a runtime error because the size we pass in is of type numpy.int32. A simple int-cast should fix this without any side effects.
@thias15 could you please merge this? It causes issues with newer torch versions :)
@thias15
Would be helpful if someone could merge this one :) @thias15 @shariqfarooq123
Edit: a different interpolate call also requires typecasting
Monkey-patching interpolate may be a workaround:
# Backup the original interpolate function
original_interpolate = F.interpolate
def patched_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
if size is not None:
size = tuple(int(s) for s in size)
return original_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor)
model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_K", pretrained=True).cuda()
F.interpolate = patched_interpolate
depths_zoe = model.infer(imgs)
F.interpolate = original_interpolate