Skip to content

Handle automatic chunks duration for SC2 #3721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
1ebb6db
Handle automatic RAM allocation for chunks
yger Feb 25, 2025
0ce092c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2025
cf31e7f
Default
yger Feb 25, 2025
b299dd3
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Feb 25, 2025
7255530
Default
yger Feb 25, 2025
b6dc572
Sync with main
yger Feb 25, 2025
26fc226
WIP
yger Feb 27, 2025
3819f8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2025
42d1c2c
Keeping the input dict
yger Feb 27, 2025
d156f60
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Feb 27, 2025
29eb160
Reducing memory footprint
yger Feb 28, 2025
ed319bf
Patch for small num_channels
yger Feb 28, 2025
8c1ca39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2025
0a2bd23
Saving the final analyzer
yger Feb 28, 2025
b913be2
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Feb 28, 2025
5d88e72
Docstrings
yger Mar 2, 2025
f746d41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2025
d6b3bdb
More docstrings
yger Mar 2, 2025
a891d13
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Mar 2, 2025
f2a3ac4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2025
b72ee86
Cosmetic and bug fixes
yger Mar 3, 2025
151fefb
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Mar 3, 2025
0445d4e
Remove HDBSCAN dependency
yger Mar 5, 2025
6409824
Remove hdbscan
yger Mar 5, 2025
b9b3457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2025
5d4e516
Adding sklearn as a dependency for testing
yger Mar 5, 2025
03ccc7e
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Mar 5, 2025
fcea1b6
Spaces
yger Mar 5, 2025
3812ece
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2025
8c400f4
HDBSCAN will go in other PR
yger Mar 5, 2025
d6917c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2025
12b5edb
Reverting
yger Mar 5, 2025
297d9b9
Merge branch 'main' into total_memory
yger Mar 7, 2025
9bb2f33
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Mar 7, 2025
c86a587
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Mar 7, 2025
0e9468c
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Mar 10, 2025
b7de6a1
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Mar 10, 2025
8282461
Bringing back optimal n jobs
yger Mar 12, 2025
ce56760
Fixes
yger Mar 12, 2025
7cc16ce
WIP
yger Mar 12, 2025
c4faee6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2025
d25105f
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Mar 12, 2025
08213c4
Merge branch 'SpikeInterface:main' into total_memory
yger Mar 14, 2025
eae719c
Merge branch 'SpikeInterface:main' into total_memory
yger Mar 20, 2025
ba96c50
Merge branch 'SpikeInterface:main' into total_memory
yger Mar 26, 2025
68c03f3
Merge branch 'main' into total_memory
yger Mar 27, 2025
0138b17
Merge branch 'SpikeInterface:main' into total_memory
yger Mar 28, 2025
bd07dc3
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 2, 2025
b68e2fb
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Apr 2, 2025
b569023
Merge branch 'main' into total_memory
yger Apr 4, 2025
08f579e
WIP
yger Apr 4, 2025
c26b092
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2025
16e9a51
Desactivate by default
yger Apr 4, 2025
7559588
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Apr 7, 2025
51f753a
WIP
yger Apr 7, 2025
b7fb2bf
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 9, 2025
5b47453
Sync with main
yger Apr 16, 2025
456e257
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2025
71c9eec
Merge branch 'main' into total_memory
yger Apr 17, 2025
c97dff3
Merge branch 'total_memory' of github.com:yger/spikeinterface into to…
yger Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def get_random_recording_slices(
chunk_duration : str | float | None, default "500ms"
The duration of each chunk in 's' or 'ms'
chunk_size : int | None
Size of a chunk in number of frames. This is ued only if chunk_duration is None.
Size of a chunk in number of frames. This is used only if chunk_duration is None.
This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead.
concatenated : bool, default: True
If True chunk are concatenated along time axis
Expand Down
16 changes: 13 additions & 3 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
cache_preprocessing,
get_prototype_and_waveforms_from_recording,
get_shuffled_recording_slices,
_set_optimal_chunk_size,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.sparsity import compute_sparsity
Expand All @@ -39,6 +40,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"apply_preprocessing": True,
"templates_from_svd": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"chunk_preprocessing": {"memory_limit": None},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.75},
"seed": 42,
Expand Down Expand Up @@ -66,6 +68,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)",
"cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \
memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting",
"chunk_preprocessing": "How much RAM (approximately) should be devoted to load all data chunks (given n_jobs).\
memory_limit will control how much RAM can be used as a fraction of available memory. Otherwise, use total_memory to fix a hard limit, with\
a string syntax (e.g. '1G', '500M')",
"multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)",
"job_kwargs": "A dictionary to specify how many jobs and which parameters they should used",
"seed": "An int to control how chunks are shuffled while detecting peaks",
Expand Down Expand Up @@ -100,8 +105,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

