Question on training Stage2 with Real-ESRGAN degradation
Hi, thanks for sharing the great work!
I tried to follow the training process, but faced problems during training in the Stage 2 model.
After filling the train_cldm.yaml file, I run the python train.py --config configs/train_cldm.yaml, but got the below error:
Traceback (most recent call last):
File "/home/jaeha/Research/DiffBIR/train.py", line 32, in <module>
main()
File "/home/jaeha/Research/DiffBIR/train.py", line 28, in main
trainer.fit(model, datamodule=data_module)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
self._run(model)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
self._dispatch()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
self.accelerator.start_training(self)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
return self._run_train()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
self.trainer.call_hook(
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
trainer_hook(*args, **kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 189, in on_train_batch_end
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
return fn(*args, **kwargs)
File "/home/jaeha/Research/DiffBIR/model/callbacks.py", line 55, in on_train_batch_end
images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 379, in log_images
samples = self.sample_log(
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 394, in sample_log
samples = sampler.sample(
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
TypeError: sample() got an unexpected keyword argument 'unconditional_guidance_scale'
I suspect the error occurs from here: https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L394
where the function sample in SpacedSampler does not require unconditional_guidance_scale as input components.
Could you please let me know the solution for this symptom?
A quick look at git log --patch model/spaced_sampler.py shows like the model/spaced_sampler.py file was updated, but model/cldm.py was not updated to reflect the changes in sampler API.
Just bumped into this myself, you have two solutions:
- Resetting the repo (or just
model/spaced_sampler.py, but you might end up with something inconsistent) to previous version when spaced sampler had the old api. From what i see,d3e29f7is the last commit when spaced_sampler had the old api - Looking at code, scratching your head and rewriting the necessary code.
Let me know if you already did some progress on this, since you asked a while ago, i'd be interested in something other than git reset --hard d3e29f7
Thanks for sharing the awesome information!
In my solution, I changed the original sampler.sample function into another sampler.sample function, which is implemented in the inference.py. (To be precise, I additionally modify here to add "c_lq" into condition, and remove the decoding and normalizing steps.)
I'm not sure it is the correct solution, but it seems now to be working as I expected.
Thank you for the solution. Could you please provide a more specific code snippet? I now modify the calling method of the function to this way.
samples = sampler.sample( steps=steps, shape=shape, cond_img=cond["c_concat"][0], positive_prompt="", negative_prompt="", cfg_scale=1.0 )
But another error occurred
Traceback (most recent call last):█████████████████████████████████████████████████████████████████████████| 50/50 [00:37<00:00, 1.32it/s]
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
self.trainer.call_hook(
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
trainer_hook(*args, **kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 189, in on_train_batch_end
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
return fn(*args, **kwargs)
File "/data/jt/projects/DiffBIR/model/callbacks.py", line 55, in on_train_batch_end
images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/data/jt/projects/DiffBIR/model/cldm.py", line 385, in log_images
x_samples = self.decode_first_stage(samples)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/data/jt/projects/DiffBIR/ldm/models/diffusion/ddpm.py", line 832, in decode_first_stage
return self.first_stage_model.decode(z)
File "/data/jt/projects/DiffBIR/ldm/models/autoencoder.py", line 90, in decode
z = self.post_quant_conv(z)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [4, 4, 1, 1], expected input[8, 3, 512, 512] to have 4 channels, but got 3 channels instead
Does "remove the decoding and normalizing steps" means delete https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L384 where error happened? Can you give me a more detailed solution, thank you!
Below is my code snippet. But again, note that it is NOT the official solution.
from .spaced_sampler import SpacedSampler
...
class ControlLDM(LatentDiffusion):
...
@torch.no_grad()
def log_images(self, batch, sample_steps=50):
log = dict()
z, c = self.get_input(batch, self.first_stage_key)
c_lq = c["lq"][0]
c_latent = c["c_latent"][0]
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
log["hq"] = (self.decode_first_stage(z) + 1) / 2
log["control"] = c_cat
log["decoded_control"] = (self.decode_first_stage(c_latent) + 1) / 2
log["lq"] = c_lq
log["text"] = (log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) + 1) / 2
samples = self.sample_log(
# TODO: remove c_concat from cond
# cond={"c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]},
cond={"c_lq": c_lq, "c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]},
steps=sample_steps
)
# x_samples = self.decode_first_stage(samples)
# log["samples"] = (x_samples + 1) / 2
log["samples"] = samples
return log
@torch.no_grad()
def sample_log(self, cond, steps, cond_fn=None, color_fix_type="wavelet"):
sampler = SpacedSampler(self)
b, c, h, w = cond["c_concat"][0].shape
shape = (b, self.channels, h // 8, w // 8)
x_T = torch.randn(shape, device=self.model.device, dtype=torch.float32)
# samples = sampler.sample(
# steps, shape, cond, unconditional_guidance_scale=1.0,
# unconditional_conditioning=None
# )
samples = sampler.sample(
steps=steps, shape=shape, cond_img=cond["c_lq"],
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0, cond_fn=cond_fn,
color_fix_type=color_fix_type
)
return samples