'timestep_sampling=flux_shift' can offer quality improvements, but barely mentioned in docs
So I just discovered this Flux (SD3 branch) parameter:
--timestep_sampling flux_shift
Previously I'd been using
--timestep_sampling shift
due to:
- The README.md having
--timestep_sampling shiftin the example training command - The extensive diagrams showing the effect of changing the
discrete_flow_shiftvalue and the effect of differentsigmoidvalues - none of which apply to the--timestep_sampling flux_shiftmode. The space given to these modes in the README.md kind of implies (in my opinion) that these are the modes to use. - Almost zero casual references to
flux_shifton this sd-scripts repository. Searching for 'flux_shift' in the GitHub search bar returns very few results.
But now I've stumbled across this mode and tried it, my sample images are immediately more realistic.
I don't think enough people are aware of this mode and how good it can be, so I'm mostly filing this issue to raise its profile. @kohya-ss , should the README.md be updated to make flux_shift be the default in the example training command for Flux?
@recris, we've been talking about how to avoid the vertical lines artifact, and I came across this thread where someone has a similar-looking horizontal line artifact, and claims he 'fixed' it by using flux_shift:
https://github.com/kohya-ss/sd-scripts/issues/1948
Worth a try if you have some line artifacts and you currently aren't using this timestep sampling mode?
I've been using flux_shift for a long time, it does not fully solve the vertical lines issue.
The only thing that seems to fully avoid the problem is to train with larger batch sizes and/or low learning rates.
I have a (working) hypothesis on the cause for those annoying artifacts.
Basically those vertical (or horizontal) stripes are the result of "fried" latents - during training the network learned to generate extreme noise predictions (high variance), which cause the image latent to deviate from the expected value distribution, and when decoded it results in those white bands across the image. This learned behavior seems more prone to happen with high learning rates, although with low rates it still happens given enough training time.
Inspired by the research in https://github.com/kohya-ss/sd-scripts/discussions/294 , I changed the loss function and added a regularization term to penalize the network from learning to predict noise that deviates too much from the normal distribution mean and variance.
In train_network.py:
# keep this low - increasing too much causes image sharpness and contrast to explode
dist_loss_weight = 0.03
(...)
loss = train_util.conditional_loss(...)
if dist_loss_weight > 0.0:
# penalise high noise timesteps more than low noise
ts = timesteps / 1000.0
w = ts * dist_loss_weight
# noise mean and variance per channel
n_var, n_mean = torch.var_mean(noise_pred.float(), dim=(2,3), correction=0)
n_logvar = torch.log(n_var)
# KL divergence loss to standard gaussian
kl_div_loss = -0.5 * torch.sum(1 + n_logvar - n_mean.pow(2) - n_logvar.exp(), dim=1)
dist_loss = w * kl_div_loss
loss = loss.mean([1, 2, 3])
if dist_loss is not None:
loss = loss + dist_loss
This approach seems to improve some of my test cases, I was able to use more "aggressive" huber loss settings with less side effects.
There are a bunch of assumptions being made here:
- The ideal noise distribution is gaussian with mean=0 and variance=1 - I currently have no strong evidence for this
-
dist_loss_weighthas a linear schedule - I've tried a few others, at least seems to work better than a constant schedule?
That's interesting, @recris. Back when I was complaining about the vertical stripes, I found I actually had alpha masked most of my images in a way that didn't preserve the RGB color channel for alpha=0 pixels in the resulting .png file. (It's a checkbox on Gimp .png export, but it doesn't always work even when you tick it). This created large invisible black regions on the training images, which would not produce good latents.
The black-region latents would affect the denoising, even though the resulting loss would be masked. And even then, the masking is blurring 8x8 squares, so would include the black regions of the training images even though they're in alpha=0 pixels.
Once I fixed that, my stripe issues became much milder. But they were still around. If the theory is that they're caused by training images that have non 0.0/1.0 mean/std values in their latents, then that would explain why I still get them, but more slowly.
I think I'll try to write that code to print out which training images I have that are producing out-of-tolerance mean/std values in latents this weekend. And I want to give your code snippet a try too. It might even improve training quality with the un-fixed training images.
I gave that code a try on a training run that had already started showing vertical lines, to see if it could repair/recover it. No such luck - the lines stay around. And it also started making very 'grainy' images, like film grain, and kinda darker shadows on the images too. I don't think it's working out for me.
I put it after:
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
and my training images are mostly 95% alpha, apart from the (smallish) object I'm training to learn. I don't know if that affects anything or not.
You have to be careful with loss masks, the way you inserted the change is not going to work because it will apply dist_loss across the whole image, but regular loss will not.
I've actually made other changes to the way loss mask is applied, instead of doing loss * mask_image I've changed it to apply directly on the gradient during back-propagation via register_hook:
noise_pred.register_hook(lambda grad: grad * mask) # mask is rescaled to same shape as the latent tensor
This allows me to more easily play with the loss function and not have to worry if the mask is being incorrectly applied.
https://github.com/kohya-ss/sd-scripts/pull/1541 This is a pr function I created that uses the shift conditions mentioned as possible in the paper to adaptively adjust shift based on resolution...