Source code for ipfx.feature_extractor

# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2015-2016. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
from pandas import DataFrame

from . import spike_features as spkf
from . import subthresh_features as subf
from . import spike_detector as spkd
from . import spike_train_features as strf
from . import time_series_utils as tsu


[docs]class SpikeFeatureExtractor(object): AFFECTED_BY_CLIPPING = [ "trough_t", "trough_v", "trough_i", "trough_index", "downstroke", "downstroke_t","downstroke_v", "downstroke_index", "fast_trough_t", "fast_trough_v", "fast_trough_i", "fast_trough_index" "adp_t", "adp_v", "adp_i", "adp_index", "slow_trough_t", "slow_trough_v", "slow_trough_i", "slow_trough_index", "isi_type", "width", "upstroke_downstroke_ratio" ] """Feature calculation for a sweep (voltage and/or current time series).""" def __init__(self, start=None, end=None, filter=10., dv_cutoff=20., max_interval=0.005, min_height=2., min_peak=-30., thresh_frac=0.05, reject_at_stim_start_interval=0): """Initialize SweepFeatures object.- Parameters ---------- t : ndarray of times (seconds) v : ndarray of voltages (mV) i : ndarray of currents (pA) start : start of time window for feature analysis (optional) end : end of time window for feature analysis (optional) filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20) max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005) min_height : minimum acceptable height from threshold to peak in mV (optional, default 2) min_peak : minimum acceptable absolute peak level in mV (optional, default -30) thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05) reject_at_stim_start_interval : duration of window after start to reject potential spikes (optional, default 0) """ self.start = start self.end = end self.filter = filter self.dv_cutoff = dv_cutoff self.max_interval = max_interval self.min_height = min_height self.min_peak = min_peak self.thresh_frac = thresh_frac self.reject_at_stim_start_interval = reject_at_stim_start_interval
[docs] def process(self, t, v, i): dvdt = tsu.calculate_dvdt(v, t, self.filter) # Basic features of spikes putative_spikes = spkd.detect_putative_spikes(v, t, self.start, self.end, dv_cutoff=self.dv_cutoff, dvdt=dvdt) peaks = spkd.find_peak_indexes(v, t, putative_spikes, self.end) putative_spikes, peaks = spkd.filter_putative_spikes(v, t, putative_spikes, peaks, self.min_height, self.min_peak, dvdt=dvdt) if not putative_spikes.size: # Save time if no spikes detected return DataFrame() upstrokes = spkd.find_upstroke_indexes(v, t, putative_spikes, peaks, dvdt=dvdt) thresholds = spkd.refine_threshold_indexes(v, t, upstrokes, self.thresh_frac, dvdt=dvdt) thresholds, peaks, upstrokes, clipped = spkd.check_thresholds_and_peaks(v, t, thresholds, peaks, upstrokes, self.start, self.end, self.max_interval, dvdt=dvdt, reject_at_stim_start_interval=self.reject_at_stim_start_interval) if not thresholds.size: # Save time if no spikes detected return DataFrame() # Spike list and thresholds have been refined - now find other features upstrokes = spkd.find_upstroke_indexes(v, t, thresholds, peaks, self.filter, dvdt) troughs = spkd.find_trough_indexes(v, t, thresholds, peaks, clipped, self.end) downstrokes = spkd.find_downstroke_indexes(v, t, peaks, troughs, clipped, dvdt=dvdt) trough_details, clipped = spkf.analyze_trough_details(v, t, thresholds, peaks, clipped, self.end, dvdt=dvdt) widths = spkf.find_widths(v, t, thresholds, peaks, trough_details[1], clipped) # Points where we care about t, v, and i if available vit_data_indexes = { "threshold": thresholds, "peak": peaks, "trough": troughs, } # Points where we care about t and dv/dt dvdt_data_indexes = { "upstroke": upstrokes, "downstroke": downstrokes } # Trough details isi_types = trough_details[0] trough_detail_indexes = dict(zip(["fast_trough", "adp", "slow_trough"], trough_details[1:])) # Redundant, but ensures that DataFrame has right number of rows # Any better way to do it? spikes_df = DataFrame(data=thresholds, columns=["threshold_index"]) spikes_df["clipped"] = clipped for k, all_vals in vit_data_indexes.items(): valid_ind = ~np.isnan(all_vals) vals = all_vals[valid_ind].astype(int) spikes_df[k + "_index"] = np.nan spikes_df[k + "_t"] = np.nan spikes_df[k + "_v"] = np.nan if len(vals) > 0: spikes_df.loc[valid_ind, k + "_index"] = vals spikes_df.loc[valid_ind, k + "_t"] = t[vals] spikes_df.loc[valid_ind, k + "_v"] = v[vals] if i is not None: spikes_df[k + "_i"] = np.nan if len(vals) > 0: spikes_df.loc[valid_ind, k + "_i"] = i[vals] for k, all_vals in dvdt_data_indexes.items(): valid_ind = ~np.isnan(all_vals) vals = all_vals[valid_ind].astype(int) spikes_df[k + "_index"] = np.nan spikes_df[k] = np.nan if len(vals) > 0: spikes_df.loc[valid_ind, k + "_index"] = vals spikes_df.loc[valid_ind, k + "_t"] = t[vals] spikes_df.loc[valid_ind, k + "_v"] = v[vals] spikes_df.loc[valid_ind, k] = dvdt[vals] spikes_df["isi_type"] = isi_types for k, all_vals in trough_detail_indexes.items(): valid_ind = ~np.isnan(all_vals) vals = all_vals[valid_ind].astype(int) spikes_df[k + "_index"] = np.nan spikes_df[k + "_t"] = np.nan spikes_df[k + "_v"] = np.nan if len(vals) > 0: spikes_df.loc[valid_ind, k + "_index"] = vals spikes_df.loc[valid_ind, k + "_t"] = t[vals] spikes_df.loc[valid_ind, k + "_v"] = v[vals] if i is not None: spikes_df[k + "_i"] = np.nan if len(vals) > 0: spikes_df.loc[valid_ind, k + "_i"] = i[vals] spikes_df["width"] = widths spikes_df["upstroke_downstroke_ratio"] = spikes_df["upstroke"] / -spikes_df["downstroke"] return spikes_df
[docs] def spikes(self, spikes_df): """Get all features for each spike as a list of records.""" return spikes_df.to_dict(orient='records')
[docs] def is_spike_feature_affected_by_clipping(self, key): return key in self.AFFECTED_BY_CLIPPING
[docs] def spike_feature_keys(self, spikes_df): """Get list of every available spike feature.""" return spikes_df.columns.values.tolist()
[docs] def spike_feature(self, spikes_df, key, include_clipped=False, force_exclude_clipped=False): """Get specified feature for every spike. Parameters ---------- key : feature name include_clipped: return values for every identified spike, even when clipping means they will be incorrect/undefined Returns ------- spike_feature_values : ndarray of features for each spike """ if len(spikes_df) == 0: return np.array([]) if key not in spikes_df.columns: raise KeyError("requested feature '{:s}' not available".format(key)) values = spikes_df[key].values if include_clipped and force_exclude_clipped: raise ValueError("include_clipped and force_exclude_clipped cannot both be true") if not include_clipped and self.is_spike_feature_affected_by_clipping(key): values = values[~spikes_df["clipped"].values] elif force_exclude_clipped: values = values[~spikes_df["clipped"].values] return values
[docs]class SpikeTrainFeatureExtractor(object): def __init__(self, start, end, #pause_cost_weight=1.0, burst_tol=0.5, pause_cost=1.0, #deflect_type="min", deflect_type=None, stim_amp_fn=None, baseline_interval=0.1, filter_frequency=1.0, sag_baseline_interval=0.03, peak_width=0.005): self.start = start self.end = end self.burst_tol = burst_tol self.pause_cost = pause_cost #self.pause_cost_weight = pause_cost_weight self.deflect_type = deflect_type self.stim_amp_fn = stim_amp_fn self.baseline_interval = baseline_interval self.filter_frequency = filter_frequency self.sag_baseline_interval = sag_baseline_interval self.peak_width = peak_width
[docs] def process(self, t, v, i, spikes_df, extra_features=None, exclude_clipped=False): features = strf.basic_spike_train_features(t, spikes_df, self.start, self.end, exclude_clipped=exclude_clipped) if self.start is None: self.start = 0.0 if extra_features is None: extra_features = [] if 'peak_deflect' in extra_features: features['peak_deflect'] = subf.voltage_deflection(t, v, i, self.start, self.end, self.deflect_type) if 'stim_amp' in extra_features: features['stim_amp'] = self.stim_amp_fn(t, i, self.start) if self.stim_amp_fn else None if 'v_baseline' in extra_features: features['v_baseline'] = subf.baseline_voltage(t, v, self.start, self.baseline_interval, self.filter_frequency) if 'sag' in extra_features: features['sag'] = subf.sag(t, v, i, self.start, self.end, self.peak_width, self.sag_baseline_interval) if features["avg_rate"] > 0: if 'pause' in extra_features: features['pause'] = strf.pause(t, spikes_df, self.start, self.end, self.pause_cost_weight) if 'burst' in extra_features: features['burst'] = strf.burst(t, spikes_df, self.burst_tol, self.pause_cost) if 'delay' in extra_features: features['delay'] = strf.delay(t, v, spikes_df, self.start, self.end) return features