sd-scripts icon indicating copy to clipboard operation
sd-scripts copied to clipboard

Commit "Update T5 attention mask handling in FLUX" results in an increase of tens of GB of video memory usage when using --apply_t5_attn_mask

Open leonary opened this issue 1 year ago • 8 comments

Commit "Update T5 attention mask handling in FLUX" results in an increase of tens of GB of video memory usage when using --apply_t5_attn_mask.

leonary avatar Aug 24 '24 05:08 leonary

Applying T5 mask seems to increase memory usage by about 1GB, but 10GB is too much. Please let me know the versions of PyTorch and CUDA. Upgrading PyTorch to 2.4.0 and CUDA to 12.4 may solve the problem.

kohya-ss avatar Aug 24 '24 05:08 kohya-ss

Applying T5 mask seems to increase memory usage by about 1GB, but 10GB is too much. Please let me know the versions of PyTorch and CUDA. Upgrading PyTorch to 2.4.0 and CUDA to 12.4 may solve the problem.

After updating to torch 2.4.0+cu121, I encountered ImportError: cannot import name 'log' from 'torch.distributed.elastic.agent.server.api' (/root/miniconda3/envs/3.10/lib/python3.10/site-packages/torch/distributed/elastic/agent/server/api.py). I am curious why --apply_t5_attn_mask did not take up tens of GB of additional video memory before commit "Update T5 attention mask handling in FLUX" (hash:7e459c00b2e142e40a9452341934c2eb9f70a172)?

leonary avatar Aug 24 '24 08:08 leonary

Is this why I just updated from a week or two ago and now a 1024, 1024 training is instantly OOM? on a 4090? Nothing else changed.

DarkAlchy avatar Aug 25 '24 13:08 DarkAlchy

For some reason, FlashAttention for PyTorch's scale_dot_product_attention may be disabled when masking. This may depend on the CUDA version, GPU or mixed precision dtype etc.

If you turn off apply_t5_attn_mask, you should get similar training results as before.

kohya-ss avatar Aug 25 '24 14:08 kohya-ss

I did not, and gave up as it went a little further, but as soon as step 1 was to start OOM.

DarkAlchy avatar Aug 25 '24 14:08 DarkAlchy

For some reason, FlashAttention for PyTorch's scale_dot_product_attention may be disabled when masking. This may depend on the CUDA version, GPU or mixed precision dtype etc.

If you turn off apply_t5_attn_mask, you should get similar training results as before.

@kohya-ss I updated torch to 2.4.0 and deepspeed to the latest version. Now using --apply_t5_attn_mask does not seem to cause video memory overflow, but the training time is doubled compared to not using this option. Similarly, this problem did not exist before this commit. Why is this?

leonary avatar Aug 30 '24 02:08 leonary

@kohya-ss I updated torch to 2.4.0 and deepspeed to the latest version. Now using --apply_t5_attn_mask does not seem to cause video memory overflow, but the training time is doubled compared to not using this option. Similarly, this problem did not exist before this commit. Why is this?

When --apply_t5_attn_mask is specified, we fixed the implementation to correctly mask attention from zero padding in embeddings. Attention masking seems to be quite slow in PyTorch's SDPA in some cases.

kohya-ss avatar Aug 30 '24 12:08 kohya-ss

When --apply_t5_attn_mask is specified, we fixed the implementation to correctly mask attention from zero padding in embeddings. Attention masking seems to be quite slow in PyTorch's SDPA in some cases.

Thank you for your explanation, it seems I will have to give up using this option for the time being.

leonary avatar Aug 30 '24 13:08 leonary