"""
detect.py - Detection module of DAS4Whales package
This module provides functions for detecting whale calls in DAS strain data.
Author: Quentin Goestchel
Date: 2023-2024
"""
from __future__ import annotations
import concurrent.futures
from typing import Dict, List, Tuple, Union, Optional, Any
import librosa
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal as sp
import scipy.stats as st
from tqdm import tqdm
from das4whales.plot import import_roseus
## Matched filter detection functions:
[docs]
def gen_linear_chirp(fmin: float, fmax: float, duration: float, sampling_rate: int) -> np.ndarray:
"""Generate a linear chirp signal.
Parameters
----------
fmin : float
The ending frequency of the chirp signal.
fmax : float
The starting frequency of the chirp signal.
duration : float
The duration of the chirp signal in seconds.
sampling_rate : int
The sampling rate of the chirp signal in Hz.
Returns
-------
numpy.ndarray
The generated linear chirp signal.
"""
t = np.arange(0, duration, 1/sampling_rate)
y = sp.chirp(t, f0=fmax, f1=fmin, t1=duration, method='linear')
return y
[docs]
def gen_hyperbolic_chirp(fmin: float, fmax: float, duration: float, sampling_rate: int) -> np.ndarray:
"""Generate a hyperbolic chirp signal.
Parameters
----------
fmin : float
The ending frequency of the chirp signal.
fmax : float
The starting frequency of the chirp signal.
duration : float
The duration of the chirp signal in seconds.
sampling_rate : int
The sampling rate of the chirp signal in Hz.
Returns
-------
numpy.ndarray
The generated hyperbolic chirp signal.
"""
t = np.arange(0, duration, 1/sampling_rate)
y = sp.chirp(t, f0=fmax, f1=fmin, t1=duration, method='hyperbolic')
return y
[docs]
def gen_template_fincall(time: np.ndarray, fs: float, fmin: float = 15., fmax: float = 25., duration: float = 1., window: bool = True) -> np.ndarray:
""" generate template of a fin whale call
Parameters
----------
time : numpy.ndarray
time vector
fs : float
sampling rate in Hz
fmin : float, optional
Minimum frequency, by default 15
fmax : float, optional
Maximum frequency, by default 25
duration : float, optional
Duration of the chirp signal in seconds, by default 1.
"""
# 1 Hz frequency buffer to compensate the windowing
df = 0
chirp_signal = gen_hyperbolic_chirp(fmin-df, fmax + df, duration, fs)
template = np.zeros(np.shape(time))
#TODO: remove the padding and keep just the short window values
if window:
template[:len(chirp_signal)] = chirp_signal * np.hanning(len(chirp_signal))
else:
template[:len(chirp_signal)] = chirp_signal
return template
[docs]
def shift_xcorr(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""compute the shifted (positive lags only) cross correlation between two 1D arrays
Parameters
----------
x : numpy.ndarray
1D array containing signal
y : numpy.ndarray
1D array containing signal
Returns
-------
numpy.ndarray
1D array cross-correlation betweem x and y, only for positive lags
"""
corr = sp.correlate(x, y, mode='full', method='fft')
# TODO: Modify to use with the short window values (mode = 'same' instead of 'full')
return corr[len(x)-1 :]
[docs]
def shift_nxcorr(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Compute the shifted (positive lags only) normalized cross-correlation with standard deviation normalization.
Parameters
----------
x : numpy.ndarray
first input signal.
y : numpy.ndarray
second input signal
Returns
-------
numpy.ndarray
The normalized cross-correlation between the two signals
"""
#TODO: Modify to use with the short window values (mode = 'same' instead of 'full')
# Compute cross-correlation
cross_corr = sp.correlate(x, y, mode='full', method='fft')
# Normalize using standard deviation
normalized_corr = cross_corr / (np.std(x) * np.std(y) * len(x))
return normalized_corr[len(x)-1 :]
[docs]
def compute_cross_correlogram(data: np.ndarray, template: np.ndarray) -> np.ndarray:
"""
Compute the cross correlogram between the given data and template.
Parameters
----------
data : numpy.ndarray
The input data array.
template : numpy.ndarray
The template array.
Returns
-------
numpy.ndarray
The cross correlogram array.
"""
# Normalize data along axis 1 by its maximum (peak normalization)
norm_data = (data - np.mean(data, axis=1, keepdims=True)) / np.max(np.abs(data), axis=1, keepdims=True)
template = (template - np.mean(template)) / np.max(np.abs(template))
# Compute correlation along axis 1
cross_correlogram = np.empty_like(data)
for i in tqdm(range(data.shape[0])):
cross_correlogram[i, :] = shift_xcorr(norm_data[i, :], template)
return cross_correlogram
[docs]
def calc_nmf(data: np.ndarray, template: np.ndarray) -> np.ndarray:
"""
Calculate the normalized matched filter between the input data and the template.
Parameters
----------
data : numpy.ndarray
The input data array.
template : numpy.ndarray
The template array.
Returns
-------
numpy.ndarray
The normalized matched filter array (vector).
"""
nmf = sp.correlate(data, template, mode='same', method='fft') / np.sqrt((np.sum(data ** 2) * np.sum(template ** 2)))
return nmf
[docs]
def calc_nmf_correlogram(data: np.ndarray, template: np.ndarray) -> np.ndarray:
"""
Calculate the normalized matched filter correlogram between the input data and the template.
Parameters
----------
data : numpy.ndarray
The input data array.
template : numpy.ndarray
The template array.
Returns
-------
numpy.ndarray
The normalized matched filter correlogram array.
"""
# Normalize data along axis 1 by its maximum (peak normalization)
norm_data = (data - np.mean(data, axis=1, keepdims=True)) / np.max(np.abs(data), axis=1, keepdims=True)
template = (template - np.mean(template)) / np.max(np.abs(template))
# Compute correlation along axis 1
nmf_correlogram = np.empty_like(data)
for i in tqdm(range(data.shape[0])):
nmf_correlogram[i, :] = calc_nmf(data[i, :], template)
# Parallelized version:
# with concurrent.futures.ThreadPoolExecutor() as executor:
# results = [executor.submit(calc_nmf, data[i, :], template) for i in range(data.shape[0])]
# # Use tqdm to display a progress bar for the as_completed iterator
# for i, future in enumerate(tqdm(concurrent.futures.as_completed(results), total=len(results))):
# nmf_correlogram[i, :] = future.result()
return nmf_correlogram
[docs]
def pick_times_env(corr_m: np.ndarray, threshold: float) -> List[np.ndarray]:
"""Detects the peak times in a correlation matrix. Parallelized version : pick_times_par
This function takes a correlation matrix, computes the Hilbert transform of each correlation,
and detects the peak times based on a given threshold.
Parameters
----------
corr_m : numpy.ndarray
The correlation matrix.
threshold : float, optional
The threshold value for peak detection. Defaults to 0.3.
Returns
-------
list
A list of arrays, where each array contains the peak indexes for each correlation.
"""
peaks_indexes_m = []
for corr in tqdm(corr_m, desc="Processing corr_m"):
peaks_indexes = sp.find_peaks(abs(sp.hilbert(corr)), prominence=threshold)[0] # Change distance in indexes, ex: 'distance=200'
peaks_indexes,_ = sp.find_peaks(corr, distance = ipi * fs, height=th)
peaks_indexes_m.append(peaks_indexes)
return peaks_indexes_m
[docs]
def process_corr(corr: np.ndarray, threshold: float) -> np.ndarray:
"""Detects the peak times in a correlation serie, kernel for parallelization.
This function takes a correlation serie, computes the Hilbert transform of the correlation, and detects the peak times based on a given threshold.
Parameters
----------
corr : np.ndarray
The correlogram array.
threshold : float, optional
The threshold value for peak detection. Defaults to 0.3.
Returns
-------
np.ndarray
The peak indexes for the given correlation.
"""
peaks_indexes = sp.find_peaks(abs(sp.hilbert(corr)), prominence=threshold)[0]
return peaks_indexes
[docs]
def pick_times_par(corr_m: np.ndarray, threshold: float) -> List[np.ndarray]:
"""Detects the peak times in a correlation matrix using parallel processing.
This function takes a correlation matrix, computes the Hilbert transform of each correlation,
and detects the peak times based on a given threshold using parallel processing.
Parameters
----------
corr_m : numpy.ndarray
The correlation matrix.
threshold : float, optional
The threshold value for peak detection. Defaults to 0.3.
Returns
-------
list
A list of arrays, where each array contains the peak indexes for each correlation.
"""
peaks_indexes_m = []
with concurrent.futures.ThreadPoolExecutor() as executor:
results = [executor.submit(process_corr, corr, threshold) for corr in corr_m]
for result in concurrent.futures.as_completed(results):
peaks_indexes_m.append(result.result())
return peaks_indexes_m
[docs]
def pick_times(corr_m: np.ndarray, threshold: float, ipi_idx: int) -> List[np.ndarray]:
"""Detects the peak times in a correlation matrix.
This function takes a correlation matrix, computes the Hilbert transform of each correlation,
and detects the peak times based on a given threshold.
Parameters
----------
corr_m : numpy.ndarray
The correlation matrix.
threshold : float, optional
The threshold value for peak detection. Defaults to 0.3.
ipi_idx : int
The minimum inter-pulse interval in indexes.
Returns
-------
list
A list of arrays, where each array contains the peak indexes for each correlation.
"""
peaks_indexes_m = []
for corr in tqdm(corr_m, desc=f"Picking times, threshold: {threshold}, ipi: {ipi_idx} time samples"):
peaks_indexes,_ = sp.find_peaks(corr, distance = ipi_idx, height=threshold)
peaks_indexes_m.append(peaks_indexes)
return peaks_indexes_m
[docs]
def convert_pick_times(peaks_indexes_m: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
"""
Convert pick times from a list of lists to a numpy array.
Parameters
----------
peaks_indexes_m : list of lists
A list of lists containing the pick times. The indexes of each list correspond to the space index.
[[t1, t2, t3, ...], [t1, t2, t3, ...], ...]
Returns
-------
numpy.ndarray
A numpy array containing a tuple (time index, spatial index) of the converted pick times.
"""
peaks_indexes_tp = ([], [])
for i in range(len(peaks_indexes_m)):
nb_el = len(peaks_indexes_m[i])
for j in range(nb_el):
peaks_indexes_tp[0].append(i)
for el in peaks_indexes_m[i]:
peaks_indexes_tp[1].append(el)
peaks_indexes_tp = np.asarray(peaks_indexes_tp)
# TODO: test = np.column_stack((nlf_assoc_list[0][0], nlf_assoc_list[0][1]))
return peaks_indexes_tp
[docs]
def select_picked_times(idx_tp, tstart, tend, fs):
"""
Select the picked times within a given time range.
Parameters
----------
idx_tp : numpy.ndarray
The time and spatial indexes of the picked times.
tstart : float
The starting time of the time range [s].
tend : float
The ending time of the time range [s].
fs : float
The sampling rate of the data.
Returns
-------
numpy.ndarray
The selected picked times within the given time range (time index, spatial index).
"""
idx_tp_selected = (idx_tp[0][(idx_tp[1] >= tstart * fs) & (idx_tp[1] <= tend * fs)],
idx_tp[1][(idx_tp[1] >= tstart * fs) & (idx_tp[1] <= tend * fs)])
return idx_tp_selected
## Spectrogram correlation functions:
[docs]
def get_sliced_nspectrogram(trace, fs, fmin, fmax, nperseg, nhop, plotflag=False):
"""
Compute the sliced non-stationary spectrogram of a given trace.
Parameters
----------
trace : numpy.ndarray
The input trace signal.
fs : float
The sampling rate of the trace signal.
fmin : float
The minimum frequency of interest.
fmax : float
The maximum frequency of interest.
nperseg : int
The length of each segment for the spectrogram computation.
nhop : int
The number of samples to advance between segments.
plotflag : bool, optional
Whether to plot the spectrogram, defaults to False.
Returns
-------
spectrogram : numpy.ndarray
The computed spectrogram.
ff : ndarray
The frequency values of the spectrogram.
tt : ndarray
The time values of the spectrogram.
Notes
-----
This function computes the non-stationary spectrogram of a given trace signal.
The spectrogram is computed using the Short-Time Fourier Transform (STFT) with
a specified segment length and hop size. The resulting spectrogram is then sliced
between the specified minimum and maximum frequencies of interest.
Examples
--------
>>> trace = np.random.randn(1000)
>>> fs = 1000
>>> fmin = 10
>>> fmax = 100
>>> nperseg = 256
>>> nhop = 128
>>> spectrogram, ff, tt = get_sliced_nspectrogram(trace, fs, fmin, fmax, nperseg, nhop, plotflag=True)
"""
spectrogram = np.abs(librosa.stft(y=trace, n_fft=nperseg, hop_length=nhop))
# Axis
nf, nt = spectrogram.shape
tt = np.linspace(0, len(trace)/fs, num=nt)
ff = np.linspace(0, fs / 2, num=nf)
p = spectrogram / np.max(spectrogram)
# Slice the spectrogram betweem fmin and fmax
ff_idx = np.where((ff >= fmin) & (ff <= fmax))
p = p[ff_idx]
ff = ff[ff_idx]
if plotflag:
roseus = import_roseus()
fig, ax = plt.subplots(figsize=(12,4))
shw = ax.pcolormesh(tt, ff, 20 * np.log10(p / np.max(p)), cmap=roseus, vmin=None, vmax=None)
# Colorbar
bar = fig.colorbar(shw, aspect=20, pad=0.015)
bar.set_label('Normalized magnitude [-]')
plt.xlim(0, len(trace)/fs)
plt.ylim(fmin, fmax)
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.tight_layout()
plt.show()
return p, ff, tt
[docs]
def buildkernel(f0, f1, bdwdth, dur, f, t, samp, fmin, fmax, plotflag=False):
"""
Calculate kernel and plot.
Parameters
----------
f0 : float
Starting frequency.
f1 : float
Ending frequency.
bdwdth : float
Frequency width of call.
dur : float
Call length (seconds).
f : np.array
Vector of frequencies returned from plotwav.
t : np.array
Vector of times returned from plotwav.
samp : float
Sample rate.
plotflag : bool, optional
If True, plots kernel. If False, no plot. Default is False.
kernel_lims : tuple, optional
Tuple of minimum kernel range and maximum kernel range. Default is finKernelLims.
Returns
-------
tvec : numpy.array
Vector of kernel times.
fvec_sub : numpy.array
Vector of kernel frequencies.
BlueKernel : 2-d numpy.array
Matrix of kernel values.
Key variables:
-------------
tvec : numpy.array
Kernel times (seconds).
fvec : numpy.array
Kernel frequencies.
BlueKernel : numpy.array
Matrix of kernel values.
"""
# create a time vector of the same length as the call, with the same number of points as the spectrogram
tvec = np.linspace(0, dur, np.size(np.nonzero((t < dur*8) & (t > dur*7))))
# another way: int(dur * fs / (nperseg * (1-overlap_pct)) + 1)
# define frequency span of kernel to match spectrogram
fvec = f
# preallocate kernel array
Kdist = np.zeros((len(fvec), len(tvec)))
ker_min, ker_max = fmin, fmax
for j in range(len(tvec)):
# calculate hat function that is centered on linearly decreasing
# frequency values for each time in tvec
# Linearly decreasing frequency values
# x = fvec - (f0 + (tvec[j] / dur) * (f1 - f0))
# Hyperbolic decreasing frequency values
x = fvec - (f0 * f1 * dur / ((f0 - f1) * (tvec[j]) + f1 * dur))
Kval = (1 - np.square(x) / (bdwdth * bdwdth)) * np.exp(-np.square(x) / (2 * (bdwdth * bdwdth)))
# store hat function values in preallocated array
Kdist[:, j] = Kval
BlueKernel = Kdist * np.hanning(len(tvec))[np.newaxis, :]
# freq_inds = np.where(np.logical_and(fvec >= ker_min, fvec <= ker_max))
# fvec_sub = fvec[freq_inds]
# BlueKernel = BlueKernel_full[freq_inds, :][0]
if plotflag:
plt.figure(figsize=(1, 5))
img = plt.pcolormesh(tvec, fvec, BlueKernel, cmap="RdBu_r", vmin=-np.max(np.abs(BlueKernel)), vmax=np.max(np.abs(BlueKernel)),)
plt.axis([0, dur, np.min(fvec), np.max(fvec)])
plt.colorbar(img, format='%.1f')
plt.clim(-1, 1)
plt.ylim(ker_min, ker_max)
plt.title('Fin whale call kernel')
plt.xlabel('t [s]')
plt.ylabel('f [Hz]')
plt.show()
return tvec, fvec, BlueKernel
[docs]
def buildkernel_from_template(fmin, fmax, dur, fs, nperseg, nhop, plotflag=False):
"""
Build a kernel from a template.
Parameters
----------
fmin : float
The minimum frequency of interest.
fmax : float
The maximum frequency of interest.
dur : float
The duration of the kernel in seconds.
fs : float
The sampling rate of the kernel in Hz.
nperseg : int
The length of each segment for the spectrogram computation.
nhop : int
The number of samples to advance between segments.
plotflag : bool, optional
Whether to plot the kernel, defaults to False.
Returns
-------
numpy.ndarray
The computed kernel.
"""
template = gen_hyperbolic_chirp(fmin, fmax, dur, fs)
template *= np.hanning(len(template))
spectro, ff, tt = get_sliced_nspectrogram(template, fs, fmin, fmax, nperseg, nhop, plotflag=False)
if plotflag:
roseus = import_roseus()
fig, ax = plt.subplots(figsize=(2,4))
shw = ax.pcolormesh(tt, ff, spectro, cmap=roseus, vmin=None, vmax=None)
# Colorbar
bar = fig.colorbar(shw, aspect=20, pad=0.015)
bar.set_label('Normalized magnitude [-]')
plt.xlim(0, dur)
plt.ylim(fmin, fmax)
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.tight_layout()
plt.show()
return spectro
[docs]
def nxcorr2d(spectro, kernel):
"""
Calculate the normalized cross-correlation between a spectrogram and a kernel.
Parameters
----------
spectro : numpy.ndarray
The spectrogram array.
kernel : numpy.ndarray
The kernel array.
Returns
-------
numpy.ndarray
The maximum correlation values along the time axis.
Notes
-----
The normalized cross-correlation is calculated using `scipy.signal.correlate2d`.
The correlation values are normalized by dividing by the standard deviation of the spectrogram and the kernel,
multiplied by the number of columns in the spectrogram.
Examples
--------
>>> spectro = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> kernel = np.array([[1, 0], [0, 1]])
>>> nxcorr2d(spectro, kernel)
array([0. , 0.33333333, 0.66666667])
"""
correlation = sp.correlate(spectro, kernel, mode='same', method='fft') / (np.std(spectro) * np.std(kernel) * spectro.shape[1])
maxcorr_t = np.max(correlation, axis=0)
return maxcorr_t
[docs]
def xcorr2d(spectro, kernel):
"""
Calculate the 2D cross-correlation between a spectrogram and a kernel.
Parameters
----------
spectro : numpy.ndarray
The input spectrogram array [frequency x time].
kernel : numpy.ndarray
The kernel array used for cross-correlation [frequency x time].
Returns
-------
numpy.ndarray
The resulting cross-correlation array.
"""
correlation = sp.fftconvolve(spectro, np.flip(kernel, axis=1), mode='same', axes=1)
maxcorr_t = np.sum(correlation, axis=0)
maxcorr_t[maxcorr_t < 0] = 0
maxcorr_t /= (np.median(spectro) * kernel.shape[1])
return maxcorr_t
[docs]
def xcorr(t, f, Sxx, tvec, fvec, BlueKernel):
"""
Cross-correlate kernel with spectrogram
Parameters
----------
t : np.array
Vector of times returned from plotwav
f : np.array
Vector of frequencies returned from plotwav
Sxx : np.array
2-D array of spectrogram amplitudes
tvec : np.array
Vector of times of kernel
fvec : np.array
Vector of frequencies of kernel
BlueKernel : np.array
2-D array of kernel amplitudes
Returns
-------
t_scale : numpy.array
Vector of correlation lags
CorrVal : numpy.array
Vector of correlation values
"""
tvec_size = np.size(tvec)
fvec_size = np.size(fvec)
CorrVal = np.zeros(np.size(t) - (tvec_size-1))
corrchunk= np.zeros((fvec_size, tvec_size))
for ind1 in range(np.size(t) - tvec_size + 1):
ind2 = ind1 + tvec_size
corrchunk = Sxx[:fvec_size, ind1:ind2]
CorrVal[ind1] = np.sum(BlueKernel * corrchunk)
CorrVal /= (np.median(Sxx)*tvec_size)
CorrVal[0] = 0
CorrVal[-1] = 0
CorrVal[CorrVal < 0] = 0
t_scale = t[int(tvec_size / 2)-1:-int(np.ceil(tvec_size / 2))]
return [t_scale, CorrVal]
[docs]
def compute_cross_correlogram_spectrocorr(data, fs, flims, kernel, win_size, overlap_pct):
"""Compute the cross-correlogram via spectrogram correlation.
This function computes the cross-correlogram spectrocorr between the input data and a kernel.
The cross-correlogram spectrocorr is a measure of similarity between the spectrogram of the input data
and the kernel.
Parameters
----------
data : ndarray
Input data array of shape (n, m), where n is the number of samples and m is the number of channels.
fs : float
Sampling frequency of the input data.
flims : tuple
Frequency limits (fmin, fmax) for the spectrogram computation.
kernel : dict
Dictionary containing the kernel parameters (f0, f1, duration, bandwidth).
win_size : float
Window size in seconds for the spectrogram computation.
overlap_pct : float
Percentage of overlap between consecutive windows for the spectrogram computation.
Returns
-------
cross_correlogram : ndarray
Cross-correlogram spectrocorr array of shape (n, p), where n is the number of samples and p is the number of time bins.
"""
norm_data = (data - np.mean(data, axis=1, keepdims=True)) / np.max(np.abs(data), axis=1, keepdims=True)
nperseg = int(win_size * fs)
nhop = int(np.floor(nperseg * (1 - overlap_pct)))
noverlap = nperseg - nhop
print(f'nperseg: {nperseg}, noverlap: {noverlap}, hop_length: {nhop}')
fmin, fmax = flims
# get call kernel attributes
f1 = kernel["f1"]
f0 = kernel["f0"]
duration = kernel["dur"]
bandwidth = kernel["bdwidth"]
# check that hat function is within frequency range of spectrogram
if fmax-f1 < 2 * bandwidth:
fmax = f1 + 3 * bandwidth
if f0-fmin < 2 * bandwidth:
fmin = f0 - 3 * bandwidth
# Compute correlation along axis 1
spectro, ff, tt = get_sliced_nspectrogram(data[0, :], fs, fmin, fmax, nperseg, nhop, plotflag=False)
# TODO: Try weighting the spectrogram with the Cable frequency response (channel, bearing dependant)
cross_correlogram = np.empty((data.shape[0], len(tt)))
_, _, kernel = buildkernel(f0, f1, bandwidth, duration, ff, tt, fs, fmin, fmax, plotflag=False)
# kernel = buildkernel_from_template(fmin, fmax, duration, fs, nperseg, nhop, plotflag=False)
for i in tqdm(range(data.shape[0])):
spectro, _, _ = get_sliced_nspectrogram(data[i, :], fs, fmin, fmax, nperseg, nhop, plotflag=False)
cross_correlogram[i, :] = xcorr2d(spectro, kernel)
return cross_correlogram
[docs]
def resolve_hf_lf_crosstalk(input_peaks: np.ndarray, comp_peaks: np.ndarray,
input_SNR: np.ndarray, comp_SNR: np.ndarray, dt_tol: int, dx_tol: int):
#TODO: maybe parallelize this function to speed up the process
""" Sort peaks that are at the same distance and time but keep the one with higher SNR
to differentiate between HF and LF peaks.
Parameters
----------
input_peaks : np.ndarray
Array of shape (2, n_peaks) containing [distance_indices, time_indices] of the input peaks.
comp_peaks : np.ndarray
Array of shape (2, n_peaks) containing [distance_indices, time_indices] of the comparison peaks.
input_SNR : np.ndarray
Array of shape (n_peaks_input) containing the SNR values for the input peaks
comp_SNR : np.ndarray
Array of shape (n_peaks_comp) containing the SNR values for the comparison peaks
dt_tol : int
Tolerance in time index to consider peaks as matching.
dx_tol : int
Tolerance in distance index to consider peaks as matching.
"""
# Make copies to avoid modifying input arrays
input_peaks = input_peaks.copy()
comp_peaks = comp_peaks.copy()
input_SNR = input_SNR.copy()
comp_SNR = comp_SNR.copy()
# Track which peaks to keep (start with all True)
input_keep = np.ones(input_peaks.shape[1], dtype=bool)
comp_keep = np.ones(comp_peaks.shape[1], dtype=bool)
ix = comp_peaks[0, :]
it = comp_peaks[1, :]
for i, (d, t) in tqdm(enumerate(zip(ix, it)), total=len(ix), desc="Post-filtering hf/lf detections"):
# Skip if this comparison peak is already marked for removal
if not comp_keep[i]:
continue
# Find matching input peaks within tolerance
dist_match = np.abs(input_peaks[0, :] - d) <= dx_tol
time_match = np.abs(input_peaks[1, :] - t) <= dt_tol
mask = dist_match & time_match
# Only consider peaks that are still marked to keep
valid_mask = mask & input_keep
if np.sum(valid_mask) > 0:
# print(f"Found {np.sum(valid_mask)} matching input peaks for comparison peak {i} at distance {d} and time {t}.")
# Get indices of valid matching input peaks
input_match_indices = np.where(valid_mask)[0]
# Compare with the first valid matching input peak
input_idx = input_match_indices[0]
if input_SNR[input_idx] > comp_SNR[i]:
# Mark comparison peak for removal
comp_keep[i] = False
# print(f"Removing comparison peak {i} (SNR: {comp_SNR[i]:.2f}) in favor of input peak {input_idx} (SNR: {input_SNR[input_idx]:.2f})")
else:
# Mark all matching input peaks for removal
input_keep[input_match_indices] = False
# print(f"Removing {len(input_match_indices)} input peaks in favor of comparison peak {i} (SNR: {comp_SNR[i]:.2f})")
# Filter arrays to keep only selected peaks
input_peaks_out = input_peaks[:, input_keep]
input_SNR_out = input_SNR[input_keep]
comp_peaks_out = comp_peaks[:, comp_keep]
comp_SNR_out = comp_SNR[comp_keep]
return input_peaks_out, input_SNR_out, comp_peaks_out, comp_SNR_out