`apply_to_collection` doesn't work for cached properties
Motivation
When running apply_to_collection on a dataclass, cached properties do not get modified. This can cause subtle issues: for example, suppose I initialize a dataclass on CPU in a dataworker, and then move it onto GPU for a model batch. All of the dataclass fields that contain Tensors get moved correctly, but the cached_propertys continue to residue on the original device.
Steps to reproduce
import dataclasses
from functools import cached_property
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
@dataclasses.dataclass
class Data:
a: Tensor
@cached_property
def b(self):
print("*" * 10)
print("Computing and cache prop b")
print("*" * 10)
return self.a * 2
print("*" * 10)
print("Data on CPU")
print("*" * 10)
data = Data(a=torch.tensor([1, 2, 3], device="cuda"))
print(f"{data.a=}")
print(f"{data.a.device=}")
print(f"{data.b=}")
print(f"{data.b=}") # do this a second time to make sure we're caching it
print(f"{data.b.device=}")
print("*" * 10)
print("Move Data to GPU")
print("*" * 10)
new_data = apply_to_collection(data, Tensor, lambda x: x.to("cpu"))
print(f"{new_data.a=}")
print(f"{new_data.a.device=}")
print(f"{new_data.b=}")
print(f"{new_data.b=}") # do this a second time to make sure we're caching it
print(f"{new_data.b.device=}")
Yields the following output:
**********
Start with data on GPU
**********
data.a=tensor([1, 2, 3], device='cuda:0')
data.a.device=device(type='cuda', index=0)
**********
Computing and cache prop b
**********
data.b=tensor([2, 4, 6], device='cuda:0')
data.b=tensor([2, 4, 6], device='cuda:0')
data.b.device=device(type='cuda', index=0)
**********
Move Data to CPU
**********
new_data.a=tensor([1, 2, 3])
new_data.a.device=device(type='cpu')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b.device=device(type='cuda', index=0)
The Lightning apply_to_collection logic is defined here and relies on dataclass.fields, which doesn't include cached properties
@awaelchli, do you have any experience with this one?
Hey @jackdent
This is a rare use case and I won't have the bandwidth to look into it. We would be grateful for a contribution here if you're interested. The fix is probably to just reset the cache when running apply_to_collection.