Kilosort icon indicating copy to clipboard operation
Kilosort copied to clipboard

Memory usage for finding nearest channels is much higher in Kilosort4 vs. Kilosort3

Open naterenegar opened this issue 1 year ago • 4 comments

Describe the issue:

Hello,

I'm working with data from a 2D MEA. It seems that kilosort3 works on the data but not kilosort4 due to the differences in electrode location upsampling in the two versions. In the spike extraction step, there are two calls to a function that finds the nearby channels for every channel. The first call finds the nearest upsampled locations to all of the original channel locations, and the second call finds the nearest upsampled locations to all of the upsampled locations.

In kilosort3, the upsampled electrodes seem to be distance gated before the second call. This is so we only look at upsampled locations near original channels? The snippet from extract_spikes.m:

NchanNear = 8;
[iC, dist] = getClosestChannels2(ycup, xcup, rez.yc, rez.xc, NchanNear);

igood = dist(1,:)<dNearActiveSite;
iC = iC(:, igood);
dist = dist(:, igood);

ycup = ycup(igood);
xcup = xcup(igood);

NchanNearUp =  min(numel(ycup), 10*NchanNear);
[iC2, dist2] = getClosestChannels2(ycup, xcup, ycup, xcup, NchanNearUp);

But in kilosort4, all of the upsampled locations are kept:

[ys, xs] = np.meshgrid(ops['yup'], ops['xup'])
ys, xs = ys.flatten(), xs.flatten()
ops['ycup'], ops['xcup'] = ys, xs

xc, yc = ops['xc'], ops['yc']
Nfilt = len(ys)

nC = ops['settings']['nearest_chans']
nC2 = ops['settings']['nearest_templates']
iC, ds = nearest_chans(ys, yc, xs, xc, nC, device=device)
iC2, ds2 = nearest_chans(ys, ys, xs, xs, nC2, device=device)

This snippet fails at the very last line with a CUDA out of memory error. My GPU has 24GB of VRAM. Specifically, the distance calculation in nearest_chans fails:

def nearest_chans(ys, yc, xs, xc, nC, device=torch.device('cuda')):
    ds = (ys - yc[:,np.newaxis])**2 + (xs - xc[:,np.newaxis])**2 # <-- fails here
    iC = np.argsort(ds, 0)[:nC]
    iC = torch.from_numpy(iC).to(device)
    ds = np.sort(ds, 0)[:nC]
    return iC, ds

I understand the main application of Kilosort is for in vivo shank probes. In this case, the probes are not very wide, and most electrodes are close together so that the number of upsampled electrodes is manageable. For 2D-MEAs, which can be wider than they are tall, and also can have large gaps between recording clusters, the upsampling adds many electrodes that are nowhere near recording sites.

In my case, the original HD-MEA has ~26,000 electrodes, but we can only record from ~1,000 sites. If the sites are spread across the MEA, then the upsampling procedure creates 4 times the number of total sites on the MEA (doubling in each dimension). Then the distance calculation is creating a ~(100000,100000) matrix of floats, which is tens of gigabytes. If the sites were distance gated, I'd guess this number would be drastically reduced.

Thanks, Nathan

naterenegar avatar Mar 27 '24 15:03 naterenegar

Thank you for this issue. I think we just missed the distance gating in Kilosort4. Were you getting good results overall with Kilosort3? There is another step where I would imagine problems (clustering in groups of nearest channels).

marius10p avatar Mar 28 '24 22:03 marius10p

Hi Marius,

Yeah kilosort3 was giving me good results! I haven't looked at the specific step you mentioned to see how it would do on MEA data.

P.S., it seems someone else has encountered this issue #647

naterenegar avatar Mar 29 '24 14:03 naterenegar

@naterenegar I just pushed version 4.0.4, which should address this problem. Would you mind trying it out on your data when you have time and letting us know how it goes? Note that you will likely need to set the new x_centers parameter for a 2D array like this (under "Extra settings" if you're using the GUI). The goal with that parameter is to not include too many templates in a single grouping, it will divide up the horizontal space of the probe into that many sections. I would start with a value around 10 for a large array, and try increasing it if the clustering step seems exceptionally slow. You can see where the grouping centers get placed by checking the box by that name under the probe plot in the GUI.

You might also need to set max_channel_distance (also under "Extra settings") if you still run into memory issues. This controls the distance gating you pointed out from previous versions. By default it will keep templates that are within max(dmin, dminx), but that might be overkill if there's a lot of space between channels. You can also preview those in the GUI with the "Universal Templates" check-box.

jacobpennington avatar Apr 13 '24 21:04 jacobpennington

@naterenegar Hi, we are currently trying to process 2D MEA data using Kilosort, and it seems we are using the same product. I'd like to ask, when creating probes, do I need to create a probe array of 26400 probes, or just use the channels that are currently in use? Thanks :)

yyyaaaaaaa avatar Apr 25 '24 07:04 yyyaaaaaaa

Closing this for now, please let us know if you try out the new version and still have problems sorting this dataset.

jacobpennington avatar May 09 '24 20:05 jacobpennington

Hi Jacob, the new version ran through without any issues on this dataset. Thanks for the update!

naterenegar avatar Jul 09 '24 16:07 naterenegar