job_kwargs = fix_job_kwargs(params["job_kwargs"])
job_kwargs.update({"progress_bar": verbose})

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
if params["chunk_preprocessing"].get("memory_limit", None) is not None:
job_kwargs = _set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"])

sampling_frequency = recording.get_sampling_frequency()
num_channels = recording.get_num_channels()
Expand Down Expand Up @@ -382,7 +388,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting.save(folder=curation_folder)
# np.save(fitting_folder / "amplitudes", guessed_amplitudes)

sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs)
final_analyzer = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs)
final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer")

sorting = final_analyzer.sorting

if verbose:
print(f"Kept {len(sorting.unit_ids)} units after final merging")
Expand Down Expand Up @@ -441,4 +450,5 @@ def final_cleaning_circus(
sparsity_overlap=sparsity_overlap,
**job_kwargs,
)
return final_sa.sorting

return final_sa
14 changes: 7 additions & 7 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,12 @@

import random, string
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.core.template import Templates
from spikeinterface.core.sparsity import compute_sparsity
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.sortingcomponents.tools import remove_empty_templates, _get_optimal_n_jobs
from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd


from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel


Expand Down Expand Up @@ -62,6 +57,7 @@ class CircusClustering:
"noise_levels": None,
"tmp_folder": None,
"verbose": True,
"memory_limit": 0.25,
"debug": False,
}

Expand Down Expand Up @@ -162,13 +158,17 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
if not params["templates_from_svd"]:
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording

job_kwargs_local = job_kwargs.copy()
unit_ids = np.unique(peak_labels)
ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4
job_kwargs_local = _get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"])
templates = get_templates_from_peaks_and_recording(
recording,
peaks,
peak_labels,
ms_before,
ms_after,
**job_kwargs,
**job_kwargs_local,
)
else:
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd
Expand Down
15 changes: 10 additions & 5 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,16 @@ def compress_templates(
if remove_mean:
templates_array -= templates_array.mean(axis=(1, 2))[:, None, None]

temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False)
# Keep only the strongest components
temporal = temporal[:, :, :approx_rank].astype(np.float32)
singular = singular[:, :approx_rank].astype(np.float32)
spatial = spatial[:, :approx_rank, :].astype(np.float32)
num_templates, num_samples, num_channels = templates_array.shape
temporal = np.zeros((num_templates, num_samples, approx_rank), dtype=np.float32)
spatial = np.zeros((num_templates, approx_rank, num_channels), dtype=np.float32)
singular = np.zeros((num_templates, approx_rank), dtype=np.float32)

for i in range(num_templates):
i_temporal, i_singular, i_spatial = np.linalg.svd(templates_array[i], full_matrices=False)
temporal[i, :, : min(approx_rank, num_channels)] = i_temporal[:, :approx_rank]
spatial[i, : min(approx_rank, num_channels), :] = i_spatial[:approx_rank, :]
singular[i, : min(approx_rank, num_channels)] = i_singular[:approx_rank]

if return_new_templates:
templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial)
Expand Down
141 changes: 133 additions & 8 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from spikeinterface.core.sparsity import ChannelSparsity
from spikeinterface.core.template import Templates
from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer
from spikeinterface.core.job_tools import split_job_kwargs
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.sparsity import ChannelSparsity
from spikeinterface.core.analyzer_extension_core import ComputeTemplates
Expand Down Expand Up @@ -249,19 +249,144 @@ def check_probe_for_drift_correction(recording, dist_x_max=60):
return True


