Brodcasting error when saving recording after deep interpolation
Hi, I'm having issues with storing the recording object after applying deepinterpolation. When saving an interpolated recording, I get error messages like:
ValueError: could not broadcast input array from shape (29886,768) into shape (30000,768)
I assume that is because the convolutional network removes samples at the edges of the signal. What is the recommended way of dealing with this? Should I pad the traces before applying deepinterpolation? If so, is there a function for this? Or would it be better to create a new object for storing the interpolated recording?
Below is a code example. I'm using spikeinterface version 0.100.6 because deepinterpolation requires python <= 3.8
from pathlib import Path
import numpy as np
import spikeinterface.extractors as se
from spikeinterface.preprocessing import common_reference, deepinterpolate, zero_channel_pad
root = Path(__file__).parent.parent.absolute()
model_path = list((root/"model").glob("*"))[0]
recording = se.read_spikeglx( folder_path=root/"data"/"neuropixels")
# pad so dimensions are compatible with model
recording = zero_channel_pad(recording, num_channels=384*2)
# create preprocessing pipeline
recording_cmr = common_reference(recording=recording, operator="median")
recording_deepint = deepinterpolate(recording=recording_cmr, model_path=str(model_path), use_gpu=False)
# process and save
recording_deepint.save()
Thanks in advance!
Hi, can you share the complete error trace?
Sure, here it is:
Use cache_folder=/tmp/spikeinterface_cache/random_id
write_binary_recording with n_jobs = 1 and chunk_size = 30000
write_binary_recording: 0%| | 0/587 [00:00<?, ?it/s]/home/user/project/.venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:2323: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
warnings.warn('`Model.state_updates` will be removed in a future version. '
write_binary_recording: 0%| | 0/587 [00:57<?, ?it/s]
Traceback (most recent call last):
File "scripts/test_deepinterpolate.py", line 30, in <module>
recording_deepint.save()
File "/home/user/project/.venv/lib/python3.8/site-packages/spikeinterface/core/base.py", line 845, in save
loaded_extractor = self.save_to_folder(**kwargs)
File "/home/user/project/.venv/lib/python3.8/site-packages/spikeinterface/core/base.py", line 931, in save_to_folder
cached = self._save(folder=folder, verbose=verbose, **save_kwargs)
File "/home/user/project/.venv/lib/python3.8/site-packages/spikeinterface/core/baserecording.py", line 462, in _save
write_binary_recording(self, file_paths=file_paths, dtype=dtype, **job_kwargs)
File "/home/user/project/.venv/lib/python3.8/site-packages/spikeinterface/core/recording_tools.py", line 137, in write_binary_recording
executor.run()
File "/home/user/project/.venv/lib/python3.8/site-packages/spikeinterface/core/job_tools.py", line 381, in run
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
File "/home/user/project/.venv/lib/python3.8/site-packages/spikeinterface/core/recording_tools.py", line 171, in _write_binary_chunk
array[start_frame:end_frame, :] = traces
ValueError: could not broadcast input array from shape (29983,768) into shape (30000,768)
Yes, it would seem that either the number of samples or the dtype might be wrong so I think your initial hunch is right.