import numpy as np
import warnings
import logging
from . import spike_features as spkf
from . import error as er
[docs]def basic_spike_train_features(t, spikes_df, start, end, exclude_clipped=False):
features = {}
if len(spikes_df) == 0 or spikes_df.empty:
features["avg_rate"] = 0
return features
thresholds = spikes_df["threshold_index"].values.astype(int)
if exclude_clipped:
mask = spikes_df["clipped"].values.astype(bool)
thresholds = thresholds[~mask]
isis = get_isis(t, thresholds)
with warnings.catch_warnings():
# ignore mean of empty slice warnings here
warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
features = {
"adapt": adaptation_index(isis),
"latency": latency(t, thresholds, start),
"isi_cv": (isis.std() / isis.mean()) if len(isis) >= 1 else np.nan,
"mean_isi": isis.mean() if len(isis) > 0 else np.nan,
"median_isi": np.median(isis),
"first_isi": isis[0] if len(isis) >= 1 else np.nan,
"avg_rate": average_rate(t, thresholds, start, end),
}
return features
[docs]def pause(t, spikes_df, start, end, cost_weight=1.0):
"""Estimate average number of pauses and average fraction of time spent in a pause
Attempts to detect pauses with a variety of conditions and averages results together.
Pauses that are consistently detected contribute more to estimates.
Returns
-------
avg_n_pauses : average number of pauses detected across conditions
avg_pause_frac : average fraction of interval (between start and end) spent in a pause
max_reliability : max fraction of times most reliable pause was detected given weights tested
n_max_rel_pauses : number of pauses detected with `max_reliability`
"""
warnings.warn("This function will be removed")
# Pauses are unusually long ISIs with a "detour reset" among delay resets
thresholds = spikes_df["threshold_index"].values.astype(int)
isis = get_isis(t, thresholds)
isi_types = spikes_df["isi_type"][:-1].values
pause_list = spkf.detect_pauses(isis, isi_types, cost_weight)
if len(pause_list) == 0:
return 0, 0.
n_pauses = len(pause_list)
pause_frac = isis[pause_list].sum()
pause_frac /= end - start
return n_pauses, pause_frac
[docs]def burst(t, spikes_df, tol=0.5, pause_cost=1.0):
"""Find bursts and return max "burstiness" index (normalized max rate in burst vs out).
Returns
-------
max_burstiness_index : max "burstiness" index across detected bursts
num_bursts : number of bursts detected
"""
warnings.warn("This function will be removed")
thresholds = spikes_df["threshold_index"].values.astype(int)
isis = get_isis(t, thresholds)
isi_types = spikes_df["isi_type"][:-1].values
fast_tr_v = spikes_df["fast_trough_v"].values
fast_tr_t = spikes_df["fast_trough_t"].values
slow_tr_v = spikes_df["slow_trough_v"].values
slow_tr_t = spikes_df["slow_trough_t"].values
thr_v = spikes_df["threshold_v"].values
bursts = spkf.detect_bursts(isis, isi_types,
fast_tr_v, fast_tr_t,
slow_tr_v, slow_tr_t,
thr_v, tol, pause_cost)
burst_info = np.array(bursts)
if burst_info.shape[0] > 0:
return burst_info[:, 0].max(), burst_info.shape[0]
else:
return 0., 0
[docs]def delay(t, v, spikes_df, start, end):
"""Calculates ratio of latency to dominant time constant of rise before spike
Returns
-------
delay_ratio : ratio of latency to tau (higher means more delay)
tau : dominant time constant of rise before spike
"""
warnings.warn("This function will be removed")
if len(spikes_df) == 0:
logging.info("No spikes available for delay calculation")
return 0., 0.
spike_time = spikes_df["threshold_t"].values[0]
tau = spkf.fit_prespike_time_constant(t, v, start, spike_time)
latency = spike_time - start
delay_ratio = latency / tau
return delay_ratio, tau
[docs]def fit_fi_slope(stim_amps, avg_rates):
"""Fit the rate and stimulus amplitude to a line and return the slope of the fit."""
if len(stim_amps) < 2:
raise er.FeatureError("Cannot fit f-I curve slope with less than two sweeps")
x = stim_amps
y = avg_rates
A = np.vstack([x, np.ones_like(x)]).T
m, c = np.linalg.lstsq(A, y,rcond=None)[0]
return m
[docs]def get_isis(t, spikes):
"""Find interspike intervals in sec between spikes (as indexes)."""
if len(spikes) <= 1:
return np.array([])
return t[spikes[1:]] - t[spikes[:-1]]
[docs]def adaptation_index(isis):
"""Calculate adaptation index of `isis`."""
if len(isis) == 0:
return np.nan
return norm_diff(isis)
[docs]def latency(t, spikes, start):
"""Calculate time to the first spike."""
if len(spikes) == 0:
return np.nan
if start is None:
start = t[0]
return t[spikes[0]] - start
[docs]def average_rate(t, spikes, start, end):
"""Calculate average firing rate during interval between `start` and `end`.
Parameters
----------
t : numpy array of times in seconds
spikes : numpy array of spike indexes
start : start of time window for spike detection
end : end of time window for spike detection
Returns
-------
avg_rate : average firing rate in spikes/sec
"""
if start is None:
start = t[0]
if end is None:
end = t[-1]
spikes_in_interval = [spk for spk in spikes if t[spk] >= start and t[spk] <= end]
avg_rate = len(spikes_in_interval) / (end - start)
return avg_rate
[docs]def norm_diff(a):
"""Calculate average of (a[i] - a[i+1]) / (a[i] + a[i+1])."""
if len(a) <= 1:
return np.nan
a = a.astype(float)
if np.allclose((a[1:] + a[:-1]), 0.):
return 0.
norm_diffs = (a[1:] - a[:-1]) / (a[1:] + a[:-1])
norm_diffs[(a[1:] == 0) & (a[:-1] == 0)] = 0.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
avg = np.nanmean(norm_diffs)
return avg
[docs]def norm_sq_diff(a):
"""Calculate average of (a[i] - a[i+1])^2 / (a[i] + a[i+1])^2."""
if len(a) <= 1:
return np.nan
a = a.astype(float)
norm_sq_diffs = np.square((a[1:] - a[:-1])) / np.square((a[1:] + a[:-1]))
return norm_sq_diffs.mean()
[docs]def detect_pauses(isis, isi_types, cost_weight=1.0):
"""Determine which ISIs are "pauses" in ongoing firing.
Pauses are unusually long ISIs with a "detour reset" among "direct resets".
Parameters
----------
isis : numpy array of interspike intervals
isi_types : numpy array of interspike interval types ('direct' or 'detour')
cost_weight : weight for cost function for calling an ISI a pause
Higher cost weights lead to fewer ISIs identified as pauses. The cost function
also depends on the difference between the duration of the "pause" ISIs and the
average duration and standard deviation of "non-pause" ISIs.
Returns
-------
pauses : numpy array of indices corresponding to pauses in `isis`
"""
if len(isis) != len(isi_types):
raise er.FeatureError("Wrong number of ISIs")
if not np.any(isi_types == "direct"):
# Need some direct-type firing to have pauses
return np.array([])
detour_candidates = [i for i, isi_type in enumerate(isi_types) if isi_type == "detour"]
median_direct = np.median(isis[isi_types == "direct"])
direct_candidates = [i for i, isi_type in enumerate(isi_types) if isi_type == "direct" and isis[i] > 3 * median_direct]
candidates = detour_candidates + direct_candidates
if not candidates:
return np.array([])
pause_list = np.array([], dtype=int)
all_cv = isis.std() / isis.mean()
best_net = 0
for i in candidates:
temp_pause_list = np.append(pause_list, i)
non_pause_isis = np.delete(isis, temp_pause_list)
pause_isis = isis[temp_pause_list]
if len(non_pause_isis) < 2:
break
cv = non_pause_isis.std() / non_pause_isis.mean()
benefit = all_cv - cv
cost = np.sum(non_pause_isis.std() / np.abs(non_pause_isis.mean() - pause_isis))
cost *= cost_weight
net = benefit - cost
if net > 0 and net < best_net:
break
if net > best_net:
best_net = net
pause_list = np.append(pause_list, i)
if best_net <= 0:
pause_list = np.array([])
return np.sort(pause_list)
[docs]def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t,
thr_v, tol=0.5, pause_cost=1.0):
"""Detect bursts in spike train.
Parameters
----------
isis : numpy array of n interspike intervals
isi_types : numpy array of n interspike interval types
fast_tr_v : numpy array of fast trough voltages for the n + 1 spikes of the train
fast_tr_t : numpy array of fast trough times for the n + 1 spikes of the train
slow_tr_v : numpy array of slow trough voltages for the n + 1 spikes of the train
slow_tr_t : numpy array of slow trough times for the n + 1 spikes of the train
thr_v : numpy array of threshold voltages for the n + 1 spikes of the train
tol : tolerance for the difference in slow trough voltages and thresholds (default 0.5 mV)
Used to identify "delay" interspike intervals that occur within a burst
Returns
-------
bursts : list of bursts
Each item in list is a tuple of the form (burst_index, start, end) where `burst_index`
is a comparison index between the highest instantaneous rate within the burst vs
the highest instantaneous rate outside the burst. `start` is the index of the first
ISI of the burst, and `end` is the ISI index immediately following the burst.
"""
if len(isis) != len(isi_types):
raise er.FeatureError("Wrong number of ISIs")
if len(isis) < 2: # can't determine burstiness for a single ISI
return np.array([])
fast_tr_v = fast_tr_v[:-1]
fast_tr_t = fast_tr_t[:-1]
slow_tr_v = slow_tr_v[:-1]
slow_tr_t = slow_tr_t[:-1]
isi_types = np.array(isi_types) # don't want to change the actual isi types data
# Burst transitions can't be at "pause"-like ISIs
pauses = detect_pauses(isis, isi_types, cost_weight=pause_cost).astype(int)
isi_types[pauses] = "pauselike"
if not (np.any(isi_types == "direct") and np.any(isi_types == "detour")):
# no candidates that could be bursts
return np.array([])
# Want to catch special case of detour in the middle of a large burst where
# the slow trough value is higher than the previous spike's threshold
isi_types[(thr_v[:-1] < (slow_tr_v + tol)) & (isi_types == "detour")] = "midburst"
# Find transitions from direct -> detour and vice versa for burst boundaries
into_burst = np.array([i + 1 for i, (prev, cur) in
enumerate(zip(isi_types[:-1], isi_types[1:])) if
cur == "direct" and prev == "detour"],
dtype=int)
if isi_types[0] == "direct":
into_burst = np.append(np.array([0]), into_burst)
drop_into = []
out_of_burst = []
for j, (into, next) in enumerate(zip(into_burst, np.append(into_burst[1:], len(isis)))):
for i, isi in enumerate(isi_types[into + 1:next]):
if isi == "detour":
out_of_burst.append(i + into + 1)
break
elif isi == "pauselike":
drop_into.append(j)
break
mask = np.ones_like(into_burst, dtype=bool)
mask[drop_into] = False
into_burst = into_burst[mask]
out_of_burst = np.array(out_of_burst, dtype=int)
if len(out_of_burst) == len(into_burst) - 1:
out_of_burst = np.append(out_of_burst, len(isi_types))
if not (into_burst.size or out_of_burst.size):
return np.array([])
if len(into_burst) != len(out_of_burst):
raise er.FeatureError("Inconsistent burst boundary identification")
inout_pairs = list(zip(into_burst, out_of_burst))
delta_t = slow_tr_t - fast_tr_t
scores = _score_burst_set(inout_pairs, isis, delta_t)
best_score = np.mean(scores)
worst = np.argmin(scores)
test_bursts = list(inout_pairs)
del test_bursts[worst]
while len(test_bursts) > 0:
scores = _score_burst_set(test_bursts, isis, delta_t)
if np.mean(scores) > best_score:
best_score = np.mean(scores)
inout_pairs = list(test_bursts)
worst = np.argmin(scores)
del test_bursts[worst]
else:
break
if best_score < 0:
return np.array([])
bursts = []
for i, (into, outof) in enumerate(inout_pairs):
if i == len(inout_pairs) - 1: # last burst to evaluate
if outof <= len(isis) - 1: # are there spikes left after the burst?
metric = _burstiness_index(isis[into:outof], isis[outof:])
elif i == 0: # was this the first one (and there weren't spikes after)?
metric = _burstiness_index(isis[into:outof], isis[:into])
else:
prev_burst = inout_pairs[i - 1]
metric = _burstiness_index(isis[into:outof], isis[prev_burst[1]:into])
else:
next_burst = inout_pairs[i + 1]
metric = _burstiness_index(isis[into:outof], isis[outof:next_burst[0]])
bursts.append((metric, into, outof))
return bursts
def _score_burst_set(bursts, isis, delta_t, c_n=0.1, c_tx=0.01):
in_burst = np.zeros_like(isis, dtype=bool)
for b in bursts:
in_burst[b[0]:b[1]] = True
# If all ISIs are part of a burst, give it a bad score
if len(isis[~in_burst]) == 0:
return [-1e12] * len(bursts)
delta_frac = delta_t / isis
scores = []
for b in bursts:
score = _burstiness_index(isis[b[0]:b[1]], isis[~in_burst]) # base score
if b[1] < len(delta_t):
score -= c_tx * (1. / (delta_frac[b[1]])) # cost for starting a burst
if b[0] > 0:
score -= c_tx * (1. / delta_frac[b[0] - 1]) # cost for ending a burst
score -= c_n * (b[1] - b[0] - 1) # cost for extending a burst
scores.append(score)
return scores
def _burstiness_index(in_burst_isis, out_burst_isis):
burst_rate = 1. / in_burst_isis.min()
out_rate = 1. / out_burst_isis.min()
return (burst_rate - out_rate) / (burst_rate + out_rate)