import logging
import os
import json
import traceback
import numpy as np
import pandas as pd
import h5py
from allensdk.core.cell_types_cache import CellTypesCache
import ipfx.lims_queries as lq
import ipfx.stim_features as stf
import ipfx.stimulus_protocol_analysis as spa
import ipfx.data_set_features as dsf
import ipfx.time_series_utils as tsu
import ipfx.error as er
from ipfx.stimulus import StimulusType
from ipfx.sweep import SweepSet
from ipfx.dataset.create import create_ephys_data_set
[docs]def dataset_for_specimen_id(specimen_id, data_source, ontology, file_list=None):
if data_source == "lims":
nwb_path, h5_path = lims_nwb_information(specimen_id)
if type(nwb_path) is dict and "error" in nwb_path:
logging.warning("Problem getting NWB file for specimen {:d} from LIMS".format(specimen_id))
return nwb_path
try:
data_set = create_ephys_data_set(
nwb_file=nwb_path, ontology=ontology)
except Exception as detail:
logging.warning("Exception when loading specimen {:d} from LIMS".format(specimen_id))
logging.warning(detail)
return {"error": {"type": "dataset", "details": traceback.format_exc(limit=None)}}
elif data_source == "sdk":
nwb_path, sweep_info = sdk_nwb_information(specimen_id)
try:
data_set = create_ephys_data_set(
nwb_file=nwb_path, sweep_info=sweep_info, ontology=ontology)
except Exception as detail:
logging.warning("Exception when loading specimen {:d} via Allen SDK".format(specimen_id))
logging.warning(detail)
return {"error": {"type": "dataset", "details": traceback.format_exc(limit=None)}}
elif data_source == "filesystem":
nwb_path = file_list[specimen_id]
try:
data_set = create_ephys_data_set(nwb_file=nwb_path)
except Exception as detail:
logging.warning("Exception when loading specimen {:d} via file system".format(specimen_id))
logging.warning(detail)
return {"error": {"type": "dataset", "details": traceback.format_exc(limit=None)}}
else:
logging.error("invalid data source specified ({})".format(data_source))
return data_set
[docs]def categorize_iclamp_sweeps(data_set, stimuli_names, sweep_qc_option="none", specimen_id=None):
exist_sql = """
select swp.sweep_number from ephys_sweeps swp
where swp.specimen_id = :1
and swp.sweep_number = any(:2)
"""
passed_sql = """
select swp.sweep_number from ephys_sweeps swp
where swp.specimen_id = :1
and swp.sweep_number = any(:2)
and swp.workflow_state like '%%passed'
"""
passed_except_delta_vm_sql = """
select swp.sweep_number, tag.name
from ephys_sweeps swp
join ephys_sweep_tags_ephys_sweeps estes on estes.ephys_sweep_id = swp.id
join ephys_sweep_tags tag on tag.id = estes.ephys_sweep_tag_id
where swp.specimen_id = :1
and swp.sweep_number = any(:2)
"""
iclamp_st = data_set.filtered_sweep_table(clamp_mode=data_set.CURRENT_CLAMP, stimuli=stimuli_names)
if iclamp_st.shape[0] == 0:
return np.array([])
if sweep_qc_option == "none":
return iclamp_st["sweep_number"].sort_values().values
elif sweep_qc_option == "lims-passed-only":
# check that sweeps exist in LIMS
sweep_num_list = iclamp_st["sweep_number"].sort_values().tolist()
results = lq.query(exist_sql, (specimen_id, sweep_num_list))
res_nums = pd.DataFrame(results, columns=["sweep_number"])["sweep_number"].tolist()
not_checked_list = []
for swp_num in sweep_num_list:
if swp_num not in res_nums:
logging.debug("Could not find sweep {:d} from specimen {:d} in LIMS for QC check".format(swp_num, specimen_id))
not_checked_list.append(swp_num)
# Get passed sweeps
results = lq.query(passed_sql, (specimen_id, sweep_num_list))
results_df = pd.DataFrame(results, columns=["sweep_number"])
passed_sweep_nums = results_df["sweep_number"].values
return np.sort(np.hstack([passed_sweep_nums, np.array(not_checked_list)])) # deciding to keep non-checked sweeps for now
elif sweep_qc_option == "lims-passed-except-delta-vm":
# check that sweeps exist in LIMS
sweep_num_list = iclamp_st["sweep_number"].sort_values().tolist()
results = lq.query(exist_sql, (specimen_id, sweep_num_list))
res_nums = pd.DataFrame(results, columns=["sweep_number"])["sweep_number"].tolist()
not_checked_list = []
for swp_num in sweep_num_list:
if swp_num not in res_nums:
logging.debug("Could not find sweep {:d} from specimen {:d} in LIMS for QC check".format(swp_num, specimen_id))
not_checked_list.append(swp_num)
# get straight-up passed sweeps
results = lq.query(passed_sql, (specimen_id, sweep_num_list))
results_df = pd.DataFrame(results, columns=["sweep_number"])
passed_sweep_nums = results_df["sweep_number"].values
# also get sweeps that only fail due to delta Vm
failed_sweep_list = list(set(sweep_num_list) - set(passed_sweep_nums))
if len(failed_sweep_list) == 0:
return np.sort(passed_sweep_nums)
results = lq.query(passed_except_delta_vm_sql, (specimen_id, failed_sweep_list))
results_df = pd.DataFrame(results, columns=["sweep_number", "name"])
# not all cells have tagged QC status - if there are no tags assume the
# fail call is correct and exclude those sweeps
tagged_mask = np.array([sn in results_df["sweep_number"].tolist() for sn in failed_sweep_list])
# otherwise, check for having an error tag that isn't 'Vm delta'
# and exclude those sweeps
has_non_delta_tags = np.array([np.any((results_df["sweep_number"].values == sn) &
(results_df["name"].values != "Vm delta")) for sn in failed_sweep_list])
also_passing_nums = np.array(failed_sweep_list)[tagged_mask & ~has_non_delta_tags]
return np.sort(np.hstack([passed_sweep_nums, also_passing_nums, np.array(not_checked_list)]))
else:
raise ValueError("Invalid sweep-level QC option {}".format(sweep_qc_option))
[docs]def validate_sweeps(data_set, sweep_numbers, extra_dur=0.2):
check_sweeps = data_set.sweep_set(sweep_numbers)
valid_sweep_stim = []
start = None
dur = None
for swp in check_sweeps.sweeps:
if len(swp.t) == 0:
valid_sweep_stim.append(False)
continue
swp_start, swp_dur, _, _, _ = stf.get_stim_characteristics(swp.i, swp.t)
if swp_start is None:
valid_sweep_stim.append(False)
else:
start = swp_start
dur = swp_dur
valid_sweep_stim.append(True)
if start is None:
# Could not find any sweeps to define stimulus interval
return [], None, None
end = start + dur
# Check that all sweeps are long enough and not ended early
good_sweeps = [s for s, v in zip(check_sweeps.sweeps, valid_sweep_stim)
if s.t[-1] >= end + extra_dur
and v is True
and not np.all(s.v[tsu.find_time_index(s.t, end)-100:tsu.find_time_index(s.t, end)] == 0)]
return SweepSet(sweeps=good_sweeps), start, end
[docs]def preprocess_long_square_sweeps(data_set, sweep_numbers, extra_dur=0.2, subthresh_min_amp=-100.):
if len(sweep_numbers) == 0:
raise er.FeatureError("No long square sweeps available for feature extraction")
lsq_sweeps, lsq_start, lsq_end = validate_sweeps(data_set, sweep_numbers, extra_dur=extra_dur)
if len(lsq_sweeps.sweeps) == 0:
raise er.FeatureError("No long square sweeps were long enough or did not end early")
lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(
lsq_sweeps,
start=lsq_start,
end=lsq_end,
min_peak=-25,
**dsf.detection_parameters(StimulusType.LONG_SQUARE)
)
lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx,
subthresh_min_amp=subthresh_min_amp)
lsq_features = lsq_an.analyze(lsq_sweeps)
return lsq_sweeps, lsq_features, lsq_an, lsq_start, lsq_end
[docs]def preprocess_short_square_sweeps(data_set, sweep_numbers, extra_dur=0.2, spike_window=0.05):
if len(sweep_numbers) == 0:
raise er.FeatureError("No short square sweeps available for feature extraction")
ssq_sweeps, ssq_start, ssq_end = validate_sweeps(data_set, sweep_numbers, extra_dur=extra_dur)
if len(ssq_sweeps.sweeps) == 0:
raise er.FeatureError("No short square sweeps were long enough or did not end early")
ssq_spx, ssq_spfx = dsf.extractors_for_sweeps(ssq_sweeps,
est_window = [ssq_start, ssq_start + 0.001],
start=ssq_start,
end=ssq_end + spike_window,
reject_at_stim_start_interval=0.0002,
**dsf.detection_parameters(StimulusType.SHORT_SQUARE))
ssq_an = spa.ShortSquareAnalysis(ssq_spx, ssq_spfx)
ssq_features = ssq_an.analyze(ssq_sweeps)
return ssq_sweeps, ssq_features, ssq_an
[docs]def preprocess_ramp_sweeps(data_set, sweep_numbers):
if len(sweep_numbers) == 0:
raise er.FeatureError("No ramp sweeps available for feature extraction")
ramp_sweeps = data_set.sweep_set(sweep_numbers)
ramp_start, ramp_dur, _, _, _ = stf.get_stim_characteristics(ramp_sweeps.sweeps[0].i, ramp_sweeps.sweeps[0].t)
ramp_spx, ramp_spfx = dsf.extractors_for_sweeps(ramp_sweeps,
start = ramp_start,
**dsf.detection_parameters(StimulusType.RAMP))
ramp_an = spa.RampAnalysis(ramp_spx, ramp_spfx)
ramp_features = ramp_an.analyze(ramp_sweeps)
return ramp_sweeps, ramp_features, ramp_an
[docs]def filter_results(specimen_ids, results):
filtered_set = [(i, r) for i, r in zip(specimen_ids, results) if not "error" in r.keys()]
error_set = [{"id": i, "error": d} for i, d in zip(specimen_ids, results) if "error" in d.keys()]
if len(filtered_set) == 0:
logging.info("No specimens had results")
return
used_ids, results = zip(*filtered_set)
return used_ids, results, error_set
[docs]def organize_results(specimen_ids, results):
"""Build dictionary of results, filling data from cells with appropriate-length
nan arrays where needed"""
result_sizes = {}
output = {}
all_keys = np.unique(np.concatenate([list(r.keys()) for r in results]))
for k in all_keys:
if k not in result_sizes:
for r in results:
if k in r and r[k] is not None:
result_sizes[k] = len(r[k])
data = np.array([r[k] if k in r else np.nan * np.zeros(result_sizes[k])
for r in results])
output[k] = data
return output
[docs]def save_results_to_npy(specimen_ids, results_dict, output_dir, output_code):
k_sizes = {}
for k in results_dict:
np.save(os.path.join(output_dir, "fv_{:s}_{:s}.npy".format(k, output_code)), results_dict[k])
np.save(os.path.join(output_dir, "fv_ids_{:s}.npy".format(output_code)), specimen_ids)
[docs]def save_results_to_h5(specimen_ids, results_dict, output_dir, output_code):
ids_arr = np.array(specimen_ids)
h5_file = h5py.File(os.path.join(output_dir, "fv_{}.h5".format(output_code)), "w")
for k in results_dict:
data = results_dict[k]
dset = h5_file.create_dataset(k, data.shape, dtype=data.dtype,
compression="gzip")
dset[...] = data
dset = h5_file.create_dataset("ids", ids_arr.shape,
dtype=ids_arr.dtype, compression="gzip")
dset[...] = ids_arr
h5_file.close()
[docs]def save_errors_to_json(error_set, output_dir, output_code):
with open(os.path.join(output_dir, "fv_errors_{:s}.json".format(output_code)), "w") as f:
json.dump(error_set, f, indent=4)