def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs):
save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs)
def _set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None):
"""
Set the optimal chunk size for a job given the memory_limit and the number of jobs

if mode == "memory":
Parameters
----------

recording: Recording
The recording object
job_kwargs: dict
The job kwargs
memory_limit: float
The memory limit in fraction of available memory
total_memory: str, Default None
The total memory to use for the job in bytes

Returns
-------

job_kwargs: dict
The updated job kwargs
"""
job_kwargs = fix_job_kwargs(job_kwargs)
n_jobs = job_kwargs["n_jobs"]
if total_memory is None:
if HAVE_PSUTIL:
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1["
memory_usage = memory_limit * psutil.virtual_memory().available
if recording.get_total_memory_size() < memory_usage:
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs)
num_channels = recording.get_num_channels()
dtype_size_bytes = recording.get_dtype().itemsize
chunk_size = memory_usage / ((num_channels * dtype_size_bytes) * n_jobs)
chunk_duration = chunk_size / recording.get_sampling_frequency()
job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s"))
job_kwargs = fix_job_kwargs(job_kwargs)
else:
import warnings

warnings.warn("psutil is required to use only a fraction of available memory")
else:
from spikeinterface.core.job_tools import convert_string_to_bytes

total_memory = convert_string_to_bytes(total_memory)
num_channels = recording.get_num_channels()
dtype_size_bytes = recording.get_dtype().itemsize
chunk_size = (num_channels * dtype_size_bytes) * n_jobs / total_memory
chunk_duration = chunk_size / recording.get_sampling_frequency()
job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s"))
job_kwargs = fix_job_kwargs(job_kwargs)
return job_kwargs


def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25):
"""
Set the optimal chunk size for a job given the memory_limit and the number of jobs

Parameters
----------

recording: Recording
The recording object
ram_requested: int
The amount of RAM (in bytes) requested for the job
memory_limit: float
The memory limit in fraction of available memory

Returns
-------

job_kwargs: dict
The updated job kwargs
"""
job_kwargs = fix_job_kwargs(job_kwargs)
n_jobs = job_kwargs["n_jobs"]
if HAVE_PSUTIL:
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1["
memory_usage = memory_limit * psutil.virtual_memory().available
n_jobs = max(1, int(min(n_jobs, memory_usage // ram_requested)))
job_kwargs.update(dict(n_jobs=n_jobs))
else:
import warnings

warnings.warn("psutil is required to use only a fraction of available memory")
return job_kwargs


def cache_preprocessing(
recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs
):
"""
Cache the preprocessing of a recording object

Parameters
----------

recording: Recording
The recording object
mode: str
The mode to cache the preprocessing, can be 'memory', 'folder', 'zarr' or 'no-cache'
memory_limit: float
The memory limit in fraction of available memory
total_memory: str, Default None
The total memory to use for the job in bytes
delete_cache: bool
If True, delete the cache after the job
**extra_kwargs: dict
The extra kwargs for the job

Returns
-------

recording: Recording
The cached recording object
"""

save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs)

if mode == "memory":
if total_memory is None:
if HAVE_PSUTIL:
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1["
memory_usage = memory_limit * psutil.virtual_memory().available
if recording.get_total_memory_size() < memory_usage:
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs)
else:
import warnings

warnings.warn("Recording too large to be preloaded in RAM...")
else:
print("Recording too large to be preloaded in RAM...")
import warnings

warnings.warn("psutil is required to preload in memory given only a fraction of available memory")
else:
print("psutil is required to preload in memory")
if recording.get_total_memory_size() < total_memory:
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs)
else:
import warnings

warnings.warn("Recording too large to be preloaded in RAM...")
elif mode == "folder":
recording = recording.save_to_folder(**extra_kwargs)
elif mode == "zarr":
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/crosscorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

if i < len(self.axes) - 1:
self.axes[i, j].set_xticks([], [])
plt.tight_layout()
self.figure.tight_layout()

for i, unit_id in enumerate(unit_ids):
self.axes[0, i].set_title(str(unit_id))
Expand Down
Loading