diff --git a/PySONIC/neurons/stn.py b/PySONIC/neurons/stn.py index d509627..0c85b07 100644 --- a/PySONIC/neurons/stn.py +++ b/PySONIC/neurons/stn.py @@ -1,581 +1,612 @@ import numpy as np +from scipy.optimize import brentq from ..core import PointNeuron from ..constants import FARADAY, Z_Ca from ..utils import nernst class OtsukaSTN(PointNeuron): ''' Class defining the Otsuka model of sub-thalamic nucleus neuron with 5 different current types: - Inward Sodium current (iNa) - Outward, delayed-rectifer Potassium current (iKd) - Inward, A-type Potassium current (iA) - Inward, low-threshold Calcium current (iCaT) - Inward, high-threshold Calcium current (iCaL) - Outward, Calcium-dependent Potassium current (iKCa) - Non-specific leakage current (iLeak) References: *Otsuka, T., Abe, T., Tsukagawa, T., and Song, W.-J. (2004). Conductance-Based Model of the Voltage-Dependent Generation of a Plateau Potential in Subthalamic Neurons. Journal of Neurophysiology 92, 255–264.* *Tarnaud, T., Joseph, W., Martens, L., and Tanghe, E. (2018). Computational Modeling of Ultrasonic Subthalamic Nucleus Stimulation. IEEE Trans Biomed Eng.* ''' name = 'STN' # Resting parameters Cm0 = 1e-2 # Cell membrane resting capacitance (F/m2) Vm0 = -58.0 # Resting membrane potential (mV) CCa_in0 = 5e-9 # M (5 nM) # Reversal potentials VNa = 60.0 # Sodium Nernst potential (mV) VK = -90.0 # Potassium Nernst potential (mV) # Physical constants T = 306.15 # K (33°C) # Calcium dynamics CCa_out = 2e-3 # M (2 mM) KCa = 2e3 # s-1 # Leakage current GLeak = 3.5 # Conductance of non-specific leakage current (S/m^2) VLeak = -60.0 # Leakage reversal potential (mV) # Fast Na current GNaMax = 490.0 # Max. conductance of Sodium current (S/m^2) thetax_m = -40 # mV thetax_h = -45.5 # mV kx_m = -8 # mV kx_h = 6.4 # mV tau0_m = 0.2 * 1e-3 # s tau1_m = 3 * 1e-3 # s tau0_h = 0 * 1e-3 # s tau1_h = 24.5 * 1e-3 # s thetaT_m = -53 # mV thetaT1_h = -50 # mV thetaT2_h = -50 # mV sigmaT_m = -0.7 # mV sigmaT1_h = -15 # mV sigmaT2_h = 16 # mV # Delayed rectifier K+ current GKMax = 570.0 # Max. conductance of delayed-rectifier Potassium current (S/m^2) thetax_n = -41 # mV kx_n = -14 # mV tau0_n = 0 * 1e-3 # s tau1_n = 11 * 1e-3 # s thetaT1_n = -40 # mV thetaT2_n = -40 # mV sigmaT1_n = -40 # mV sigmaT2_n = 50 # mV # T-type Ca2+ current GTMax = 50.0 # Max. conductance of low-threshold Calcium current (S/m^2) thetax_p = -56 # mV thetax_q = -85 # mV kx_p = -6.7 # mV kx_q = 5.8 # mV tau0_p = 5 * 1e-3 # s tau1_p = 0.33 * 1e-3 # s tau0_q = 0 * 1e-3 # s tau1_q = 400 * 1e-3 # s thetaT1_p = -27 # mV thetaT2_p = -102 # mV thetaT1_q = -50 # mV thetaT2_q = -50 # mV sigmaT1_p = -10 # mV sigmaT2_p = 15 # mV sigmaT1_q = -15 # mV sigmaT2_q = 16 # mV # L-type Ca2+ current GLMax = 150.0 # Max. conductance of high-threshold Calcium current (S/m^2) thetax_c = -30.6 # mV thetax_d1 = -60 # mV thetax_d2 = 0.1 * 1e-6 # M kx_c = -5 # mV kx_d1 = 7.5 # mV kx_d2 = 0.02 * 1e-6 # M tau0_c = 45 * 1e-3 # s tau1_c = 10 * 1e-3 # s tau0_d1 = 400 * 1e-3 # s tau1_d1 = 500 * 1e-3 # s tau_d2 = 130 * 1e-3 # s thetaT1_c = -27 # mV thetaT2_c = -50 # mV thetaT1_d1 = -40 # mV thetaT2_d1 = -20 # mV sigmaT1_c = -20 # mV sigmaT2_c = 15 # mV sigmaT1_d1 = -15 # mV sigmaT2_d1 = 20 # mV # A-type K+ current GAMax = 50.0 # Max. conductance of A-type Potassium current (S/m^2) thetax_a = -45 # mV thetax_b = -90 # mV kx_a = -14.7 # mV kx_b = 7.5 # mV tau0_a = 1 * 1e-3 # s tau1_a = 1 * 1e-3 # s tau0_b = 0 * 1e-3 # s tau1_b = 200 * 1e-3 # s thetaT_a = -40 # mV thetaT1_b = -60 # mV thetaT2_b = -40 # mV sigmaT_a = -0.5 # mV sigmaT1_b = -30 # mV sigmaT2_b = 10 # mV # Ca2+-activated K+ current GKCaMax = 10.0 # Max. conductance of Calcium-dependent Potassium current (S/m^2) thetax_r = 0.17 * 1e-6 # M kx_r = -0.08 * 1e-6 # M tau_r = 2 * 1e-3 # s # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h'], 'i_{Kd}\ kin.': ['n'], 'i_A\ kin.': ['a', 'b'], 'i_{CaT}\ kin.': ['p', 'q'], 'i_{CaL}\ kin.': ['c', 'd1', 'd2'], 'Ca^{2+}_i': ['C_Ca'], 'i_{KCa}\ kin.': ['r'], 'I': ['iLeak', 'iNa', 'iKd', 'iA', 'iCaT2', 'iCaL2', 'iKCa', 'iNet'] } def __init__(self): ''' Constructor of the class ''' # Names and initial states of the channels state probabilities self.states_names = ['a', 'b', 'c', 'd1', 'd2', 'm', 'h', 'n', 'p', 'q', 'r', 'C_Ca'] # Names of the different coefficients to be averaged in a lookup table. self.coeff_names = [ 'alphaa', 'betaa', 'alphab', 'betab', 'alphac', 'betac', 'alphad1', 'betad1', 'alpham', 'betam', 'alphah', 'betah', 'alphan', 'betan', 'alphap', 'betap', 'alphaq', 'betaq', ] # Compute Calcium reversal potential for Cai = 5 nM self.VCa = nernst(Z_Ca, self.CCa_in0, self.CCa_out, self.T) # mV # Compute deff for that reversal potential iCaT = self.iCaT( self.pinf(self.Vm0), self.qinf(self.Vm0), self.Vm0) # mA/m2 iCaL = self.iCaL( self.cinf(self.Vm0), self.d1inf(self.Vm0), self.d2inf(self.CCa_in0), self.Vm0) # mA/m2 self.deff = -(iCaT + iCaL) / (Z_Ca * FARADAY * self.KCa * self.CCa_in0) * 1e-6 # m # Compute conversion factor from electrical current (mA/m2) to Calcium concentration (M) self.i2CCa = 1e-6 / (Z_Ca * self.deff * FARADAY) # Initial states self.states0 = self.steadyStates(self.Vm0) # Charge interval bounds for lookup creation self.Qbounds = np.array([np.round(self.Vm0 - 25.0), 50.0]) * self.Cm0 * 1e-3 # C/m2 def _xinf(self, var, theta, k): ''' Generic function computing the steady-state activation/inactivation of a particular ion channel at a given voltage or ion concentration. :param var: membrane potential (mV) or ion concentration (mM) :param theta: half-(in)activation voltage or concentration (mV or mM) :param k: slope parameter of (in)activation function (mV or mM) :return: steady-state (in)activation (-) ''' return 1 / (1 + np.exp((var - theta) / k)) def ainf(self, Vm): return self._xinf(Vm, self.thetax_a, self.kx_a) def binf(self, Vm): return self._xinf(Vm, self.thetax_b, self.kx_b) def cinf(self, Vm): return self._xinf(Vm, self.thetax_c, self.kx_c) def d1inf(self, Vm): return self._xinf(Vm, self.thetax_d1, self.kx_d1) def d2inf(self, Cai): return self._xinf(Cai, self.thetax_d2, self.kx_d2) def minf(self, Vm): return self._xinf(Vm, self.thetax_m, self.kx_m) def hinf(self, Vm): return self._xinf(Vm, self.thetax_h, self.kx_h) def ninf(self, Vm): return self._xinf(Vm, self.thetax_n, self.kx_n) def pinf(self, Vm): return self._xinf(Vm, self.thetax_p, self.kx_p) def qinf(self, Vm): return self._xinf(Vm, self.thetax_q, self.kx_q) def rinf(self, Cai): return self._xinf(Cai, self.thetax_r, self.kx_r) def _taux1(self, Vm, theta, sigma, tau0, tau1): ''' Generic function computing the voltage-dependent, activation/inactivation time constant of a particular ion channel at a given voltage (first variant). :param Vm: membrane potential (mV) :param theta: voltage at which (in)activation time constant is half-maximal (mV) :param sigma: slope parameter of (in)activation time constant function (mV) :param tau0: minimal time constant (s) :param tau1: modulated time constant (s) :return: (in)activation time constant (s) ''' return tau0 + tau1 / (1 + np.exp(-(Vm - theta) / sigma)) def taua(self, Vm): return self._taux1(Vm, self.thetaT_a, self.sigmaT_a, self.tau0_a, self.tau1_a) def taum(self, Vm): return self._taux1(Vm, self.thetaT_m, self.sigmaT_m, self.tau0_m, self.tau1_m) def _taux2(self, Vm, theta1, theta2, sigma1, sigma2, tau0, tau1): ''' Generic function computing the voltage-dependent, activation/inactivation time constant of a particular ion channel at a given voltage (second variant). :param Vm: membrane potential (mV) :param theta: voltage at which (in)activation time constant is half-maximal (mV) :param sigma: slope parameter of (in)activation time constant function (mV) :param tau0: minimal time constant (s) :param tau1: modulated time constant (s) :return: (in)activation time constant (s) ''' return tau0 + tau1 / (np.exp(-(Vm - theta1) / sigma1) + np.exp(-(Vm - theta2) / sigma2)) def taub(self, Vm): return self._taux2(Vm, self.thetaT1_b, self.thetaT2_b, self.sigmaT1_b, self.sigmaT2_b, self.tau0_b, self.tau1_b) def tauc(self, Vm): return self._taux2(Vm, self.thetaT1_c, self.thetaT2_c, self.sigmaT1_c, self.sigmaT2_c, self.tau0_c, self.tau1_c) def taud1(self, Vm): return self._taux2(Vm, self.thetaT1_d1, self.thetaT2_d1, self.sigmaT1_d1, self.sigmaT2_d1, self.tau0_d1, self.tau1_d1) def tauh(self, Vm): return self._taux2(Vm, self.thetaT1_h, self.thetaT2_h, self.sigmaT1_h, self.sigmaT2_h, self.tau0_h, self.tau1_h) def taun(self, Vm): return self._taux2(Vm, self.thetaT1_n, self.thetaT2_n, self.sigmaT1_n, self.sigmaT2_n, self.tau0_n, self.tau1_n) def taup(self, Vm): return self._taux2(Vm, self.thetaT1_p, self.thetaT2_p, self.sigmaT1_p, self.sigmaT2_p, self.tau0_p, self.tau1_p) def tauq(self, Vm): return self._taux2(Vm, self.thetaT1_q, self.thetaT2_q, self.sigmaT1_q, self.sigmaT2_q, self.tau0_q, self.tau1_q) def derA(self, Vm, a): return (self.ainf(Vm) - a) / self.taua(Vm) def derB(self, Vm, b): return (self.binf(Vm) - b) / self.taub(Vm) def derC(self, Vm, c): return (self.cinf(Vm) - c) / self.tauc(Vm) def derD1(self, Vm, d1): return (self.d1inf(Vm) - d1) / self.taud1(Vm) def derD2(self, Cai, d2): return (self.d2inf(Cai) - d2) / self.tau_d2 def derM(self, Vm, m): return (self.minf(Vm) - m) / self.taum(Vm) def derH(self, Vm, h): return (self.hinf(Vm) - h) / self.tauh(Vm) def derN(self, Vm, n): return (self.ninf(Vm) - n) / self.taun(Vm) def derP(self, Vm, p): return (self.pinf(Vm) - p) / self.taup(Vm) def derQ(self, Vm, q): return (self.qinf(Vm) - q) / self.tauq(Vm) def derR(self, Cai, r): return (self.rinf(Cai) - r) / self.tau_r def derC_Ca(self, C_Ca, iCaT, iCaL): ''' Compute the evolution of the Calcium concentration in submembranal space. :param Vm: membrane potential (mV) :param C_Ca: Calcium concentration in submembranal space (M) :param iCaT: inward, low-threshold Calcium current (mA/m2) :param iCaL: inward, high-threshold Calcium current (mA/m2) :return: derivative of Calcium concentration in submembranal space w.r.t. time (M/s) ''' return - self.i2CCa * (iCaT + iCaL) - C_Ca * self.KCa + def get_dCCa(self, C_Ca, Vm): + ''' Return the time derivative of intracellular Calcium concentration given + its current value at a specific membrane potential. + + :param C_Ca: Calcium concentration in submembranal space (M) + :param Vm: membrane potential (mV) + :return: time derivative of Calcium concentration in submembranal space (M/s) + ''' + self.VCa = nernst(Z_Ca, C_Ca, self.CCa_out, self.T) # mV + iCaT = self.iCaT(self.pinf(Vm), self.qinf(Vm), Vm) + iCaL = self.iCaL(self.cinf(Vm), self.d1inf(Vm), self.d2inf(C_Ca), Vm) + return self.derC_Ca(C_Ca, iCaT, iCaL) + + + def findCaeq(self, Vm): + ''' Find the equilibrium intracellular Calcium concentration for a + specific membrane potential. + + :param Vm: membrane potential (mV) + :return: equilibrium Calcium concentration in submembranal space (M) + ''' + Ca_eq = brentq( + lambda x: self.get_dCCa(x, Vm), + self.CCa_in0 * 1e-4, self.CCa_in0 * 1e3, + xtol=1e-16 + ) + return Ca_eq + + def iNa(self, m, h, Vm): ''' Compute the inward Sodium current per unit area. :param m: open-probability of m-gate :param h: inactivation-probability of h-gate :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GNaMax * m**3 * h * (Vm - self.VNa) def iKd(self, n, Vm): ''' Compute the outward delayed-rectifier Potassium current per unit area. :param n: open-probability of n-gate :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GKMax * n**4 * (Vm - self.VK) def iA(self, a, b, Vm): ''' Compute the outward A-type Potassium current per unit area. :param a: open-probability of a-gate :param b: open-probability of b-gate :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GAMax * a**2 * b * (Vm - self.VK) def iCaT(self, p, q, Vm): ''' Compute the inward low-threshold Calcium current per unit area. :param p: open-probability of p-gate :param q: open-probability of q-gate :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GTMax * p**2 * q * (Vm - self.VCa) def iCaL(self, c, d1, d2, Vm): ''' Compute the inward high-threshold Calcium current per unit area. :param c: open-probability of c-gate :param d1: open-probability of d1-gate :param d2: open-probability of d2-gate :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GLMax * c**2 * d1 * d2 * (Vm - self.VCa) def iKCa(self, r, Vm): ''' Compute the outward, Calcium activated Potassium current per unit area. :param r: open-probability of r-gate :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GKCaMax * r**2 * (Vm - self.VK) def iLeak(self, Vm): ''' Compute the non-specific leakage current per unit area. :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GLeak * (Vm - self.VLeak) def iNet(self, Vm, states): ''' Compute net membrane current per unit area. ''' a, b, c, d1, d2, m, h, n, p, q, r, CCa_in = states # update VCa based on intracellular Calcium concentration self.VCa = nernst(Z_Ca, CCa_in, self.CCa_out, self.T) # mV return ( self.iNa(m, h, Vm) + self.iKd(n, Vm) + self.iA(a, b, Vm) + self.iCaT(p, q, Vm) + self.iCaL(c, d1, d2, Vm) + self.iKCa(r, Vm) + self.iLeak(Vm) ) # mA/m2 def currents(self, Vm, states): ''' Compute all membrane currents per unit area. ''' a, b, c, d1, d2, m, h, n, p, q, r, CCa_in = states # update VCa based on intracellular Calcium concentration self.VCa = nernst(Z_Ca, CCa_in, self.CCa_out, self.T) # mV return { 'iNa': self.iNa(m, h, Vm), 'iKd': self.iKd(n, Vm), 'iA': self.iA(a, b, Vm), 'iCaT': self.iCaT(p, q, Vm), 'iCaL': self.iCaL(c, d1, d2, Vm), 'iKCa': self.iKCa(r, Vm), 'iLeak': self.iLeak(Vm) } # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Solve the equation dx/dt = 0 at Vm for each x-state aeq = self.ainf(Vm) beq = self.binf(Vm) ceq = self.cinf(Vm) d1eq = self.d1inf(Vm) meq = self.minf(Vm) heq = self.hinf(Vm) neq = self.ninf(Vm) peq = self.pinf(Vm) qeq = self.qinf(Vm) - d2eq = self.d2inf(self.CCa_in0) - req = self.rinf(self.CCa_in0) + C_Ca_eq = self.findCaeq(Vm) + d2eq = self.d2inf(C_Ca_eq) + req = self.rinf(C_Ca_eq) - return np.array([aeq, beq, ceq, d1eq, d2eq, meq, heq, neq, peq, qeq, req, self.CCa_in0]) + return np.array([aeq, beq, ceq, d1eq, d2eq, meq, heq, neq, peq, qeq, req, C_Ca_eq]) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' a, b, c, d1, d2, m, h, n, p, q, r, CCa_in = states dadt = self.derA(Vm, a) dbdt = self.derB(Vm, b) dcdt = self.derC(Vm, c) dd1dt = self.derD1(Vm, d1) dd2dt = self.derD2(CCa_in, d2) dmdt = self.derM(Vm, m) dhdt = self.derH(Vm, h) dndt = self.derN(Vm, n) dpdt = self.derP(Vm, p) dqdt = self.derQ(Vm, q) drdt = self.derR(CCa_in, r) iCaT = self.iCaT(p, q, Vm) iCaL = self.iCaL(c, d1, d2, Vm) dCCaindt = self.derC_Ca(CCa_in, iCaT, iCaL) return [dadt, dbdt, dcdt, dd1dt, dd2dt, dmdt, dhdt, dndt, dpdt, dqdt, drdt, dCCaindt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Compute average cycle value for rate constants Ta = self.taua(Vm) alphaa_avg = np.mean(self.ainf(Vm) / Ta) betaa_avg = np.mean(1 / Ta) - alphaa_avg Tb = self.taub(Vm) alphab_avg = np.mean(self.binf(Vm) / Tb) betab_avg = np.mean(1 / Tb) - alphab_avg Tc = self.tauc(Vm) alphac_avg = np.mean(self.cinf(Vm) / Tc) betac_avg = np.mean(1 / Tc) - alphac_avg Td1 = self.taud1(Vm) alphad1_avg = np.mean(self.ainf(Vm) / Td1) betad1_avg = np.mean(1 / Td1) - alphad1_avg Tm = self.taum(Vm) alpham_avg = np.mean(self.minf(Vm) / Tm) betam_avg = np.mean(1 / Tm) - alpham_avg Th = self.tauh(Vm) alphah_avg = np.mean(self.hinf(Vm) / Th) betah_avg = np.mean(1 / Th) - alphah_avg Tn = self.taun(Vm) alphan_avg = np.mean(self.ninf(Vm) / Tn) betan_avg = np.mean(1 / Tn) - alphan_avg Tp = self.taup(Vm) alphap_avg = np.mean(self.pinf(Vm) / Tp) betap_avg = np.mean(1 / Tp) - alphap_avg Tq = self.tauq(Vm) alphaq_avg = np.mean(self.qinf(Vm) / Tq) betaq_avg = np.mean(1 / Tq) - alphaq_avg # Return array of coefficients return np.array([ alphaa_avg, betaa_avg, alphab_avg, betab_avg, alphac_avg, betac_avg, alphad1_avg, betad1_avg, alpham_avg, betam_avg, alphah_avg, betah_avg, alphan_avg, betan_avg, alphap_avg, betap_avg, alphaq_avg, betaq_avg ]) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names]) Vmeff = np.interp(Qm, interp_data['Q'], interp_data['V']) a, b, c, d1, d2, m, h, n, p, q, r, CCa_in = states dadt = rates[0] * (1 - a) - rates[1] * a dbdt = rates[2] * (1 - b) - rates[3] * b dcdt = rates[4] * (1 - c) - rates[5] * c dd1dt = rates[6] * (1 - d1) - rates[7] * d1 dd2dt = self.derD2(CCa_in, d2) dmdt = rates[8] * (1 - m) - rates[9] * m dhdt = rates[10] * (1 - h) - rates[11] * h dndt = rates[12] * (1 - n) - rates[13] * n dpdt = rates[14] * (1 - p) - rates[15] * p dqdt = rates[16] * (1 - q) - rates[17] * q drdt = self.derR(CCa_in, r) iCaT = self.iCaT(p, q, Vmeff) iCaL = self.iCaL(c, d1, d2, Vmeff) dCCaindt = self.derC_Ca(CCa_in, iCaT, iCaL) return [dadt, dbdt, dcdt, dd1dt, dd2dt, dmdt, dhdt, dndt, dpdt, dqdt, drdt, dCCaindt] diff --git a/PySONIC/postpro.py b/PySONIC/postpro.py index bffef1b..aa41e77 100644 --- a/PySONIC/postpro.py +++ b/PySONIC/postpro.py @@ -1,444 +1,486 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-22 14:33:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-10-25 22:47:00 +# @Last Modified time: 2019-03-11 23:49:29 ''' Utility functions to detect spikes on signals and compute spiking metrics. ''' import pickle import numpy as np import pandas as pd from .constants import * from .utils import logger +def detectCrossings(x, thr=0.0, edge='both'): + ''' + Detect crossings of a threshold value in a 1D signal. + + :param x: 1D array_like data. + :param edge: 'rising', 'falling', or 'both' + :return: 1D array with the indices preceding the crossings + ''' + ine, ire, ife = np.array([[], [], []], dtype=int) + x_padright = np.hstack((x, x[-1])) + x_padleft = np.hstack((x[0], x)) + if not edge: + ine = np.where((x_padright < thr) & (x_padleft > thr))[0] + else: + if edge.lower() in ['rising', 'both']: + ire = np.where((x_padright <= thr) & (x_padleft > thr))[0] + if edge.lower() in ['falling', 'both']: + ife = np.where((x_padright < thr) & (x_padleft >= thr))[0] + ind = np.unique(np.hstack((ine, ire, ife))) - 1 + return ind + + +def getStableFixedPoints(x, dx): + ''' Find stable fixed points in a 1D plane phase profile. + + :param x: variable (1D array) + :param dx: derivative (1D array) + :return: array of stable fixed points values (or None if none is found). + ''' + + stable_fps = [] + i_stable_fp = detectCrossings(dx, edge='falling') + if i_stable_fp.size > 0: + for i in i_stable_fp: + a = (dx[i + 1] - dx[i]) / (x[i + 1] - x[i]) + b = dx[i] - a * x[i] + stable_fps.append(-b / a) + return np.array(stable_fps) + else: + return None + + def detectPeaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False, ax=None): ''' Detect peaks in data based on their amplitude and other features. Adapted from Marco Duarte: http://nbviewer.jupyter.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb :param x: 1D array_like data. :param mph: minimum peak height (default = None). :param mpd: minimum peak distance in indexes (default = 1) :param threshold : minimum peak prominence (default = 0) :param edge : for a flat peak, keep only the rising edge ('rising'), only the falling edge ('falling'), both edges ('both'), or don't detect a flat peak (None). (default = 'rising') :param kpsh: keep peaks with same height even if they are closer than `mpd` (default = False). :param valley: detect valleys (local minima) instead of peaks (default = False). :param show: plot data in matplotlib figure (default = False). :param ax: a matplotlib.axes.Axes instance, optional (default = None). :return: 1D array with the indices of the peaks ''' print('min peak height:', mph, ', min peak distance:', mpd, ', min peak prominence:', threshold) # Convert input to numpy array x = np.atleast_1d(x).astype('float64') # Revert signal sign for valley detection if valley: x = -x # Differentiate signal dx = np.diff(x) # Find indices of all peaks with edge criterion ine, ire, ife = np.array([[], [], []], dtype=int) if not edge: ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0] else: if edge.lower() in ['rising', 'both']: ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0] if edge.lower() in ['falling', 'both']: ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0] ind = np.unique(np.hstack((ine, ire, ife))) # Remove first and last values of x if they are detected as peaks if ind.size and ind[0] == 0: ind = ind[1:] if ind.size and ind[-1] == x.size - 1: ind = ind[:-1] print('{} raw peaks'.format(ind.size)) # Remove peaks < minimum peak height if ind.size and mph is not None: ind = ind[x[ind] >= mph] print('{} height-filtered peaks'.format(ind.size)) # Remove peaks - neighbors < threshold if ind.size and threshold > 0: dx = np.min(np.vstack([x[ind] - x[ind - 1], x[ind] - x[ind + 1]]), axis=0) ind = np.delete(ind, np.where(dx < threshold)[0]) print('{} prominence-filtered peaks'.format(ind.size)) # Detect small peaks closer than minimum peak distance if ind.size and mpd > 1: ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height idel = np.zeros(ind.size, dtype=bool) for i in range(ind.size): if not idel[i]: # keep peaks with the same height if kpsh is True idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \ & (x[ind[i]] > x[ind] if kpsh else True) idel[i] = 0 # Keep current peak # remove the small peaks and sort back the indices by their occurrence ind = np.sort(ind[~idel]) print('{} distance-filtered peaks'.format(ind.size)) return ind def detectPeaksTime(t, y, mph, mtd, mpp=0): ''' Extension of the detectPeaks function to detect peaks in data based on their amplitude and time difference, with a non-uniform time vector. :param t: time vector (not necessarily uniform) :param y: signal :param mph: minimal peak height :param mtd: minimal time difference :mpp: minmal peak prominence :return: array of peak indexes ''' # Determine whether time vector is uniform (threshold in time step variation) dt = np.diff(t) if (dt.max() - dt.min()) / dt.min() < 1e-2: isuniform = True else: isuniform = False if isuniform: print('uniform time vector') dt = t[1] - t[0] mpd = int(np.ceil(mtd / dt)) ipeaks = detectPeaks(y, mph, mpd=mpd, threshold=mpp) else: print('non-uniform time vector') # Detect peaks on signal with no restriction on inter-peak distance irawpeaks = detectPeaks(y, mph, mpd=1, threshold=mpp) npeaks = irawpeaks.size if npeaks > 0: # Filter relevant peaks with temporal distance ipeaks = [irawpeaks[0]] for i in range(1, npeaks): i1 = ipeaks[-1] i2 = irawpeaks[i] if t[i2] - t[i1] < mtd: if y[i2] > y[i1]: ipeaks[-1] = i2 else: ipeaks.append(i2) else: ipeaks = [] ipeaks = np.array(ipeaks) return ipeaks def detectSpikes(t, Qm, min_amp, min_dt): ''' Detect spikes on a charge density signal, and return their number, latency and rate. :param t: time vector (s) :param Qm: charge density vector (C/m2) :param min_amp: minimal charge amplitude to detect spikes (C/m2) :param min_dt: minimal time interval between 2 spikes (s) :return: 3-tuple with number of spikes, latency (s) and spike rate (sp/s) ''' i_spikes = detectPeaksTime(t, Qm, min_amp, min_dt) if len(i_spikes) > 0: latency = t[i_spikes[0]] # s n_spikes = i_spikes.size if n_spikes > 1: first_to_last_spike = t[i_spikes[-1]] - t[i_spikes[0]] # s spike_rate = (n_spikes - 1) / first_to_last_spike # spikes/s else: spike_rate = 'N/A' else: latency = 'N/A' spike_rate = 'N/A' n_spikes = 0 return (n_spikes, latency, spike_rate) def findPeaks(y, mph=None, mpd=None, mpp=None): ''' Detect peaks in a signal based on their height, prominence and/or separating distance. :param y: signal vector :param mph: minimum peak height (in signal units, default = None). :param mpd: minimum inter-peak distance (in indexes, default = None) :param mpp: minimum peak prominence (in signal units, default = None) :return: 4-tuple of arrays with the indexes of peaks occurence, peaks prominence, peaks width at half-prominence and peaks half-prominence bounds (left and right) Adapted from: - Marco Duarte's detect_peaks function (http://nbviewer.jupyter.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb) - MATLAB findpeaks function (https://ch.mathworks.com/help/signal/ref/findpeaks.html) ''' # Define empty output empty = (np.array([]),) * 4 # Differentiate signal dy = np.diff(y) # Find all peaks and valleys # s = np.sign(dy) # ipeaks = np.where(np.diff(s) < 0.0)[0] + 1 # ivalleys = np.where(np.diff(s) > 0.0)[0] + 1 ipeaks = np.where((np.hstack((dy, 0)) <= 0) & (np.hstack((0, dy)) > 0))[0] ivalleys = np.where((np.hstack((dy, 0)) >= 0) & (np.hstack((0, dy)) < 0))[0] # Return empty output if no peak detected if ipeaks.size == 0: return empty logger.debug('%u peaks found, starting at index %u and ending at index %u', ipeaks.size, ipeaks[0], ipeaks[-1]) if ivalleys.size > 0: logger.debug('%u valleys found, starting at index %u and ending at index %u', ivalleys.size, ivalleys[0], ivalleys[-1]) else: logger.debug('no valleys found') # Ensure each peak is bounded by two valleys, adding signal boundaries as valleys if necessary if ivalleys.size == 0 or ipeaks[0] < ivalleys[0]: ivalleys = np.insert(ivalleys, 0, -1) if ipeaks[-1] > ivalleys[-1]: ivalleys = np.insert(ivalleys, ivalleys.size, y.size - 1) if ivalleys.size - ipeaks.size != 1: logger.debug('Cleaning up incongruities') i = 0 while i < min(ipeaks.size, ivalleys.size) - 1: if ipeaks[i] < ivalleys[i]: # 2 peaks between consecutive valleys -> remove lowest idel = i - 1 if y[ipeaks[i - 1]] < y[ipeaks[i]] else i logger.debug('Removing abnormal peak at index %u', ipeaks[idel]) ipeaks = np.delete(ipeaks, idel) if ipeaks[i] > ivalleys[i + 1]: idel = i + 1 if y[ivalleys[i]] < y[ivalleys[i + 1]] else i logger.debug('Removing abnormal valley at index %u', ivalleys[idel]) ivalleys = np.delete(ivalleys, idel) else: i += 1 logger.debug('Post-cleanup: %u peaks and %u valleys', ipeaks.size, ivalleys.size) # Remove peaks < minimum peak height if mph is not None: ipeaks = ipeaks[y[ipeaks] >= mph] if ipeaks.size == 0: return empty # Detect small peaks closer than minimum peak distance if mpd is not None: ipeaks = ipeaks[np.argsort(y[ipeaks])][::-1] # sort ipeaks by descending peak height idel = np.zeros(ipeaks.size, dtype=bool) # initialize boolean deletion array (all false) for i in range(ipeaks.size): # for each peak if not idel[i]: # if not marked for deletion closepeaks = (ipeaks >= ipeaks[i] - mpd) & (ipeaks <= ipeaks[i] + mpd) # close peaks idel = idel | closepeaks # mark for deletion along with previously marked peaks # idel = idel | (ipeaks >= ipeaks[i] - mpd) & (ipeaks <= ipeaks[i] + mpd) idel[i] = 0 # keep current peak # remove the small peaks and sort back the indices by their occurrence ipeaks = np.sort(ipeaks[~idel]) # Detect smallest valleys between consecutive relevant peaks ibottomvalleys = [] if ipeaks[0] > ivalleys[0]: itrappedvalleys = ivalleys[ivalleys < ipeaks[0]] ibottomvalleys.append(itrappedvalleys[np.argmin(y[itrappedvalleys])]) for i, j in zip(ipeaks[:-1], ipeaks[1:]): itrappedvalleys = ivalleys[np.logical_and(ivalleys > i, ivalleys < j)] ibottomvalleys.append(itrappedvalleys[np.argmin(y[itrappedvalleys])]) if ipeaks[-1] < ivalleys[-1]: itrappedvalleys = ivalleys[ivalleys > ipeaks[-1]] ibottomvalleys.append(itrappedvalleys[np.argmin(y[itrappedvalleys])]) ipeaks = ipeaks ivalleys = np.array(ibottomvalleys, dtype=int) # Ensure each peak is bounded by two valleys, adding signal boundaries as valleys if necessary if ipeaks[0] < ivalleys[0]: ivalleys = np.insert(ivalleys, 0, 0) if ipeaks[-1] > ivalleys[-1]: ivalleys = np.insert(ivalleys, ivalleys.size, y.size - 1) # Remove peaks < minimum peak prominence if mpp is not None: # Compute peaks prominences as difference between peaks and their closest valley prominences = y[ipeaks] - np.amax((y[ivalleys[:-1]], y[ivalleys[1:]]), axis=0) # initialize peaks and valleys deletion tables idelp = np.zeros(ipeaks.size, dtype=bool) idelv = np.zeros(ivalleys.size, dtype=bool) # for each peak (sorted by ascending prominence order) for ind in np.argsort(prominences): ipeak = ipeaks[ind] # get peak index # get peak bases as first valleys on either side not marked for deletion indleftbase = ind indrightbase = ind + 1 while idelv[indleftbase]: indleftbase -= 1 while idelv[indrightbase]: indrightbase += 1 ileftbase = ivalleys[indleftbase] irightbase = ivalleys[indrightbase] # Compute peak prominence and mark for deletion if < mpp indmaxbase = indleftbase if y[ileftbase] > y[irightbase] else indrightbase if y[ipeak] - y[ivalleys[indmaxbase]] < mpp: idelp[ind] = True # mark peak for deletion idelv[indmaxbase] = True # mark highest surrouding valley for deletion # remove irrelevant peaks and valleys, and sort back the indices by their occurrence ipeaks = np.sort(ipeaks[~idelp]) ivalleys = np.sort(ivalleys[~idelv]) if ipeaks.size == 0: return empty # Compute peaks prominences and reference half-prominence levels prominences = y[ipeaks] - np.amax((y[ivalleys[:-1]], y[ivalleys[1:]]), axis=0) refheights = y[ipeaks] - prominences / 2 # Compute half-prominence bounds halfmaxbounds = np.empty((ipeaks.size, 2)) for i in range(ipeaks.size): # compute the index of the left-intercept at half max ileft = ipeaks[i] while ileft >= ivalleys[i] and y[ileft] > refheights[i]: ileft -= 1 if ileft < ivalleys[i]: # intercept exactly on valley halfmaxbounds[i, 0] = ivalleys[i] else: # interpolate intercept linearly between signal boundary points a = (y[ileft + 1] - y[ileft]) / 1 b = y[ileft] - a * ileft halfmaxbounds[i, 0] = (refheights[i] - b) / a # compute the index of the right-intercept at half max iright = ipeaks[i] while iright <= ivalleys[i + 1] and y[iright] > refheights[i]: iright += 1 if iright > ivalleys[i + 1]: # intercept exactly on valley halfmaxbounds[i, 1] = ivalleys[i + 1] else: # interpolate intercept linearly between signal boundary points if iright == y.size - 1: # special case: if end of signal is reached, decrement iright iright -= 1 a = (y[iright + 1] - y[iright]) / 1 b = y[iright] - a * iright halfmaxbounds[i, 1] = (refheights[i] - b) / a # Compute peaks widths at half-prominence widths = np.diff(halfmaxbounds, axis=1) # Convert halfmaxbounds to true integers halfmaxbounds[:, 0] = np.floor(halfmaxbounds[:, 0]) halfmaxbounds[:, 1] = np.ceil(halfmaxbounds[:, 1]) halfmaxbounds = halfmaxbounds.astype(int) bounds = np.array([ivalleys[:-1], ivalleys[1:]]).T - 1 bounds[bounds < 0] = 0 return (ipeaks - 1, prominences, widths, halfmaxbounds, bounds) def computeSpikingMetrics(filenames): ''' Analyze the charge density profile from a list of files and compute for each one of them the following spiking metrics: - latency (ms) - firing rate mean and standard deviation (Hz) - spike amplitude mean and standard deviation (nC/cm2) - spike width mean and standard deviation (ms) :param filenames: list of files to analyze :return: a dataframe with the computed metrics ''' # Initialize metrics dictionaries keys = [ 'latencies (ms)', 'mean firing rates (Hz)', 'std firing rates (Hz)', 'mean spike amplitudes (nC/cm2)', 'std spike amplitudes (nC/cm2)', 'mean spike widths (ms)', 'std spike widths (ms)' ] metrics = {k: [] for k in keys} # Compute spiking metrics for fname in filenames: # Load data from file logger.debug('loading data from file "{}"'.format(fname)) with open(fname, 'rb') as fh: frame = pickle.load(fh) df = frame['data'] meta = frame['meta'] tstim = meta['tstim'] t = df['t'].values Qm = df['Qm'].values dt = t[1] - t[0] # Detect spikes on charge profile mpd = int(np.ceil(SPIKE_MIN_DT / dt)) ispikes, prominences, widths, *_ = findPeaks(Qm, SPIKE_MIN_QAMP, mpd, SPIKE_MIN_QPROM) widths *= dt if ispikes.size > 0: # Compute latency latency = t[ispikes[0]] # Select prior-offset spikes ispikes_prior = ispikes[t[ispikes] < tstim] else: latency = np.nan ispikes_prior = np.array([]) # Compute spikes widths and amplitude if ispikes_prior.size > 0: widths_prior = widths[:ispikes_prior.size] prominences_prior = prominences[:ispikes_prior.size] else: widths_prior = np.array([np.nan]) prominences_prior = np.array([np.nan]) # Compute inter-spike intervals and firing rates if ispikes_prior.size > 1: ISIs_prior = np.diff(t[ispikes_prior]) FRs_prior = 1 / ISIs_prior else: ISIs_prior = np.array([np.nan]) FRs_prior = np.array([np.nan]) # Log spiking metrics logger.debug('%u spikes detected (%u prior to offset)', ispikes.size, ispikes_prior.size) logger.debug('latency: %.2f ms', latency * 1e3) logger.debug('average spike width within stimulus: %.2f +/- %.2f ms', np.nanmean(widths_prior) * 1e3, np.nanstd(widths_prior) * 1e3) logger.debug('average spike amplitude within stimulus: %.2f +/- %.2f nC/cm2', np.nanmean(prominences_prior) * 1e5, np.nanstd(prominences_prior) * 1e5) logger.debug('average ISI within stimulus: %.2f +/- %.2f ms', np.nanmean(ISIs_prior) * 1e3, np.nanstd(ISIs_prior) * 1e3) logger.debug('average FR within stimulus: %.2f +/- %.2f Hz', np.nanmean(FRs_prior), np.nanstd(FRs_prior)) # Complete metrics dictionaries metrics['latencies (ms)'].append(latency * 1e3) metrics['mean firing rates (Hz)'].append(np.mean(FRs_prior)) metrics['std firing rates (Hz)'].append(np.std(FRs_prior)) metrics['mean spike amplitudes (nC/cm2)'].append(np.mean(prominences_prior) * 1e5) metrics['std spike amplitudes (nC/cm2)'].append(np.std(prominences_prior) * 1e5) metrics['mean spike widths (ms)'].append(np.mean(widths_prior) * 1e3) metrics['std spike widths (ms)'].append(np.std(widths_prior) * 1e3) # Return dataframe with metrics return pd.DataFrame(metrics, columns=metrics.keys()) diff --git a/scripts/STN_regime_transition.py b/scripts/STN_regime_transition.py index f1eaa7d..36d56b2 100644 --- a/scripts/STN_regime_transition.py +++ b/scripts/STN_regime_transition.py @@ -1,260 +1,295 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2018-09-28 16:13:34 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-03-11 21:48:12 +# @Last Modified time: 2019-03-12 01:24:26 ''' Script to study STN transitions between different behavioral regimesl. ''' import os import pickle import numpy as np import matplotlib.pyplot as plt import matplotlib from argparse import ArgumentParser import logging from PySONIC.core import NeuronalBilayerSonophore from PySONIC.utils import * +from PySONIC.postpro import getStableFixedPoints from PySONIC.neurons import getNeuronsDict # Plot parameters matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 matplotlib.rcParams['font.family'] = 'arial' # Set logging level logger.setLevel(logging.INFO) +def getStableQmQSS(neuron, a, Fdrive, amps): + + # Compute net current profile for each amplitude, from QSS states and Vmeff profiles + nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) + _, Qref, Vmeff, QS_states = nbls.getQSSvars(Fdrive, amps=amps) + iNet = neuron.iNet(Vmeff, QS_states) + + # Find stable fixed points in iNet(Qref) profile + Qstab = [] + for i, Adrive in enumerate(amps): + Qstab.append(getStableFixedPoints(Qref, -iNet[i, :])) + + return Qstab + + +def getChargeStabilizationFromSims(inputdir, neuron, a, Fdrive, amps, tstim, PRF=100, DC=1.0): + + # Get filenames + fnames = ['{}.pkl'.format(ASTIM_filecode(neuron.name, a, Fdrive, A, tstim, PRF, DC, 'sonic')) + for A in amps] + + # Initialize output arrays + tstab = np.empty(amps.size) + Qstab = np.empty(amps.size) + + # For each file + for i, fn in enumerate(fnames): + + # Extract charge temporal profile during stimulus + fp = os.path.join(inputdir, 'STN', fn) + # logger.info('loading data from file "{}"'.format(fn)) + with open(fp, 'rb') as fh: + frame = pickle.load(fh) + df = frame['data'] + t = df['t'].values + Qm = df['Qm'].values + Qm = Qm[t < tstim] + t = t[t < tstim] + dt = np.diff(t) + + # If charge signal is stable during last 100 ms of stimulus + if np.ptp(Qm[-int(100e-3 // dt[0]):]) < 5e-5: + + # Compute instant of stabilization by iNet thresholding + iNet_abs = np.abs(np.diff(Qm)) / dt + Qstab[i] = Qm[-1] + tstab[i] = t[np.where(iNet_abs > 1e-3)[0][-1] + 2] + logger.info('A = %.2f kPa: Qm stabilization around %.2f nC/cm2 from t = %.0f ms onward', + amps[i] * 1e-3, Qstab[i] * 1e5, tstab[i] * 1e3) + + # Otherwise, populate arrays with NaN + else: + Qstab[i] = np.nan + tstab[i] = np.nan + logger.info('A = %.2f kPa: no Qm stabilization', amps[i] * 1e-3) + + return Qstab, tstab + + def plotQSSvars_vs_Qm(neuron, a, Fdrive, Adrive, fs=12): # Get quasi-steady states and effective membrane potential profiles at this amplitude nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) _, Qref, Vmeff, QS_states = nbls.getQSSvars(Fdrive, amps=Adrive) # Compute QSS currents currents = neuron.currents(Vmeff, QS_states) iNet = sum(currents.values()) # Create figure fig, axes = plt.subplots(3, 1, figsize=(7, 9)) axes[-1].set_xlabel('Charge Density (nC/cm2)', fontsize=fs) for ax in axes: for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) for item in ax.get_xticklabels(minor=True): item.set_visible(False) figname = '{} neuron QSS dynamics @ {:.2f}kPa'.format(neuron.name, Adrive * 1e-3) fig.suptitle(figname, fontsize=fs) # Subplot 1: Vmeff ax = axes[0] ax.set_ylabel('$V_m^*$ (mV)', fontsize=fs) ax.plot(Qref * 1e5, Vmeff, color='C0') ax.axhline(neuron.Vm0, linewidth=0.5, color='k') # Subplot 2: quasi-steady states ax = axes[1] ax.set_ylabel('$X_\infty$', fontsize=fs) ax.set_yticks([0, 0.5, 1]) ax.set_ylim([-0.05, 1.05]) - for label, qsstate in zip(neuron.states_names[:-1], QS_states[:-1]): + for label, qsstate in zip(neuron.states_names, QS_states): ax.plot(Qref * 1e5, qsstate, label=label) # Subplot 3: currents ax = axes[2] ax.set_ylabel('QSS currents (A/m2)', fontsize=fs) for k, I in currents.items(): ax.plot(Qref * 1e5, I * 1e-3, label=k) ax.plot(Qref * 1e5, iNet * 1e-3, color='k', label='iNet') ax.axhline(0, color='k', linewidth=0.5) fig.tight_layout() fig.subplots_adjust(right=0.8) for ax in axes[1:]: ax.legend(loc='center right', fontsize=fs, frameon=False, bbox_to_anchor=(1.3, 0.5)) + fig.canvas.set_window_title( + '{}_QSS_states_vs_Qm_{:.2f}kPa'.format(neuron.name, Adrive * 1e-3)) + return fig def plotInetQSS_vs_Qm(neuron, a, Fdrive, amps, fs=12, cmap='viridis', zscale='lin'): # Compute net current profile for each amplitude, from QSS states and Vmeff profiles nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) _, Qref, Vmeff, QS_states = nbls.getQSSvars(Fdrive, amps=amps) iNet = neuron.iNet(Vmeff, QS_states) # Define color code mymap = plt.get_cmap(cmap) zref = amps * 1e-3 if zscale == 'lin': norm = matplotlib.colors.Normalize(zref.min(), zref.max()) elif zscale == 'log': norm = matplotlib.colors.LogNorm(zref.min(), zref.max()) sm = matplotlib.cm.ScalarMappable(norm=norm, cmap=mymap) sm._A = [] # Create figure fig, ax = plt.subplots(figsize=(6, 4)) ax.set_xlabel('$\\rm Q_m\ (nC/cm^2)$', fontsize=fs) ax.set_ylabel('$\\rm I_{net, QSS}\ (A/m^2)$', fontsize=fs) for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) figname = '{} neuron - QSS current imbalance vs. amplitude'.format(neuron.name) ax.set_title(figname, fontsize=fs) ax.axhline(0, color='k', linewidth=0.5) # Plot iNet profiles for each US amplitude (with specific color code) for i, Adrive in enumerate(amps): lbl = '{:.2f} kPa'.format(Adrive * 1e-3) c = sm.to_rgba(Adrive * 1e-3) ax.plot(Qref * 1e5, iNet[i] * 1e-3, label=lbl, c=c) + for i, Adrive in enumerate(amps): + Qstab = getStableFixedPoints(Qref, -iNet[i, :]) + if Qstab is not None: + ax.plot(Qstab * 1e5, np.zeros(Qstab.size), '.', c='k') fig.tight_layout() # Plot US amplitude colorbar fig.subplots_adjust(bottom=0.15, top=0.9, right=0.80, hspace=0.5) cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.75]) fig.colorbar(sm, cax=cbarax) cbarax.set_ylabel('Amplitude (kPa)', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) fig.canvas.set_window_title( '{}_iNet_QSS_vs_amp'.format(neuron.name)) return fig -def getChargeStabilizationFromSims(inputdir, neuron, a, Fdrive, amps, tstim, PRF=100, DC=1.0): - - # Get filenames - fnames = ['{}.pkl'.format(ASTIM_filecode(neuron.name, a, Fdrive, A, tstim, PRF, DC, 'sonic')) - for A in amps] - - # Initialize output arrays - tstab = np.empty(amps.size) - Qstab = np.empty(amps.size) - - # For each file - for i, fn in enumerate(fnames): - - # Extract charge temporal profile during stimulus - fp = os.path.join(inputdir, 'STN', fn) - logger.info('loading data from file "{}"'.format(fn)) - with open(fp, 'rb') as fh: - frame = pickle.load(fh) - df = frame['data'] - t = df['t'].values - Qm = df['Qm'].values - Qm = Qm[t < tstim] - t = t[t < tstim] - dt = np.diff(t) - - # If charge signal is stable during last 100 ms of stimulus - if np.ptp(Qm[-int(100e-3 // dt[0]):]) < 5e-5: - - # Compute instant of stabilization by iNet thresholding - iNet_abs = np.abs(np.diff(Qm)) / dt - Qstab[i] = Qm[-1] - tstab[i] = t[np.where(iNet_abs > 1e-3)[0][-1] + 2] - logger.info('Qm stabilization around %.2f nC/cm2 from t = %.0f ms onward', - Qstab[i] * 1e5, tstab[i] * 1e3) - - # Otherwise, populate arrays with NaN - else: - Qstab[i] = np.nan - tstab[i] = np.nan - logger.info('No Qm stabilization') - - return Qstab, tstab - - -def getEqChargesFromQSS(neuron, a, Fdrive, amps, Qthr=None): - - # Compute net current profile for each amplitude, from QSS states and Vmeff profiles - nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) - _, Qref, Vmeff, QS_states = nbls.getQSSvars(Fdrive, amps=amps) - iNet = neuron.iNet(Vmeff, QS_states) - - # Restrict iNet root-finding to a certain charge interval if provided - if Qthr is not None: - iNet = iNet[:, Qref >= Qthr] - Qref = Qref[Qref >= Qthr] - - # Interpolate charge density vector at iNet = 0 for each amplitude - Qeq_QSS = np.array([np.interp(0, iNet[i, :], Qref, left=0., right=np.nan) - for i in range(amps.size)]) - - return Qeq_QSS - - def compareEqChargesQSSvsSim(inputdir, neuron, a, Fdrive, amps, tstim, fs=12): # Get charge value that cancels out net current in QSS approx. and sim - Qeq_QSS = getEqChargesFromQSS(neuron, a, Fdrive, amps, Qthr=-20e-5) + Qeq_QSS = getStableQmQSS(neuron, a, Fdrive, amps) + Qeq_sim, _ = getChargeStabilizationFromSims(inputdir, neuron, a, Fdrive, amps, tstim) # Plot Qm balancing net current as function of amplitude fig, ax = plt.subplots(figsize=(6, 4)) figname = '{} neuron - equilibrium charge vs. amplitude'.format(neuron.name) ax.set_title(figname) ax.set_xlabel('Amplitude (kPa)', fontsize=fs) - ax.set_ylabel('$\\rm Q_{thr}\ (nC/cm^2)$', fontsize=fs) + ax.set_ylabel('$\\rm Q_{eq}\ (nC/cm^2)$', fontsize=fs) for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) - ax.plot(amps * 1e-3, Qeq_QSS * 1e5, label='QSS approximation') - ax.plot(amps * 1e-3, Qeq_sim * 1e5, label='end of {:.2f} s stimulus (simulation)'.format(tstim)) + + lgd = True + for Adrive, Qstab in zip(amps, Qeq_QSS): + if Qstab is not None: + if lgd: + lbl = 'QSS approximation' + lgd = False + else: + lbl = None + ax.plot(np.ones(Qstab.size) * Adrive * 1e-3, Qstab * 1e5, '.', c='C0', label=lbl) + # ax.plot(amps * 1e-3, Qeq_QSS * 1e5, label='QSS approximation') + ax.plot(amps * 1e-3, Qeq_sim * 1e5, c='C1', + label='end of {:.2f} s stimulus (simulation)'.format(tstim)) ax.legend(frameon=False, fontsize=fs) fig.tight_layout() fig.canvas.set_window_title( - '{}_Qthr_vs_amp'.format(neuron.name)) + '{}_Qeq_vs_amp'.format(neuron.name)) + + return fig def main(): ap = ArgumentParser() # Stimulation parameters ap.add_argument('-i', '--inputdir', type=str, default=None, help='Input directory') + ap.add_argument('-o', '--outputdir', type=str, default=None, help='Output directory') ap.add_argument('-f', '--figset', type=str, nargs='+', help='Figure set', default='all') ap.add_argument('-c', '--cmap', type=str, default='viridis', help='Colormap name') ap.add_argument('-v', '--verbose', default=False, action='store_true', help='Increase verbosity') + ap.add_argument('-s', '--save', default=False, action='store_true', + help='Save output figures as pdf') # Parse arguments args = ap.parse_args() logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) figset = args.figset if figset == 'all': figset = ['a', 'b', 'c'] neuron = getNeuronsDict()['STN']() a = 32e-9 # m Fdrive = 500e3 # Hz intensities = getLowIntensitiesSTN() # W/m2 amps = Intensity2Pressure(intensities) # Pa tstim = 1.0 # s figs = [] if 'a' in figset: for Adrive in [amps[0], amps[-1]]: figs.append(plotQSSvars_vs_Qm(neuron, a, Fdrive, Adrive)) if 'b' in figset: figs.append(plotInetQSS_vs_Qm(neuron, a, Fdrive, amps)) if 'c' in figset: inputdir = args.inputdir if args.inputdir is not None else selectDirDialog() - figs.append(compareEqChargesQSSvsSim(inputdir, neuron, a, Fdrive, amps, tstim)) + if inputdir == '': + logger.error('no input directory') + else: + figs.append(compareEqChargesQSSvsSim(inputdir, neuron, a, Fdrive, amps, tstim)) - plt.show() + if args.save: + outputdir = args.outputdir if args.outputdir is not None else selectDirDialog() + if outputdir == '': + logger.error('no output directory') + else: + for fig in figs: + s = fig.canvas.get_window_title() + s = s.replace('(', '- ').replace('/', '_').replace(')', '') + figname = '{}.pdf'.format(s) + fig.savefig(os.path.join(outputdir, figname), transparent=True) + else: + plt.show() if __name__ == '__main__': main()