k-diffusion
k-diffusion copied to clipboard
Add transform_last for BrownianTreeNoiseSampler
torch.Generator is not supported on Intel so we have to run BrownianTreeNoiseSampler on the CPU with transform and send results back to the GPU with transform_last.
It will throw this error if we try sending it to the GPU outside of the class: AttributeError: 'BrownianTreeNoiseSampler' object has no attribute 'to'
Example use: return BrownianTreeNoiseSampler(x.to("cpu"), sigma_min, sigma_max, seed=current_iter_seeds, transform=lambda x: x.to("cpu"), transform_last=lambda x: x.to("xpu"))