Kilosort icon indicating copy to clipboard operation
Kilosort copied to clipboard

Improve memory management in clustering_qr.kmeans_plusplus

Open RobertoDF opened this issue 1 year ago • 13 comments

This modification avoids the creation or immediately deletes unnecessary tensors in clustering_qr.kmeans_plusplus. It helps with OOM errors (#746 ) happening at https://github.com/MouseLand/Kilosort/blob/b2f5ded41aaa7e13ed44bf43ecf488770ef54753/kilosort/clustering_qr.py#L202

Xg can at sometimes be quite big (5GB in the case I get OOM), in both of these lines a copy of Xg was created unnecessarily on the GPU.

https://github.com/MouseLand/Kilosort/blob/b2f5ded41aaa7e13ed44bf43ecf488770ef54753/kilosort/clustering_qr.py#L166-L168 & https://github.com/MouseLand/Kilosort/blob/b2f5ded41aaa7e13ed44bf43ecf488770ef54753/kilosort/clustering_qr.py#L202

The solution to line 202 does not impact speed. Solution to line 167 might impact speed but not in any noticeable fashion on my tests, for this reason I didn´t extend the reach of the clear_cache arg to the kmeans_plusplus func.

Tested on pytorch 2.1.2 and 2.4.1.

RobertoDF avatar Sep 04 '24 09:09 RobertoDF

@RobertoDF Are you able to share the data that you're seeing this problem with so that I can test this myself?

jacobpennington avatar Sep 04 '24 17:09 jacobpennington

Sure, compressing the files now.

RobertoDF avatar Sep 04 '24 18:09 RobertoDF

In the zip there is a jupyter notebook that shows the problem and the specific Xd tensor that causes the crash on my machine. I put the standard and modified versions of kmeans_plusplus. The old one should crash, if you run the new one afterwards, it should run without errors. https://we.tl/t-40kiuNy3Cd

RobertoDF avatar Sep 05 '24 09:09 RobertoDF

Just noticed that in the notebook I didn´t include the change at line vtot = (Xg**2).sum(1)

RobertoDF avatar Sep 05 '24 20:09 RobertoDF

@RobertoDF Those are not the files I would need. I mean the full recording, either a .bin file or whatever format you converted from, along with the probe file you used.

jacobpennington avatar Sep 05 '24 22:09 jacobpennington

This last commit seems to really solve the OOM problems.

RobertoDF avatar Sep 07 '24 21:09 RobertoDF

Hello, I tried to use your last commit, but I'm still getting a CUDA OOM error in the final clustering phase. How much dedicated GPU memory do you have? I have 8 GB, and Kilosort used on average 6-7 GB throughout sorting until crashing at the end.

Peyton-D avatar Sep 26 '24 23:09 Peyton-D

I have 12 GB. Without the modification I would get OOM often inside the kmeans_plus_plus func. Which line is problematic to you exactly? and what is the error message saying? also what is your recording duration?

RobertoDF avatar Sep 27 '24 08:09 RobertoDF

Thanks for the quick response. Yes, kmeans_plus_plus inside of clustering_qr seems to be the cause of each crash every time. My recording duration is 90 min. Here's the problematic line and the kilosort log if it helps:

File "C:\miniconda3\envs\kilosort\lib\site-packages\kilosort\clustering_qr.py", line 215, in kmeans_plusplus mu[j] = Xg[ix].mean(0)

kilosort4_9_26_1700.log

Peyton-D avatar Sep 27 '24 14:09 Peyton-D

Mmm never had a crash at that line. If you use the normal version, not my fork, does it also crashes in the same line?

RobertoDF avatar Sep 27 '24 14:09 RobertoDF

Just ran another attempt with normal version. Here's the problem line:

File "C:\Users\ColginLab\miniconda3\envs\kilosort\lib\site-packages\kilosort\clustering_qr.py", line 167, in kmeans_plusplus vtot = (Xg**2).sum(1)

kilosort4_normal_version.log

Peyton-D avatar Sep 27 '24 15:09 Peyton-D

Ok that was a problematic line also for me and indeed I would expect my solution to solve that one. But I never had a problem at the line you showed me before. Maybe it can be optimized further but I won't have time to check this in near future. If you have access to a 12 GB I would expect that to solve the problem. If you are on windows you can try to use a debugger stopping at that line and inspect the GPU memory via task manager.

RobertoDF avatar Sep 27 '24 15:09 RobertoDF

Alright, I'll look into getting more GPU memory. Thanks for the help!

Peyton-D avatar Sep 27 '24 15:09 Peyton-D

@RobertoDF Are you able to provide a bit more explanation for the changes you proposed? I can see from other issues that they're helping with some memory problems, but I'm having a hard time finding any information in the Pytorch docs that would explain why these changes prevent copies / otherwise reduce memory usage.

jacobpennington avatar Nov 19 '24 04:11 jacobpennington

Sure! I just went in the code using a debugger breakpoint while checking GPU memory consumption and substitute (while checking the output to be identitical) lines until I would find a combination that would somehow avoid the unnecessary creation of large arrays on the GPU without sacrificing any speed (at least in my tests). Loads of trial and error!!

RobertoDF avatar Nov 19 '24 04:11 RobertoDF

@Peyton-D did you find a solution for your problem?

I'm having an identical issue: failing on line 215, in kmeans_plusplus mu[j] = Xg[ix].mean(0)

I have a large recording (several hours, ~300gb) but I have successfully kilosorted many others of similar or larger sizes. I'm only getting this issue with recordings from the same probe so i'm suspicious that it is something probe related? Though it looks perfectly normal when I plot out the channels, not especially noisy or anything like that.

kilosort4.log

EmmettJT avatar Jan 28 '25 08:01 EmmettJT

@EmmettJT Yes, my issue was fixed by updating my nvidia driver to version 561.09 and cuda version to 12.6, specifically. I did not end up using this commit. My recordings are ~120 GB each, and I was able to sort them using a GeForce GTX 1080.

Peyton-D avatar Jan 28 '25 15:01 Peyton-D

Codecov Report

Attention: Patch coverage is 0% with 10 lines in your changes missing coverage. Please review.

Project coverage is 0.00%. Comparing base (d8ba42f) to head (f73c7a5). Report is 533 commits behind head on main.

Files with missing lines Patch % Lines
kilosort/clustering_qr.py 0.00% 10 Missing :warning:
Additional details and impacted files
@@          Coverage Diff           @@
##            main    #775    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files         32      33     +1     
  Lines       4649    5589   +940     
======================================
- Misses      4649    5589   +940     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Jan 30 '25 15:01 codecov-commenter

I'm closing this pull request because it's superceded by the changes made in v4.0.26.

A short explanation of the two changes this PR made and how the new changes address those:

  1. Computing vtot on CPU instead of GPU. The way Xg was computed before required allocating a full copy of the tensor temporarily, and moving that to CPU took some of the extra load off of the GPU. The new changes compute this with torch.norm instead, which is a bit faster and avoids the duplicate allocation altogether.

  2. Breaking up vexp assignment into multiple lines of code. This improved memory usage in a roundabout way by assigning vexp to a low-memory as a first step, which releases the memory before allocating it again on the next iteration. The new changes accomplish the same thing by adding del(vexp) and del(dexp) to the end of the loop.

The new changes also remove an unused Xg[...] call that could result in a full (or almost full) copy of Xg being temporarily allocated on the first few iterations. Additionally, the size of Xg is reduced (sometimes by quite a lot) by updating code that assumed channel maps were consecutive.

jacobpennington avatar Feb 27 '25 20:02 jacobpennington