diff --git a/PySONIC/core/bls.py b/PySONIC/core/bls.py index 833968c..739d429 100644 --- a/PySONIC/core/bls.py +++ b/PySONIC/core/bls.py @@ -1,808 +1,808 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2016-09-29 16:16:19 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-07-22 18:44:26 +# @Last Modified time: 2020-07-30 19:01:21 from enum import Enum import os import json import numpy as np import pandas as pd import scipy.integrate as integrate from scipy.optimize import brentq, curve_fit from .model import Model from .solvers import PeriodicSolver from .drives import Drive, AcousticDrive from ..utils import logger, si_format from ..constants import * class PmCompMethod(Enum): ''' Enum: types of computation method for the intermolecular pressure ''' direct = 1 predict = 2 def LennardJones(x, beta, alpha, C, m, n): ''' Generic expression of a Lennard-Jones function, adapted for the context of symmetric deflection (distance = 2x). :param x: deflection (i.e. half-distance) :param beta: x-shifting factor :param alpha: x-scaling factor :param C: y-scaling factor :param m: exponent of the repulsion term :param n: exponent of the attraction term :return: Lennard-Jones potential at given distance (2x) ''' return C * (np.power((alpha / (2 * x + beta)), m) - np.power((alpha / (2 * x + beta)), n)) def lookup(func): ''' Load parameters from lookup file, or compute them and store them in lookup file. ''' lookup_path = os.path.join(os.path.split(__file__)[0], 'bls_lookups.json') def wrapper(obj): akey = f'{obj.a * 1e9:.1f}' Qkey = f'{obj.Qm0 * 1e5:.2f}' # Open lookup files try: with open(lookup_path, 'r') as fh: lookups = json.load(fh) except FileNotFoundError: lookups = {} # If info not in lookups, compute parameters and add them if akey not in lookups or Qkey not in lookups[akey]: func(obj) if akey not in lookups: lookups[akey] = {Qkey: {'LJ_approx': obj.LJ_approx, 'Delta_eq': obj.Delta}} else: lookups[akey][Qkey] = {'LJ_approx': obj.LJ_approx, 'Delta_eq': obj.Delta} logger.debug('Saving BLS derived parameters to lookup file') with open(lookup_path, 'w') as fh: json.dump(lookups, fh, indent=2) # If lookup exists, load parameters from it else: logger.debug('Loading BLS derived parameters from lookup file') obj.LJ_approx = lookups[akey][Qkey]['LJ_approx'] obj.Delta = lookups[akey][Qkey]['Delta_eq'] return wrapper class BilayerSonophore(Model): ''' Definition of the Bilayer Sonophore Model - geometry - pressure terms - cavitation dynamics ''' # BIOMECHANICAL PARAMETERS T = 309.15 # Temperature (K) delta0 = 2.0e-9 # Thickness of the leaflet (m) Delta_ = 1.4e-9 # Initial gap between the two leaflets on a non-charged membrane at equil. (m) pDelta = 1.0e5 # Attraction/repulsion pressure coefficient (Pa) m = 5.0 # Exponent in the repulsion term (dimensionless) n = 3.3 # Exponent in the attraction term (dimensionless) rhoL = 1075.0 # Density of the surrounding fluid (kg/m^3) muL = 7.0e-4 # Dynamic viscosity of the surrounding fluid (Pa.s) muS = 0.035 # Dynamic viscosity of the leaflet (Pa.s) kA = 0.24 # Area compression modulus of the leaflet (N/m) alpha = 7.56 # Tissue shear loss modulus frequency coefficient (Pa.s) C0 = 0.62 # Initial gas molar concentration in the surrounding fluid (mol/m^3) kH = 1.613e5 # Henry's constant (Pa.m^3/mol) P0 = 1.0e5 # Static pressure in the surrounding fluid (Pa) Dgl = 3.68e-9 # Diffusion coefficient of gas in the fluid (m^2/s) xi = 0.5e-9 # Boundary layer thickness for gas transport across leaflet (m) c = 1515.0 # Speed of sound in medium (m/s) # BIOPHYSICAL PARAMETERS epsilon0 = 8.854e-12 # Vacuum permittivity (F/m) epsilonR = 1.0 # Relative permittivity of intramembrane cavity (dimensionless) rel_Zmin = -0.49 # relative deflection range lower bound (in multiples of Delta) tscale = 'us' # relevant temporal scale of the model simkey = 'MECH' # keyword used to characterize simulations made with this model def __init__(self, a, Cm0, Qm0, embedding_depth=0.0): ''' Constructor of the class. :param a: in-plane radius of the sonophore structure within the membrane (m) :param Cm0: membrane resting capacitance (F/m2) :param Qm0: membrane resting charge density (C/m2) :param embedding_depth: depth of the embedding tissue around the membrane (m) ''' # Extract resting constants and geometry self.Cm0 = Cm0 self.Qm0 = Qm0 self.a = a self.d = embedding_depth self.S0 = np.pi * self.a**2 # Initialize null elastic modulus for tissue self.kA_tissue = 0. # Compute Pm params self.computePMparams() # Compute initial volume and gas content self.V0 = np.pi * self.Delta * self.a**2 self.ng0 = self.gasPa2mol(self.P0, self.V0) def copy(self): return self.__class__(self.a, self.Cm0, self.Qm0, embedding_depth=self.d) @property def a(self): return self._a @a.setter def a(self, value): if value <= 0.: raise ValueError('Sonophore radius must be positive') self._a = value @property def Cm0(self): return self._Cm0 @Cm0.setter def Cm0(self, value): if value <= 0.: raise ValueError('Resting membrane capacitance must be positive') self._Cm0 = value @property def d(self): return self._d @d.setter def d(self, value): if value < 0.: raise ValueError('Embedding depth cannot be negative') self._d = value def __repr__(self): s = f'{self.__class__.__name__}({self.a * 1e9:.1f} nm' if self.d > 0.: s += f', d={si_format(self.d, precision=1)}m' return f'{s})' @property def meta(self): return { 'a': self.a, 'd': self.d, 'Cm0': self.Cm0, 'Qm0': self.Qm0, } @classmethod def initFromMeta(cls, d): return cls(d['a'], d['Cm0'], d['Qm0']) @staticmethod def inputs(): return { 'a': { 'desc': 'sonophore radius', 'label': 'a', 'unit': 'm', 'precision': 0 }, 'Qm': { 'desc': 'membrane charge density', 'label': 'Q_m', 'unit': 'nC/cm^2', 'factor': 1e5, 'precision': 1 }, **AcousticDrive.inputs() } def filecodes(self, drive, Qm, PmCompMethod='predict'): return { 'simkey': self.simkey, 'a': f'{self.a * 1e9:.0f}nm', **drive.filecodes, 'Qm': f'{Qm * 1e5:.1f}nCcm2' } @staticmethod def getPltVars(wl='df["', wr='"]'): return { 'Pac': { 'desc': 'acoustic pressure', 'label': 'P_{AC}', 'unit': 'kPa', 'factor': 1e-3, 'func': f'meta["drive"].compute({wl}t{wr})' }, 'Z': { 'desc': 'leaflets deflection', 'label': 'Z', 'unit': 'nm', 'factor': 1e9, 'bounds': (-1.0, 10.0) }, 'ng': { 'desc': 'gas content', 'label': 'n_g', 'unit': '10^{-22}\ mol', 'factor': 1e22, 'bounds': (1.0, 15.0) }, 'Pmavg': { 'desc': 'average intermolecular pressure', 'label': 'P_M', 'unit': 'kPa', 'factor': 1e-3, 'func': f'PMavgpred({wl}Z{wr})' }, 'Telastic': { 'desc': 'leaflet elastic tension', 'label': 'T_E', 'unit': 'mN/m', 'factor': 1e3, 'func': f'TEleaflet({wl}Z{wr})' }, 'Cm': { 'desc': 'membrane capacitance', 'label': 'C_m', 'unit': 'uF/cm^2', 'factor': 1e2, 'bounds': (0.0, 1.5), 'func': f'v_capacitance({wl}Z{wr})' } } @property def pltScheme(self): return { 'P_{AC}': ['Pac'], 'Z': ['Z'], 'n_g': ['ng'] } @property def Zmin(self): return self.rel_Zmin * self.Delta def curvrad(self, Z): ''' Leaflet curvature radius (signed variable) :param Z: leaflet apex deflection (m) :return: leaflet curvature radius (m) ''' if Z == 0.0: return np.inf else: return (self.a**2 + Z**2) / (2 * Z) def v_curvrad(self, Z): ''' Vectorized curvrad function ''' return np.array(list(map(self.curvrad, Z))) def surface(self, Z): ''' Surface area of the stretched leaflet (spherical cap formula) :param Z: leaflet apex deflection (m) :return: stretched leaflet surface (m^2) ''' return np.pi * (self.a**2 + Z**2) def volume(self, Z): ''' Volume of the inter-leaflet space (cylinder +/- 2 spherical caps) :param Z: leaflet apex deflection (m) :return: bilayer sonophore inner volume (m^3) ''' return np.pi * self.a**2 * self.Delta\ * (1 + (Z / (3 * self.Delta) * (3 + Z**2 / self.a**2))) def arealStrain(self, Z): ''' Areal strain of the stretched leaflet epsilon = (S - S0)/S0 = (Z/a)^2 :param Z: leaflet apex deflection (m) :return: areal strain (dimensionless) ''' return (Z / self.a)**2 def capacitance(self, Z): ''' Membrane capacitance (parallel-plate capacitor evaluated at average inter-layer distance) :param Z: leaflet apex deflection (m) :return: capacitance per unit area (F/m2) ''' if Z == 0.0: return self.Cm0 else: return ((self.Cm0 * self.Delta / self.a**2) * (Z + (self.a**2 - Z**2 - Z * self.Delta) / (2 * Z) * np.log((2 * Z + self.Delta) / self.Delta))) def v_capacitance(self, Z): ''' Vectorized capacitance function ''' return np.array(list(map(self.capacitance, Z))) def derCapacitance(self, Z, U): ''' Evolution of membrane capacitance :param Z: leaflet apex deflection (m) :param U: leaflet apex deflection velocity (m/s) :return: time derivative of capacitance per unit area (F/m2.s) ''' dCmdZ = ((self.Cm0 * self.Delta / self.a**2) * ((Z**2 + self.a**2) / (Z * (2 * Z + self.Delta)) - ((Z**2 + self.a**2) * np.log((2 * Z + self.Delta) / self.Delta)) / (2 * Z**2))) return dCmdZ * U @staticmethod def localDeflection(r, Z, R): ''' Local leaflet deflection at specific radial distance (signed) :param r: in-plane distance from center of the sonophore (m) :param Z: leaflet apex deflection (m) :param R: leaflet curvature radius (m) :return: local transverse leaflet deviation (m) ''' if np.abs(Z) == 0.0: return 0.0 else: return np.sign(Z) * (np.sqrt(R**2 - r**2) - np.abs(R) + np.abs(Z)) def PMlocal(self, r, Z, R): ''' Local intermolecular pressure :param r: in-plane distance from center of the sonophore (m) :param Z: leaflet apex deflection (m) :param R: leaflet curvature radius (m) :return: local intermolecular pressure (Pa) ''' z = self.localDeflection(r, Z, R) relgap = (2 * z + self.Delta) / self.Delta_ return self.pDelta * ((1 / relgap)**self.m - (1 / relgap)**self.n) def PMavg(self, Z, R, S): ''' Average intermolecular pressure across the leaflet (computed by quadratic integration) :param Z: leaflet apex outward deflection value (m) :param R: leaflet curvature radius (m) :param S: surface of the stretched leaflet (m^2) :return: averaged intermolecular resultant pressure (Pa) .. warning:: quadratic integration is computationally expensive. ''' # Integrate intermolecular force over an infinitely thin ring of radius r from 0 to a fTotal, _ = integrate.quad(lambda r, Z, R: 2 * np.pi * r * self.PMlocal(r, Z, R), 0, self.a, args=(Z, R)) return fTotal / S def v_PMavg(self, Z, R, S): ''' Vectorized PMavg function ''' return np.array(list(map(self.PMavg, Z, R, S))) def LJfitPMavg(self): ''' Determine optimal parameters of a Lennard-Jones expression approximating the average intermolecular pressure. These parameters are obtained by a nonlinear fit of the Lennard-Jones function for a range of deflection values between predetermined Zmin and Zmax. :return: 3-tuple with optimized LJ parameters for PmAvg prediction (Map) and the standard and max errors of the prediction in the fitting range (in Pascals) ''' # Determine lower bound of deflection range: when Pm = Pmmax PMmax = LJFIT_PM_MAX # Pa Zlb_range = (self.Zmin, 0.0) Zlb = brentq(lambda Z, Pmmax: self.PMavg(Z, self.curvrad(Z), self.surface(Z)) - PMmax, *Zlb_range, args=(PMmax), xtol=1e-16) # Create vectors for geometric variables Zub = 2 * self.a Z = np.arange(Zlb, Zub, 1e-11) Pmavg = self.v_PMavg(Z, self.v_curvrad(Z), self.surface(Z)) # Compute optimal nonlinear fit of custom LJ function with initial guess x0_guess = self.delta0 C_guess = 0.1 * self.pDelta nrep_guess = self.m nattr_guess = self.n pguess = (x0_guess, C_guess, nrep_guess, nattr_guess) popt, _ = curve_fit(lambda x, x0, C, nrep, nattr: LennardJones(x, self.Delta, x0, C, nrep, nattr), Z, Pmavg, p0=pguess, maxfev=100000) (x0_opt, C_opt, nrep_opt, nattr_opt) = popt Pmavg_fit = LennardJones(Z, self.Delta, x0_opt, C_opt, nrep_opt, nattr_opt) # Compute prediction error residuals = Pmavg - Pmavg_fit ss_res = np.sum(residuals**2) N = residuals.size std_err = np.sqrt(ss_res / N) max_err = max(np.abs(residuals)) logger.debug('LJ approx: x0 = %.2f nm, C = %.2f kPa, m = %.2f, n = %.2f', x0_opt * 1e9, C_opt * 1e-3, nrep_opt, nattr_opt) LJ_approx = {"x0": x0_opt, "C": C_opt, "nrep": nrep_opt, "nattr": nattr_opt} return (LJ_approx, std_err, max_err) @lookup def computePMparams(self): # Find Delta that cancels out Pm + Pec at Z = 0 (m) if self.Qm0 == 0.0: D_eq = self.Delta_ else: (D_eq, Pnet_eq) = self.findDeltaEq(self.Qm0) assert Pnet_eq < PNET_EQ_MAX, 'High Pnet at Z = 0 with ∆ = %.2f nm' % (D_eq * 1e9) self.Delta = D_eq # Find optimal Lennard-Jones parameters to approximate PMavg (self.LJ_approx, std_err, _) = self.LJfitPMavg() assert std_err < PMAVG_STD_ERR_MAX, 'High error in PmAvg nonlinear fit:'\ ' std_err = %.2f Pa' % std_err def PMavgpred(self, Z): ''' Approximated average intermolecular pressure (using nonlinearly fitted Lennard-Jones function) :param Z: leaflet apex deflection (m) :return: predicted average intermolecular pressure (Pa) ''' return LennardJones(Z, self.Delta, self.LJ_approx['x0'], self.LJ_approx['C'], self.LJ_approx['nrep'], self.LJ_approx['nattr']) def Pelec(self, Z, Qm): ''' Electrical pressure term :param Z: leaflet apex deflection (m) :param Qm: membrane charge density (C/m2) :return: electrical pressure (Pa) ''' relS = self.S0 / self.surface(Z) abs_perm = self.epsilon0 * self.epsilonR # F/m return - relS * Qm**2 / (2 * abs_perm) # Pa def findDeltaEq(self, Qm): ''' Compute the Delta that cancels out the (Pm + Pec) equation at Z = 0 for a given membrane charge density, using the Brent method to refine the pressure root iteratively. :param Qm: membrane charge density (C/m2) :return: equilibrium value (m) and associated pressure (Pa) ''' def dualPressure(Delta): x = (self.Delta_ / Delta) return (self.pDelta * (x**self.m - x**self.n) + self.Pelec(0.0, Qm)) Delta_eq = brentq(dualPressure, 0.1 * self.Delta_, 2.0 * self.Delta_, xtol=1e-16) logger.debug('∆eq = %.2f nm', Delta_eq * 1e9) return (Delta_eq, dualPressure(Delta_eq)) def gasFlux(self, Z, P): ''' Gas molar flux through the sonophore boundary layers :param Z: leaflet apex deflection (m) :param P: internal gas pressure (Pa) :return: gas molar flux (mol/s) ''' dC = self.C0 - P / self.kH return 2 * self.surface(Z) * self.Dgl * dC / self.xi @classmethod def gasmol2Pa(cls, ng, V): ''' Internal gas pressure for a given molar content :param ng: internal molar content (mol) :param V: sonophore inner volume (m^3) :return: internal gas pressure (Pa) ''' return ng * Rg * cls.T / V @classmethod def gasPa2mol(cls, P, V): ''' Internal gas molar content for a given pressure :param P: internal gas pressure (Pa) :param V: sonophore inner volume (m^3) :return: internal gas molar content (mol) ''' return P * V / (Rg * cls.T) def PtotQS(self, Z, ng, Qm, Pac, Pm_comp_method): ''' Net quasi-steady pressure for a given acoustic pressure (Ptot = Pm + Pg + Pec - P0 - Pac) :param Z: leaflet apex deflection (m) :param ng: internal molar content (mol) :param Qm: membrane charge density (C/m2) :param Pac: acoustic pressure (Pa) :param Pm_comp_method: computation method for average intermolecular pressure :return: total balance pressure (Pa) ''' if Pm_comp_method is PmCompMethod.direct: Pm = self.PMavg(Z, self.curvrad(Z), self.surface(Z)) elif Pm_comp_method is PmCompMethod.predict: Pm = self.PMavgpred(Z) return Pm + self.gasmol2Pa(ng, self.volume(Z)) - self.P0 - Pac + self.Pelec(Z, Qm) def balancedefQS(self, ng, Qm, Pac=0.0, Pm_comp_method=PmCompMethod.predict): ''' Quasi-steady equilibrium deflection for a given acoustic pressure (computed by approximating the root of quasi-steady pressure) :param ng: internal molar content (mol) :param Qm: membrane charge density (C/m2) :param Pac: external acoustic perturbation (Pa) :param Pm_comp_method: computation method for average intermolecular pressure :return: leaflet deflection canceling quasi-steady pressure (m) ''' Zbounds = (self.Zmin, self.a) Plb, Pub = [self.PtotQS(x, ng, Qm, Pac, Pm_comp_method) for x in Zbounds] assert (Plb > 0 > Pub), '[{}, {}] is not a sign changing interval for PtotQS'.format(*Zbounds) return brentq(self.PtotQS, *Zbounds, args=(ng, Qm, Pac, Pm_comp_method), xtol=1e-16) def TEleaflet(self, Z): ''' Elastic tension in leaflet :param Z: leaflet apex deflection (m) :return: circumferential elastic tension (N/m) ''' return self.kA * self.arealStrain(Z) def setTissueModulus(self, drive): ''' Set the frequency-dependent elastic modulus of the surrounding tissue. ''' G_tissue = self.alpha * drive.modulationFrequency # G'' (Pa) self.kA_tissue = 2 * G_tissue * self.d # kA of the tissue layer (N/m) def TEtissue(self, Z): ''' Elastic tension in surrounding viscoelastic layer :param Z: leaflet apex deflection (m) :return: circumferential elastic tension (N/m) ''' return self.kA_tissue * self.arealStrain(Z) def TEtot(self, Z): ''' Total elastic tension (leaflet + surrounding viscoelastic layer) :param Z: leaflet apex deflection (m) :return: circumferential elastic tension (N/m) ''' return self.TEleaflet(Z) + self.TEtissue(Z) def PEtot(self, Z, R): ''' Total elastic tension pressure (leaflet + surrounding viscoelastic layer) :param Z: leaflet apex deflection (m) :param R: leaflet curvature radius (m) :return: elastic tension pressure (Pa) ''' return - self.TEtot(Z) / R @classmethod def PVleaflet(cls, U, R): ''' Viscous stress pressure in leaflet :param U: leaflet apex deflection velocity (m/s) :param R: leaflet curvature radius (m) :return: leaflet viscous stress pressure (Pa) ''' return - 12 * U * cls.delta0 * cls.muS / R**2 @classmethod def PVfluid(cls, U, R): ''' Viscous stress pressure in surrounding medium :param U: leaflet apex deflection velocity (m/s) :param R: leaflet curvature radius (m) :return: fluid viscous stress pressure (Pa) ''' return - 4 * U * cls.muL / np.abs(R) @classmethod def accP(cls, Ptot, R): ''' Leaflet transverse acceleration resulting from pressure imbalance :param Ptot: net pressure (Pa) :param R: leaflet curvature radius (m) :return: pressure-driven acceleration (m/s^2) ''' return Ptot / (cls.rhoL * np.abs(R)) @staticmethod def accNL(U, R): ''' Leaflet transverse nonlinear acceleration :param U: leaflet apex deflection velocity (m/s) :param R: leaflet curvature radius (m) :return: nonlinear acceleration term (m/s^2) .. note:: A simplified version of nonlinear acceleration (neglecting dR/dH) is used here. ''' # return - (3/2 - 2*R/H) * U**2 / R return -(3 * U**2) / (2 * R) @staticmethod def checkInputs(drive, Qm, Pm_comp_method): ''' Check validity of stimulation parameters :param drive: acoustic drive object :param Qm: imposed membrane charge density (C/m2) :param Pm_comp_method: type of method used to compute average intermolecular pressure ''' if not isinstance(drive, Drive): raise TypeError(f'Invalid "drive" parameter (must be an "Drive" object)') if not isinstance(Qm, float): raise TypeError(f'Invalid "Qm" parameter (must be float typed)') Qmin, Qmax = CHARGE_RANGE if Qm < Qmin or Qm > Qmax: raise ValueError( f'Invalid applied charge: {Qm * 1e5} nC/cm2 (must be within [{Qmin * 1e5}, {Qmax * 1e5}] interval') if not isinstance(Pm_comp_method, PmCompMethod): raise TypeError('Invalid Pm computation method (must be "PmCompmethod" type)') def derivatives(self, t, y, drive, Qm, Pm_comp_method=PmCompMethod.predict): ''' Evolution of the mechanical system :param t: time instant (s) :param y: vector of HH system variables at time t :param drive: acoustic drive object :param Qm: membrane charge density (F/m2) :param Pm_comp_method: computation method for average intermolecular pressure :return: vector of mechanical system derivatives at time t ''' # Split input vector explicitly U, Z, ng = y # Correct deflection value is below critical compression if Z < self.Zmin: logger.warning('Deflection out of range: Z = %.2f nm', Z * 1e9) Z = self.Zmin # Compute curvature radius R = self.curvrad(Z) # Compute total pressure Pg = self.gasmol2Pa(ng, self.volume(Z)) if Pm_comp_method is PmCompMethod.direct: Pm = self.PMavg(Z, self.curvrad(Z), self.surface(Z)) elif Pm_comp_method is PmCompMethod.predict: Pm = self.PMavgpred(Z) Pac = drive.compute(t) Pv = self.PVleaflet(U, R) + self.PVfluid(U, R) Ptot = Pm + Pg - self.P0 - Pac + self.PEtot(Z, R) + Pv + self.Pelec(Z, Qm) # Compute derivatives dUdt = self.accP(Ptot, R) + self.accNL(U, R) dZdt = U dngdt = self.gasFlux(Z, Pg) # Return derivatives vector return [dUdt, dZdt, dngdt] def computeInitialDeflection(self, drive, Qm, dt, Pm_comp_method=PmCompMethod.predict): ''' Compute non-zero deflection value for a small perturbation (solving quasi-steady equation). ''' Pac = drive.compute(dt) return self.balancedefQS(self.ng0, Qm, Pac, Pm_comp_method) @classmethod @Model.checkOutputDir def simQueue(cls, freqs, amps, charges, **kwargs): drives = AcousticDrive.createQueue(freqs, amps) queue = [] for drive in drives: for Qm in charges: queue.append([drive, Qm]) return queue def initialConditions(self, *args, **kwargs): ''' Compute simulation initial conditions. ''' # Compute initial non-zero deflection Z = self.computeInitialDeflection(*args, **kwargs) # Return initial conditions dictionary return { 'U': [0.] * 2, 'Z': [0., Z], 'ng': [self.ng0] * 2, } - def simCycles(self, drive, Qm, n=None, Pm_comp_method=PmCompMethod.predict): + def simCycles(self, drive, Qm, nmax=None, nmin=None, Pm_comp_method=PmCompMethod.predict): ''' Simulate for a specific number of cycles or until periodic stabilization, for a specific set of ultrasound parameters, and return output data in a dataframe. :param drive: acoustic drive object :param Qm: imposed membrane charge density (C/m2) :param n: number of cycles (optional) :param Pm_comp_method: type of method used to compute average intermolecular pressure :return: output dataframe ''' # Set the tissue elastic modulus self.setTissueModulus(drive) # Compute initial conditions y0 = self.initialConditions(drive, Qm, drive.dt, Pm_comp_method=Pm_comp_method) # Initialize solver and compute solution solver = PeriodicSolver( drive.periodicity, # periodicty y0.keys(), # variables list lambda t, y: self.derivatives(t, y, drive, Qm, Pm_comp_method), # dfunc primary_vars=['Z', 'ng'], # primary variables dt=drive.dt # time step ) - data = solver(y0, nmax=n) + data = solver(y0, nmax=nmax, nmin=nmin) # Remove velocity timeries from solution del data['U'] # Return solution dataframe return data @Model.addMeta @Model.logDesc @Model.checkSimParams def simulate(self, drive, Qm, Pm_comp_method=PmCompMethod.predict): ''' Wrapper around the simUntilConvergence method, with decorators. ''' return self.simCycles(drive, Qm, Pm_comp_method=Pm_comp_method) def desc(self, meta): return f'{self}: simulation @ {meta["drive"].desc}, Q = {si_format(meta["Qm"] * 1e-4, 2)}C/cm2' def getCycleProfiles(self, drive, Qm): ''' Simulate mechanical system and compute pressures over the last acoustic cycle :param drive: acoustic drive object :param Qm: imposed membrane charge density (C/m2) :return: dataframe with the time, kinematic and pressure profiles over the last cycle. ''' # Run default simulation and retrieve last cycle solution logger.info(f'Running mechanical simulation (a = {si_format(self.a, 1)}m, {drive.desc})') data = self.simulate( drive, Qm, Pm_comp_method=PmCompMethod.direct)[0].iloc[-drive.nPerCycle:, :] # Extract relevant variables and de-offset time vector t, Z, ng = [data[key].values for key in ['t', 'Z', 'ng']] dt = (t[-1] - t[0]) / (NPC_DENSE - 1) t -= t[0] # Compute pressure cyclic profiles logger.info('Computing pressure cyclic profiles') R = self.v_curvrad(Z) U = np.diff(Z) / dt U = np.hstack((U, U[-1])) data = { 't': t, 'Z': Z, 'Cm': self.v_capacitance(Z), 'P_M': self.v_PMavg(Z, R, self.surface(Z)), 'P_Q': self.Pelec(Z, Qm), 'P_{VE}': self.PEtot(Z, R) + self.PVleaflet(U, R), 'P_V': self.PVfluid(U, R), 'P_G': self.gasmol2Pa(ng, self.volume(Z)), 'P_0': - np.ones(Z.size) * self.P0 } return pd.DataFrame(data, columns=data.keys()) diff --git a/PySONIC/core/solvers.py b/PySONIC/core/solvers.py index e5d9461..53f4b7f 100644 --- a/PySONIC/core/solvers.py +++ b/PySONIC/core/solvers.py @@ -1,630 +1,633 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-05-28 14:45:12 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-06-03 16:44:22 +# @Last Modified time: 2020-07-30 18:57:58 import numpy as np import pandas as pd from scipy.interpolate import interp1d from scipy.integrate import ode, odeint, solve_ivp from tqdm import tqdm from ..utils import * from ..constants import * class ODESolver: ''' Generic interface to ODE solver object. ''' def __init__(self, ykeys, dfunc, dt=None): ''' Initialization. :param ykeys: list of differential variables names :param dfunc: derivative function :param dt: integration time step (s) ''' self.ykeys = ykeys self.dfunc = dfunc self.dt = dt def checkFunc(self, key, value): if not callable(value): raise ValueError(f'{key} function must be a callable object') @property def ykeys(self): return self._ykeys @ykeys.setter def ykeys(self, value): if not isIterable(value): value = list(value) for item in value: if not isinstance(item, str): raise ValueError('ykeys must be a list of strings') self._ykeys = value @property def nvars(self): return len(self.ykeys) @property def dfunc(self): return self._dfunc @dfunc.setter def dfunc(self, value): self.checkFunc('derivative', value) self._dfunc = value @property def dt(self): return self._dt @dt.setter def dt(self, value): if value is None: self._dt = None else: if not isinstance(value, float): raise ValueError('time step must be float-typed') if value <= 0: raise ValueError('time step must be strictly positive') self._dt = value def getNSamples(self, t0, tend, dt=None): ''' Get the number of samples required to integrate across 2 times with a given time step. :param t0: initial time (s) :param tend: final time (s) :param dt: integration time step (s) :return: number of required samples, rounded to nearest integer ''' if dt is None: dt = self.dt return max(int(np.round((tend - t0) / dt)), 2) def getTimeVector(self, t0, tend, **kwargs): ''' Get the time vector required to integrate from an initial to a final time with a specific time step. :param t0: initial time (s) :param tend: final time (s) :return: vector going from current time to target time with appropriate step (s) ''' return np.linspace(t0, tend, self.getNSamples(t0, tend, **kwargs)) def initialize(self, y0, t0=0.): ''' Initialize global time vector, state vector and solution array. :param y0: dictionary of initial conditions :param t0: optional initial time or time vector (s) ''' keys = list(y0.keys()) if len(keys) != len(self.ykeys): raise ValueError("Initial conditions do not match system's dimensions") for k in keys: if k not in self.ykeys: raise ValueError(f'{k} is not a differential variable') y0 = {k: np.asarray(v) if isIterable(v) else np.array([v]) for k, v in y0.items()} ref_size = y0[keys[0]].size if not all(v.size == ref_size for v in y0.values()): raise ValueError('dimensions of initial conditions are inconsistent') self.y = np.array(list(y0.values())).T self.t = np.ones(self.y.shape[0]) * t0 self.x = np.zeros(self.t.size) def append(self, t, y): ''' Append to global time vector, state vector and solution array. :param t: new time vector to append (s) :param y: new solution matrix to append ''' self.t = np.concatenate((self.t, t)) self.y = np.concatenate((self.y, y), axis=0) self.x = np.concatenate((self.x, np.ones(t.size) * self.xref)) def bound(self, tbounds): ''' Restrict global time vector, state vector ans solution matrix within specific time range. :param tbounds: minimal and maximal allowed time restricting the global arrays (s). ''' i_bounded = np.logical_and(self.t >= tbounds[0], self.t <= tbounds[1]) self.t = self.t[i_bounded] self.y = self.y[i_bounded, :] self.x = self.x[i_bounded] @staticmethod def timeStr(t): return f'{t * 1e3:.5f} ms' def timedlog(self, s, t=None): ''' Add preceding time information to log string. ''' if t is None: t = self.t[-1] return f't = {self.timeStr(t)}: {s}' def integrateUntil(self, target_t, remove_first=False): ''' Integrate system until a target time and append new arrays to global arrays. :param target_t: target time (s) :param remove_first: optional boolean specifying whether to remove the first index of the new arrays before appending ''' if target_t < self.t[-1]: raise ValueError(f'target time ({target_t} s) precedes current time {self.t[-1]} s') elif target_t == self.t[-1]: t, y = self.t[-1], self.y[-1] if self.dt is None: sol = solve_ivp( self.dfunc, [self.t[-1], target_t], self.y[-1], method='LSODA') t, y = sol.t, sol.y.T else: t = self.getTimeVector(self.t[-1], target_t) y = odeint(self.dfunc, self.y[-1], t, tfirst=True) if remove_first: t, y = t[1:], y[1:] self.append(t, y) def resampleArrays(self, t, y, target_dt): ''' Resample a time vector and soluton matrix to target time step. :param t: time vector to resample (s) :param y: solution matrix to resample :target_dt: target time step (s) :return: resampled time vector and solution matrix ''' tnew = self.getTimeVector(t[0], t[-1], dt=target_dt) ynew = np.array([np.interp(tnew, t, x) for x in y.T]).T return tnew, ynew def resample(self, target_dt): ''' Resample global arrays to a new target time step. :target_dt: target time step (s) ''' tnew, self.y = self.resampleArrays(self.t, self.y, target_dt) self.x = interp1d(self.t, self.x, kind='nearest', assume_sorted=True)(tnew) self.t = tnew def solve(self, y0, tstop, **kwargs): ''' Simulate system for a given time interval for specific initial conditions. :param y0: dictionary of initial conditions :param tstop: stopping time (s) ''' # Initialize system self.initialize(y0, **kwargs) # Integrate until tstop self.integrateUntil(tstop, remove_first=True) @property def solution(self): ''' Return solution as a pandas dataframe. :return: timeseries dataframe with labeled time, state and variables vectors. ''' return TimeSeries(self.t, self.x, {k: self.y[:, i] for i, k in enumerate(self.ykeys)}) def __call__(self, *args, target_dt=None, max_nsamples=None, **kwargs): ''' Specific call method: solve the system, resample solution if needed, and return solution dataframe. ''' self.solve(*args, **kwargs) if target_dt is not None: self.resample(target_dt) elif max_nsamples is not None and self.t.size > max_nsamples: self.resample(np.ptp(self.t) / max_nsamples) return self.solution class PeriodicSolver(ODESolver): ''' ODE solver that integrates periodically until a stable periodic behavior is detected.''' def __init__(self, T, *args, primary_vars=None, **kwargs): ''' Initialization. :param T: periodicity (s) :param primary_vars: keys of the primary solution variables to check for stability ''' super().__init__(*args, **kwargs) self.T = T self.primary_vars = primary_vars @property def T(self): return self._T @T.setter def T(self, value): if not isinstance(value, float): raise ValueError('periodicity must be float-typed') if value <= 0: raise ValueError('periodicity must be strictly positive') self._T = value @property def primary_vars(self): return self._primary_vars @primary_vars.setter def primary_vars(self, value): if value is None: # If none specified, set all variables to be checked for stability value = self.ykeys if not isIterable(value): value = [value] for item in value: if item not in self.ykeys: raise ValueError(f'{item} is not a differential variable') self._primary_vars = value @property def i_primary_vars(self): return [self.ykeys.index(k) for k in self.primary_vars] @property def xref(self): return 1. def getNPerCycle(self, dt=None): ''' Compute number of samples per cycle. :param dt: optional integration time step (s) :return: number of samples per cycle, rounded to nearest integer ''' # if time step not provided, compute dt from last 2 elements of time vector if dt is None: dt = self.t[-1] - self.t[-2] return int(np.round(self.T / dt)) def getCycle(self, i, ivars=None): ''' Get time vector and solution matrix for the ith cycle. :param i: cycle index :param ivars: optional indexes of subset of variables of interest :return: solution matrix for ith cycle, filtered for variables of interest ''' # By default, consider indexes of all variables if ivars is None: ivars = range(self.nvars) # Get first time index where time difference differs from solver's time step, if any i_diff_dt = np.where(np.invert(np.isclose(np.diff(self.t)[::-1], self.dt)))[0] # Determine the number of samples to consider in the backwards direction nsamples = i_diff_dt[0] if i_diff_dt.size > 0 else self.t.size npc = self.getNPerCycle() # number of samples per cycle ncycles = int(np.round(nsamples / npc)) # rounded number of cycles ioffset = self.t.size - npc * ncycles # corresponding initial index offset # Check index validity if i < 0: i += ncycles if i < 0 or i >= ncycles: raise ValueError('Invalid index') # Compute start and end indexes istart = i * npc + ioffset iend = istart + npc # Return arrays for corresponding cycle return self.t[istart:iend], self.y[istart:iend, ivars] def isPeriodicallyStable(self): ''' Assess the periodic stabilization of a solution, by evaluating the deviation of system variables between the last two periods. :return: boolean stating whether the solution is periodically stable or not ''' # Extract the last 2 cycles of the primary variables from the solution y_last, y_prec = [self.getCycle(-i, ivars=self.i_primary_vars)[1] for i in [1, 2]] # Evaluate ratios of RMSE between the two cycles / variation range over the last cycle ratios = rmse(y_last, y_prec, axis=0) / np.ptp(y_last, axis=0) # Classify solution as periodically stable only if all ratios are below critical threshold return np.all(ratios < MAX_RMSE_PTP_RATIO) def integrateCycle(self): ''' Integrate system for a cycle. ''' self.integrateUntil(self.t[-1] + self.T, remove_first=True) - def solve(self, y0, nmax=None, **kwargs): + def solve(self, y0, nmax=None, nmin=None, **kwargs): ''' Simulate system with a specific periodicity until stopping criterion is met. :param y0: dictionary of initial conditions :param nmax: maximum number of integration cycles (optional) ''' if nmax is None: nmax = NCYCLES_MAX + if nmin is None: + nmin = 2 + assert nmin < nmax, 'incorrect bounds for number of cycles (min > max)' # Initialize system if y0 is not None: self.initialize(y0, **kwargs) - # Integrate system for 2 cycles - for i in range(2): + # Integrate system for minimal number of cycles + for i in range(nmin): self.integrateCycle() # Keep integrating system periodically until stopping criterion is met while not self.isPeriodicallyStable() and i < nmax: self.integrateCycle() i += 1 # Log stopping criterion if i == nmax: logger.warning(self.timedlog(f'criterion not met -> stopping after {i} cycles')) else: logger.debug(self.timedlog(f'stopping criterion met after {i} cycles')) class EventDrivenSolver(ODESolver): ''' Event-driven ODE solver. ''' def __init__(self, eventfunc, *args, event_params=None, **kwargs): ''' Initialization. :param eventfunc: function called on each event :param event_params: dictionary of parameters used by the derivatives function ''' super().__init__(*args, **kwargs) self.eventfunc = eventfunc self.assignEventParams(event_params) def assignEventParams(self, event_params): ''' Assign event parameters as instance attributes. ''' if event_params is not None: for k, v in event_params.items(): setattr(self, k, v) @property def eventfunc(self): return self._eventfunc @eventfunc.setter def eventfunc(self, value): self.checkFunc('event', value) self._eventfunc = value @property def xref(self): return self._xref @xref.setter def xref(self, value): self._xref = value def initialize(self, *args, **kwargs): self.xref = 0 super().initialize(*args, **kwargs) def fireEvent(self, xevent): ''' Call event function and set new xref value. ''' if xevent is not None: if xevent == 'log': self.logProgress() else: self.eventfunc(xevent) self.xref = xevent def initLog(self, logfunc, n): ''' Initialize progress logger. ''' self.logfunc = logfunc if self.logfunc is None: setHandler(logger, TqdmHandler(my_log_formatter)) self.pbar = tqdm(total=n) else: self.np = n logger.debug('integrating stimulus') def logProgress(self): ''' Log simulation progress. ''' if self.logfunc is None: self.pbar.update() else: logger.debug(self.timedlog(self.logfunc(self.y[-1]))) def terminateLog(self): ''' Terminate progress logger. ''' if self.logfunc is None: self.pbar.close() else: logger.debug('integration completed') def sortEvents(self, events): ''' Sort events pairs by occurence time. ''' return sorted(events, key=lambda x: x[0]) def solve(self, y0, events, tstop, log_period=None, logfunc=None, **kwargs): ''' Simulate system for a specific stimulus application pattern. :param y0: 1D vector of initial conditions :param events: list of events :param tstop: stopping time (s) ''' # Sort events according to occurrence time events = self.sortEvents(events) # Make sure all events occur before tstop if events[-1][0] > tstop: raise ValueError('all events must occur before stopping time') if log_period is not None: # Add log events if any tlogs = np.arange(kwargs.get('t0', 0.), tstop, log_period)[1:] if tstop not in tlogs: tlogs = np.hstack((tlogs, [tstop])) events = self.sortEvents(events + [(t, 'log') for t in tlogs]) self.initLog(logfunc, tlogs.size) else: # Otherwise, add None event at tstop events.append((tstop, None)) # Initialize system self.initialize(y0, **kwargs) # For each upcoming event for i, (tevent, xevent) in enumerate(events): self.integrateUntil( # integrate until event time tevent, remove_first=i > 0 and events[i - 1][1] == 'log') self.fireEvent(xevent) # fire event # Terminate log if any if log_period is not None: self.terminateLog() class HybridSolver(EventDrivenSolver, PeriodicSolver): def __init__(self, ykeys, dfunc, dfunc_sparse, predfunc, eventfunc, T, dense_vars, dt_dense, dt_sparse, **kwargs): ''' Initialization. :param ykeys: list of differential variables names :param dfunc: derivatives function :param dfunc_sparse: derivatives function for sparse integration periods :param predfunc: function computing the extra arguments necessary for sparse integration :param eventfunc: function called on each event :param T: periodicity (s) :param dense_vars: list of fast-evolving differential variables :param dt_dense: dense integration time step (s) :param dt_sparse: sparse integration time step (s) ''' PeriodicSolver.__init__( self, T, ykeys, dfunc, primary_vars=kwargs.get('primary_vars', None), dt=dt_dense) self.eventfunc = eventfunc self.assignEventParams(kwargs.get('event_params', None)) self.predfunc = predfunc self.dense_vars = dense_vars self.dt_sparse = dt_sparse self.sparse_solver = ode(dfunc_sparse) self.sparse_solver.set_integrator('dop853', nsteps=SOLVER_NSTEPS, atol=1e-12) @property def predfunc(self): return self._predfunc @predfunc.setter def predfunc(self, value): self.checkFunc('prediction', value) self._predfunc = value @property def dense_vars(self): return self._dense_vars @dense_vars.setter def dense_vars(self, value): if value is None: # If none specified, set all variables as dense variables value = self.ykeys if not isIterable(value): value = [value] for item in value: if item not in self.ykeys: raise ValueError(f'{item} is not a differential variable') self._dense_vars = value @property def is_dense_var(self): return np.array([x in self.dense_vars for x in self.ykeys]) @property def is_sparse_var(self): return np.invert(self.is_dense_var) def integrateSparse(self, ysparse, target_t): ''' Integrate sparse system until a specific time. :param ysparse: sparse 1-cycle solution matrix of fast-evolving variables :paramt target_t: target time (s) ''' # Compute number of samples in the sparse cycle solution npc = ysparse.shape[0] # Initialize time vector and solution array for the current interval n = int(np.ceil((target_t - self.t[-1]) / self.dt_sparse)) t = np.linspace(self.t[-1], target_t, n + 1)[1:] y = np.empty((n, self.y.shape[1])) # Initialize sparse integrator self.sparse_solver.set_initial_value(self.y[-1, self.is_sparse_var], self.t[-1]) for i, tt in enumerate(t): # Integrate to next time only if dt is above given threshold if tt - self.sparse_solver.t > MIN_SPARSE_DT: self.sparse_solver.set_f_params(self.predfunc(ysparse[i % npc])) self.sparse_solver.integrate(tt) if not self.sparse_solver.successful(): raise ValueError(self.timedlog('integration error', tt)) # Assign solution values (computed and propagated) to sparse solution array y[i, self.is_dense_var] = ysparse[i % npc, self.is_dense_var] y[i, self.is_sparse_var] = self.sparse_solver.y # Append to global solution self.append(t, y) def solve(self, y0, events, tstop, update_interval, logfunc=None, **kwargs): ''' Integrate system using a hybrid scheme: - First, the full ODE system is integrated for a few cycles with a dense time granularity until a stopping criterion is met - Second, the profiles of all variables over the last cycle are downsampled to a far lower (i.e. sparse) sampling rate - Third, a subset of the ODE system is integrated with a sparse time granularity, for the remaining of the time interval, while the remaining variables are periodically expanded from their last cycle profile. ''' # Sort events according to occurrence time events = self.sortEvents(events) # Make sure all events occur before tstop if events[-1][0] > tstop: raise ValueError('all events must occur before stopping time') # Add None event at tstop events.append((tstop, None)) # Initialize system self.initialize(y0) # Initialize event iterator ievent = iter(events) tevent, xevent = next(ievent) stop = False # While final event is not reached while not stop: # Determine end-time of current interval tend = min(tevent, self.t[-1] + update_interval) # If time interval encompasses at least one cycle, solve periodic system nmax = int(np.round((tend - self.t[-1]) / self.T)) if nmax > 0: logger.debug(self.timedlog('integrating dense system')) PeriodicSolver.solve(self, None, nmax=nmax) # If end-time of current interval has been exceeded, bound solution to that time if self.t[-1] > tend: logger.debug(self.timedlog(f'bounding system at {self.timeStr(tend)}')) self.bound((self.t[0], tend)) # If end-time of current interval has not been reached if self.t[-1] < tend: # Get solution over last cycle and resample it to sparse time step tlast, ylast = self.getCycle(-1) _, ysparse = self.resampleArrays(tlast, ylast, self.dt_sparse) # Integrate sparse system for the rest of the current interval logger.debug(self.timedlog(f'integrating sparse system until {self.timeStr(tend)}')) self.integrateSparse(ysparse, tend) # If end-time corresponds to event, fire it and move to next event if self.t[-1] == tevent: logger.debug(self.timedlog('firing event')) self.fireEvent(xevent) try: tevent, xevent = next(ievent) except StopIteration: stop = True diff --git a/PySONIC/parsers.py b/PySONIC/parsers.py index 97cbeef..3f696fa 100644 --- a/PySONIC/parsers.py +++ b/PySONIC/parsers.py @@ -1,746 +1,745 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-04-17 19:56:16 +# @Last Modified time: 2020-07-31 10:51:21 import os import logging import pprint import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser from .utils import Intensity2Pressure, selectDirDialog, OpenFilesDialog, isIterable from .neurons import getPointNeuron, CorticalRS from .plt import GroupedTimeSeries, CompTimeSeries DEFAULT_OUTPUT_FOLDER = os.path.abspath(os.path.split(__file__)[0] + '../../../../dump') class Parser(ArgumentParser): ''' Generic parser interface. ''' dist_str = '[scale min max n]' def __init__(self): super().__init__() self.pp = pprint.PrettyPrinter(indent=4) self.defaults = {} self.allowed = {} self.factors = {} self.to_parse = {} self.addPlot() self.addVerbose() def pprint(self, args): self.pp.pprint(args) def getDistribution(self, xmin, xmax, nx, scale='lin'): if scale == 'log': xmin, xmax = np.log10(xmin), np.log10(xmax) return {'lin': np.linspace, 'log': np.logspace}[scale](xmin, xmax, nx) def getDistFromList(self, xlist): if not isinstance(xlist, list): raise TypeError('Input must be a list') if len(xlist) != 4: raise ValueError('List must contain exactly 4 arguments ([type, min, max, n])') scale = xlist[0] if scale not in ('log', 'lin'): raise ValueError('Unknown distribution type (must be "lin" or "log")') xmin, xmax = [float(x) for x in xlist[1:-1]] if xmin >= xmax: raise ValueError('Specified minimum higher or equal than specified maximum') nx = int(xlist[-1]) if nx < 2: raise ValueError('Specified number must be at least 2') return self.getDistribution(xmin, xmax, nx, scale=scale) def addRangeParam(self, key, desc, shortcut=None): if shortcut is not None: args = [f'-{shortcut}'] rangekey = shortcut else: args = [] rangekey = key args.append(f'--{key}') self.add_argument(*args, nargs='+', type=float, help=desc) self.add_argument( f'--{rangekey}range', type=str, nargs='+', help=f'Range of {desc}: {self.dist_str}') self.to_parse[key] = self.parseAmplitude def parseRangeParam(self, args, key): rangekey = f'{key}range' - params = [key, rangekey] self.restrict(args) if key in args: return np.array(args[key]) * self.factors[key] elif rangekey in args: return self.getDistFromList(args[rangekey]) * self.factors[key] def addVerbose(self): self.add_argument( '-v', '--verbose', default=False, action='store_true', help='Increase verbosity') self.to_parse['loglevel'] = self.parseLogLevel def addPlot(self): self.add_argument( '-p', '--plot', type=str, nargs='+', help='Variables to plot') self.to_parse['pltscheme'] = self.parsePltScheme def addPhase(self): self.add_argument( '--phase', default=False, action='store_true', help='Phase plot') def addMPI(self): self.add_argument( '--mpi', default=False, action='store_true', help='Use multiprocessing') def addTest(self): self.add_argument( '--test', default=False, action='store_true', help='Run test configuration') def addSave(self): self.add_argument( '-s', '--save', default=False, action='store_true', help='Save output(s)') def addCheckForOutput(self): self.add_argument( '--checkout', default=False, action='store_true', help='Run only simulations for which there is no output file in the output directory') self.to_parse['overwrite'] = self.parseOverwrite def addOverwrite(self): self.add_argument( '--overwrite', default=False, action='store_true', help='Overwrite pre-existing simulation files with new output') def addFigureExtension(self): self.add_argument( '--figext', type=str, default='png', help='Figure type (extension)') def addCompare(self, desc='Comparative graph'): self.add_argument( '--compare', default=False, action='store_true', help=desc) def addSamplingRate(self): self.add_argument( '--sr', type=int, default=1, help='Sampling rate for plot') def addSpikes(self): self.add_argument( '--spikes', type=str, default='none', help='How to indicate spikes on charge profile ("none", "marks" or "details")') def addNColumns(self): self.add_argument( '--ncol', type=int, default=1, help='Number of columns in figure') def addNLevels(self): self.add_argument( '--nlevels', type=int, default=10, help='Number of levels') def addHideOutput(self): self.add_argument( '--hide', default=False, action='store_true', help='Hide output') def addTimeRange(self, default=None): self.add_argument( '--trange', type=float, nargs=2, default=default, help='Time lower and upper bounds (ms)') self.to_parse['trange'] = self.parseTimeRange def addZvar(self, default): self.add_argument( '-z', '--zvar', type=str, default=default, help='z-variable type') def addYscale(self, default='lin'): self.add_argument( '--yscale', type=str, choices=('lin', 'log'), default=default, help='y-scale type ("lin" or "log")') def addZscale(self, default='lin'): self.add_argument( '--zscale', type=str, choices=('lin', 'log'), default=default, help='z-scale type ("lin" or "log")') def addZbounds(self, default): self.add_argument( '--zbounds', type=float, nargs=2, default=default, help='z-scale lower and upper bounds') def addCmap(self, default=None): self.add_argument( '--cmap', type=str, default=default, help='Colormap name') def addCscale(self, default='lin'): self.add_argument( '--cscale', type=str, default=default, choices=('lin', 'log'), help='Color scale ("lin" or "log")') def addInputDir(self, dep_key=None): self.inputdir_dep_key = dep_key self.add_argument( '-i', '--inputdir', type=str, help='Input directory') self.to_parse['inputdir'] = self.parseInputDir def addOutputDir(self, dep_key=None): self.outputdir_dep_key = dep_key self.add_argument( '-o', '--outputdir', type=str, help='Output directory') self.to_parse['outputdir'] = self.parseOutputDir def addInputFiles(self, dep_key=None): self.inputfiles_dep_key = dep_key self.add_argument( '-i', '--inputfiles', type=str, help='Input files') self.to_parse['inputfiles'] = self.parseInputFiles def addPatches(self): self.add_argument( '--patches', type=str, default='one', help='Stimulus patching mode ("none", "one", all", or a boolean list)') self.to_parse['patches'] = self.parsePatches def addThresholdCurve(self): self.add_argument( '--threshold', default=False, action='store_true', help='Show threshold amplitudes') def addNeuron(self): self.add_argument( '-n', '--neuron', type=str, nargs='+', help='Neuron name (string)') self.to_parse['neuron'] = self.parseNeuron def parseNeuron(self, args): return [getPointNeuron(n) for n in args['neuron']] def addInteractive(self): self.add_argument( '--interactive', default=False, action='store_true', help='Make interactive') def addLabels(self): self.add_argument( '--labels', type=str, nargs='+', default=None, help='Labels') def addRelativeTimeBounds(self): self.add_argument( '--rel_tbounds', type=float, nargs='+', default=None, help='Relative time lower and upper bounds') def addPretty(self): self.add_argument( '--pretty', default=False, action='store_true', help='Make figure pretty') def addSubset(self, choices): self.add_argument( '--subset', type=str, nargs='+', default=['all'], choices=choices + ['all'], help='Run specific subset(s)') self.subset_choices = choices self.to_parse['subset'] = self.parseSubset def parseSubset(self, args): if args['subset'] == ['all']: return self.subset_choices else: return args['subset'] def parseTimeRange(self, args): if args['trange'] is None: return None return np.array(args['trange']) * 1e-3 def parsePatches(self, args): if args['patches'] not in ('none', 'one', 'all'): return eval(args['patches']) else: return args['patches'] def parseInputFiles(self, args): if self.inputfiles_dep_key is not None and not args[self.inputfiles_dep_key]: return None elif args['inputfiles'] is None: return OpenFilesDialog('pkl')[0] def parseDir(self, key, args, title, dep_key=None): if dep_key is not None and args[dep_key] is False: return None try: if args[key] is not None: return os.path.abspath(args[key]) else: return selectDirDialog(title=title) except ValueError: raise ValueError(f'No {key} selected') def parseInputDir(self, args): return self.parseDir( 'inputdir', args, 'Select input directory', self.inputdir_dep_key) def parseOutputDir(self, args): if hasattr(self, 'outputdir') and self.outputdir is not None: return self.outputdir else: if args['outputdir'] is not None and args['outputdir'] == 'dump': return DEFAULT_OUTPUT_FOLDER else: return self.parseDir( 'outputdir', args, 'Select output directory', self.outputdir_dep_key) def parseLogLevel(self, args): return logging.DEBUG if args.pop('verbose') else logging.INFO def parsePltScheme(self, args): if args['plot'] is None or args['plot'] == ['all']: return None else: return {x: [x] for x in args['plot']} def parseOverwrite(self, args): check_for_output = args.pop('checkout') return not check_for_output def restrict(self, args, keys): if sum([args[x] is not None for x in keys]) > 1: raise ValueError( f'You must provide only one of the following arguments: {", ".join(keys)}') def parse2array(self, args, key, factor=1): return np.array(args[key]) * factor def parse(self): args = vars(super().parse_args()) for k, v in self.defaults.items(): if k in args and args[k] is None: args[k] = v if isIterable(v) else [v] for k, parse_method in self.to_parse.items(): args[k] = parse_method(args) return args @staticmethod def parsePlot(args, output): render_args = {} if 'spikes' in args: render_args['spikes'] = args['spikes'] if args['compare']: if args['plot'] == ['all']: logger.error('Specific variables must be specified for comparative plots') return for key in ['cmap', 'cscale']: if key in args: render_args[key] = args[key] for pltvar in args['plot']: comp_plot = CompTimeSeries(output, pltvar) comp_plot.render(**render_args) else: scheme_plot = GroupedTimeSeries(output, pltscheme=args['pltscheme']) scheme_plot.render(**render_args) # phase_plot = PhaseDiagram(output, args['plot'][0]) # phase_plot.render( # # trange=args['trange'], # # rel_tbounds=args['rel_tbounds'], # labels=args['labels'], # prettify=args['pretty'], # cmap=args['cmap'], # cscale=args['cscale'] # ) plt.show() class TestParser(Parser): def __init__(self, valid_subsets): super().__init__() self.addProfiling() self.addSubset(valid_subsets) def addProfiling(self): self.add_argument( '--profile', default=False, action='store_true', help='Run with profiling') class FigureParser(Parser): def __init__(self, valid_subsets): super().__init__() self.addSubset(valid_subsets) self.addSave() self.addOutputDir(dep_key='save') class PlotParser(Parser): def __init__(self): super().__init__() self.addHideOutput() self.addInputFiles() self.addOutputDir(dep_key='save') self.addSave() self.addFigureExtension() self.addCmap() self.addPretty() self.addTimeRange() self.addCscale() self.addLabels() class TimeSeriesParser(PlotParser): def __init__(self): super().__init__() self.addSpikes() self.addSamplingRate() self.addCompare() self.addPatches() class SimParser(Parser): ''' Generic simulation parser interface. ''' def __init__(self, outputdir=None): super().__init__() self.outputdir = outputdir self.addMPI() self.addOutputDir(dep_key='save') self.addSave() self.addCheckForOutput() self.addCompare() self.addCmap() self.addCscale() def parse(self): args = super().parse() return args class MechSimParser(SimParser): ''' Parser to run mechanical simulations from the command line. ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.defaults.update({ 'radius': 32.0, # nm 'embedding': 0., # um 'Cm0': CorticalRS().Cm0 * 1e2, # uF/m2 'Qm0': CorticalRS().Qm0 * 1e5, # nC/m2 'freq': 500.0, # kHz 'amp': 100.0, # kPa 'charge': 0., # nC/cm2 'fs': 100. # % }) self.factors.update({ 'radius': 1e-9, 'embedding': 1e-6, 'Cm0': 1e-2, 'Qm0': 1e-5, 'freq': 1e3, 'amp': 1e3, 'charge': 1e-5, 'fs': 1e-2 }) self.addRadius() self.addEmbedding() self.addCm0() self.addQm0() self.addFrequency() self.addAmplitude() self.addCharge() self.addFs() def addRadius(self): self.add_argument( '-a', '--radius', nargs='+', type=float, help='Sonophore radius (nm)') def addEmbedding(self): self.add_argument( '--embedding', nargs='+', type=float, help='Embedding depth (um)') def addCm0(self): self.add_argument( '--Cm0', type=float, nargs='+', help='Resting membrane capacitance (uF/cm2)') def addQm0(self): self.add_argument( '--Qm0', type=float, nargs='+', help='Resting membrane charge density (nC/cm2)') def addFrequency(self): self.add_argument( '-f', '--freq', nargs='+', type=float, help='US frequency (kHz)') def addAmplitude(self): self.add_argument( '-A', '--amp', nargs='+', type=float, help='Acoustic pressure amplitude (kPa)') self.add_argument( '--Arange', type=str, nargs='+', help=f'Amplitude range {self.dist_str} (kPa)') self.add_argument( '-I', '--intensity', nargs='+', type=float, help='Acoustic intensity (W/cm2)') self.add_argument( '--Irange', type=str, nargs='+', help=f'Intensity range {self.dist_str} (W/cm2)') self.to_parse['amp'] = self.parseAmplitude def addCharge(self): self.add_argument( '-Q', '--charge', nargs='+', type=float, help='Membrane charge density (nC/cm2)') def addFs(self): self.add_argument( '--fs', nargs='+', type=float, help='Sonophore coverage fraction (%%)') self.add_argument( '--spanFs', default=False, action='store_true', help='Span Fs from 1 to 100%%') self.to_parse['fs'] = self.parseFs def parseAmplitude(self, args): params = ['Irange', 'Arange', 'intensity', 'amp'] self.restrict(args, params[:-1]) Irange, Arange, Int, A = [args.pop(k) for k in params] if Irange is not None: amps = Intensity2Pressure(self.getDistFromList(Irange) * 1e4) # Pa elif Int is not None: amps = Intensity2Pressure(np.array(Int) * 1e4) # Pa elif Arange is not None: amps = self.getDistFromList(Arange) * self.factors['amp'] # Pa else: amps = np.array(A) * self.factors['amp'] # Pa return amps def parseFs(self, args): if args.pop('spanFs', False): return np.arange(1, 101) * self.factors['fs'] # (-) else: return np.array(args['fs']) * self.factors['fs'] # (-) def parse(self): args = super().parse() for key in ['radius', 'embedding', 'Cm0', 'Qm0', 'freq', 'charge']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args @staticmethod def parseSimInputs(args): return [args[k] for k in ['freq', 'amp', 'charge']] class NeuronSimParser(SimParser): def __init__(self): super().__init__() self.defaults.update({ 'neuron': 'RS', 'tstim': 100.0, # ms 'toffset': 50. # ms }) self.factors.update({ 'tstim': 1e-3, 'toffset': 1e-3 }) self.addNeuron() self.addTstim() self.addToffset() def addTstim(self): self.add_argument( '-t', '--tstim', nargs='+', type=float, help='Stimulus / burst duration (ms)') def addToffset(self): self.add_argument( '--toffset', nargs='+', type=float, help='Offset duration (ms)') class VClampParser(NeuronSimParser): def __init__(self): super().__init__() self.defaults.update({ 'vhold': -70.0, # mV 'vstep': 0.0 # mV }) self.factors.update({ 'vhold': 1., 'vstep': 1. }) self.addVhold() self.addVstep() def addVhold(self): self.add_argument( '--vhold', nargs='+', type=float, help='Held membrane potential (mV)') def addVstep(self): self.add_argument( '--vstep', nargs='+', type=float, help='Step membrane potential (mV)') self.add_argument( '--vsteprange', type=str, nargs='+', help=f'Step membrane potential range {self.dist_str} (mV)') self.to_parse['vstep'] = self.parseVstep def parseVstep(self, args): vstep_range, vstep = [args.pop(k) for k in ['vsteprange', 'vstep']] if vstep_range is not None: vsteps = self.getDistFromList(vstep_range) # mV else: vsteps = np.array(vstep) # mV return vsteps def parse(self, args=None): if args is None: args = super().parse() for key in ['vhold', 'vstep', 'tstim', 'toffset']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args @staticmethod def parseSimInputs(args): return [args[k] for k in ['vhold', 'vstep', 'tstim', 'toffset']] class PWSimParser(NeuronSimParser): ''' Generic parser interface to run PW patterned simulations from the command line. ''' def __init__(self): super().__init__() self.defaults.update({ 'PRF': 100.0, # Hz 'DC': 100.0, # % 'BRF': 1.0, # Hz 'nbursts': 1, # (-) }) self.factors.update({ 'PRF': 1., 'DC': 1e-2, 'BRF': 1., 'nbursts': 1, }) self.allowed.update({ 'DC': range(101) }) self.addPRF() self.addDC() self.addBRF() self.addNBursts() self.addTitrate() self.addSpikes() def addPRF(self): self.add_argument( '--PRF', nargs='+', type=float, help='PRF (Hz)') def addDC(self): self.add_argument( '--DC', nargs='+', type=float, help='Duty cycle (%%)') self.add_argument( '--spanDC', default=False, action='store_true', help='Span DC from 1 to 100%%') self.to_parse['DC'] = self.parseDC def addBRF(self): self.add_argument( '--BRF', nargs='+', type=float, help='Burst repetition frequency (Hz)') def addNBursts(self): self.add_argument( '--nbursts', nargs='+', type=int, help='Number of bursts') def addTitrate(self): self.add_argument( '--titrate', default=False, action='store_true', help='Perform titration') def parseAmplitude(self, args): raise NotImplementedError def parseDC(self, args): if args.pop('spanDC'): return np.arange(1, 101) * self.factors['DC'] # (-) else: return np.array(args['DC']) * self.factors['DC'] # (-) def parse(self, args=None): if args is None: args = super().parse() for key in ['tstim', 'toffset', 'PRF']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args @staticmethod def parseSimInputs(args): keys = ['amp', 'tstim', 'toffset', 'PRF', 'DC'] if len(args['nbursts']) > 1 or args['nbursts'][0] > 1: del keys[2] keys += ['BRF', 'nbursts'] return [args[k] for k in keys] class EStimParser(PWSimParser): ''' Parser to run E-STIM simulations from the command line. ''' def __init__(self): super().__init__() self.defaults.update({'amp': 10.0}) # mA/m2 self.factors.update({'amp': 1.}) self.addAmplitude() def addAmplitude(self): self.add_argument( '-A', '--amp', nargs='+', type=float, help='Amplitude of injected current density (mA/m2)') self.add_argument( '--Arange', type=str, nargs='+', help=f'Amplitude range {self.dist_str} (mA/m2)') self.to_parse['amp'] = self.parseAmplitude def addVext(self): self.add_argument( '--Vext', nargs='+', type=float, help='Extracellular potential (mV)') def parseAmplitude(self, args): if args.pop('titrate'): return None Arange, A = [args.pop(k) for k in ['Arange', 'amp']] if Arange is not None: amps = self.getDistFromList(Arange) * self.factors['amp'] # mA/m2 else: amps = np.array(A) * self.factors['amp'] # mA/m2 return amps class AStimParser(PWSimParser, MechSimParser): ''' Parser to run A-STIM simulations from the command line. ''' def __init__(self): MechSimParser.__init__(self) PWSimParser.__init__(self) self.defaults.update({'method': 'sonic'}) self.allowed.update({'method': ['full', 'hybrid', 'sonic']}) self.addMethod() self.addQSSVars() def addMethod(self): self.add_argument( '-m', '--method', nargs='+', type=str, help=f'Numerical integration method ({", ".join(self.allowed["method"])})') self.to_parse['method'] = self.parseMethod def parseMethod(self, args): for item in args['method']: if item not in self.allowed['method']: raise ValueError(f'Unknown method type: "{item}"') return args['method'] def addQSSVars(self): self.add_argument( '--qss', nargs='+', type=str, help='QSS variables') def parseAmplitude(self, args): if args.pop('titrate'): return None return MechSimParser.parseAmplitude(self, args) def parse(self): args = PWSimParser.parse(self, args=MechSimParser.parse(self)) for k in ['Cm0', 'Qm0', 'embedding', 'charge']: del args[k] return args @staticmethod def parseSimInputs(args): return [args['freq']] + PWSimParser.parseSimInputs(args) + [args[k] for k in ['fs', 'method', 'qss']] diff --git a/PySONIC/plt/pltutils.py b/PySONIC/plt/pltutils.py index d38966a..f031dd3 100644 --- a/PySONIC/plt/pltutils.py +++ b/PySONIC/plt/pltutils.py @@ -1,457 +1,458 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-21 14:33:36 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-07-22 10:57:25 +# @Last Modified time: 2020-08-03 17:04:22 ''' Useful functions to generate plots. ''' import re import numpy as np import pandas as pd import matplotlib from matplotlib.patches import Rectangle from matplotlib import cm, colors import matplotlib.pyplot as plt from ..core import getModel from ..utils import * # Matplotlib parameters matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 matplotlib.rcParams['font.family'] = 'arial' def getSymmetricCmap(cmap_key): cmap = plt.get_cmap(cmap_key) cl = np.vstack((cmap.colors, cmap.reversed().colors)) return colors.LinearSegmentedColormap.from_list(f'sym_{cmap_key}', cl) for k in ['viridis', 'plasma', 'inferno', 'magma', 'cividis']: for cmap_key in [k, f'{k}_r']: sym_cmap = getSymmetricCmap(cmap_key) plt.register_cmap(name=sym_cmap.name, cmap=sym_cmap) def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): return tuple(i / inch for i in tupl[0]) else: return tuple(i / inch for i in tupl) def extractPltVar(model, pltvar, df, meta=None, nsamples=0, name=''): if 'func' in pltvar: s = pltvar['func'] if not s.startswith('meta'): s = f'model.{s}' try: var = eval(s) except AttributeError as err: if hasattr(model, 'pneuron'): var = eval(s.replace('model', 'model.pneuron')) else: raise err elif 'key' in pltvar: var = df[pltvar['key']] elif 'constant' in pltvar: var = eval(pltvar['constant']) * np.ones(nsamples) else: var = df[name] if isinstance(var, pd.Series): var = var.values var = var.copy() if var.size == nsamples - 1: var = np.insert(var, 0, var[0]) var *= pltvar.get('factor', 1) return var def setGrid(n, ncolmax=3): ''' Determine number of rows and columns in figure grid, based on number of variables to plot. ''' if n <= ncolmax: return (1, n) else: return ((n - 1) // ncolmax + 1, ncolmax) def setNormalizer(cmap, bounds, scale='lin'): norm = { 'lin': colors.Normalize, - 'log': colors.LogNorm + 'log': colors.LogNorm, + 'symlog': colors.SymLogNorm }[scale](*bounds) sm = cm.ScalarMappable(norm=norm, cmap=cmap) sm._A = [] return norm, sm class GenericPlot: def __init__(self, outputs): ''' Constructor. :param outputs: list / generator of simulation outputs ''' try: iter(outputs) except TypeError: outputs = [outputs] self.outputs = outputs def __call__(self, *args, **kwargs): return self.render(*args, **kwargs) def figtitle(self, model, meta): return model.desc(meta) @staticmethod def wraptitle(ax, title, maxwidth=120, sep=':', fs=10, y=1.0): if len(title) > maxwidth: title = '\n'.join(title.split(sep)) y = 0.94 h = ax.set_title(title, fontsize=fs) h.set_y(y) @staticmethod def getData(entry, frequency=1, trange=None): if entry is None: raise ValueError('non-existing data') if isinstance(entry, str): data, meta = loadData(entry, frequency) else: data, meta = entry data = data.iloc[::frequency] if trange is not None: tmin, tmax = trange data = data.loc[(data['t'] >= tmin) & (data['t'] <= tmax)] return data, meta def render(self, *args, **kwargs): raise NotImplementedError @staticmethod def getSimType(fname): ''' Get sim type from filename. ''' mo = re.search('(^[A-Z]*)_(.*).pkl', fname) if not mo: raise ValueError(f'Could not find sim-key in filename: "{fname}"') return mo.group(1) @staticmethod def getModel(*args, **kwargs): return getModel(*args, **kwargs) @staticmethod def getTimePltVar(tscale): ''' Return time plot variable for a given temporal scale. ''' return { 'desc': 'time', 'label': 'time', 'unit': tscale, 'factor': {'ms': 1e3, 'us': 1e6}[tscale], 'onset': {'ms': 1e-3, 'us': 1e-6}[tscale] } @staticmethod def createBackBone(*args, **kwargs): raise NotImplementedError @staticmethod def prettify(ax, xticks=None, yticks=None, xfmt='{:.0f}', yfmt='{:+.0f}'): try: ticks = ax.get_ticks() ticks = (min(ticks), max(ticks)) ax.set_ticks(ticks) ax.set_ticklabels([xfmt.format(x) for x in ticks]) except AttributeError: if xticks is None: xticks = ax.get_xticks() xticks = (min(xticks), max(xticks)) if yticks is None: yticks = ax.get_yticks() yticks = (min(yticks), max(yticks)) ax.set_xticks(xticks) ax.set_yticks(yticks) if xfmt is not None: ax.set_xticklabels([xfmt.format(x) for x in xticks]) if yfmt is not None: ax.set_yticklabels([yfmt.format(y) for y in yticks]) @staticmethod def addInset(fig, ax, inset): ''' Create inset axis. ''' inset_ax = fig.add_axes(ax.get_position()) inset_ax.set_zorder(1) inset_ax.set_xlim(inset['xlims'][0], inset['xlims'][1]) inset_ax.set_ylim(inset['ylims'][0], inset['ylims'][1]) inset_ax.set_xticks([]) inset_ax.set_yticks([]) inset_ax.add_patch(Rectangle((inset['xlims'][0], inset['ylims'][0]), inset['xlims'][1] - inset['xlims'][0], inset['ylims'][1] - inset['ylims'][0], color='w')) return inset_ax @staticmethod def materializeInset(ax, inset_ax, inset): ''' Materialize inset with zoom boox. ''' # Re-position inset axis axpos = ax.get_position() left, right, = rescale(inset['xcoords'], ax.get_xlim()[0], ax.get_xlim()[1], axpos.x0, axpos.x0 + axpos.width) bottom, top, = rescale(inset['ycoords'], ax.get_ylim()[0], ax.get_ylim()[1], axpos.y0, axpos.y0 + axpos.height) inset_ax.set_position([left, bottom, right - left, top - bottom]) for i in inset_ax.spines.values(): i.set_linewidth(2) # Materialize inset target region with contour frame ax.plot(inset['xlims'], [inset['ylims'][0]] * 2, linestyle='-', color='k') ax.plot(inset['xlims'], [inset['ylims'][1]] * 2, linestyle='-', color='k') ax.plot([inset['xlims'][0]] * 2, inset['ylims'], linestyle='-', color='k') ax.plot([inset['xlims'][1]] * 2, inset['ylims'], linestyle='-', color='k') # Link target and inset with dashed lines if possible if inset['xcoords'][1] < inset['xlims'][0]: ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') elif inset['xcoords'][0] > inset['xlims'][1]: ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') else: logger.warning('Inset x-coordinates intersect with those of target region') def postProcess(self, *args, **kwargs): raise NotImplementedError @staticmethod def removeSpines(ax): for item in ['top', 'right']: ax.spines[item].set_visible(False) @staticmethod def setXTicks(ax, xticks=None): if xticks is not None: ax.set_xticks(xticks) @staticmethod def setYTicks(ax, yticks=None): if yticks is not None: ax.set_yticks(yticks) @staticmethod def setTickLabelsFontSize(ax, fs): for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fs) @staticmethod def setXLabel(ax, xplt, fs): ax.set_xlabel('$\\rm {}\ ({})$'.format(xplt["label"], xplt["unit"]), fontsize=fs) @staticmethod def setYLabel(ax, yplt, fs): ax.set_ylabel('$\\rm {}\ ({})$'.format(yplt["label"], yplt.get("unit", "")), fontsize=fs) @classmethod def addCmap(cls, fig, cmap, handles, comp_values, comp_info, fs, prettify, zscale='lin'): if all(isinstance(x, str) for x in comp_values): # If list of strings, assume that index suffixes can be extracted prefix, suffixes = extractCommonPrefix(comp_values) comp_values = [int(s) for s in suffixes] desc_str = f'{prefix}\ index' else: # Rescale comparison values and adjust unit comp_values = np.asarray(comp_values) * comp_info.get('factor', 1.) comp_factor, comp_prefix = getSIpair(comp_values, scale=zscale) comp_values /= comp_factor comp_info['unit'] = comp_prefix + comp_info['unit'] desc_str = comp_info["desc"].replace(" ", "\ ") if len(comp_info['unit']) > 0: desc_str = f"{desc_str}\ ({comp_info['unit']})" nvalues = len(comp_values) # Create colormap and normalizer try: mymap = plt.get_cmap(cmap) except ValueError: mymap = plt.get_cmap(swapFirstLetterCase(cmap)) norm, sm = setNormalizer(mymap, (min(comp_values), max(comp_values)), zscale) # Extract and adjust line colors zcolors = sm.to_rgba(comp_values) for lh, c in zip(handles, zcolors): if isIterable(lh): for item in lh: item.set_color(c) else: lh.set_color(c) # Add colorbar fig.subplots_adjust(left=0.1, right=0.8, bottom=0.15, top=0.95, hspace=0.5) cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.8]) cbar_kwargs = {} if all(isinstance(x, int) for x in comp_values): bounds = np.arange(nvalues + 1) + min(comp_values) - 0.5 ticks = bounds[:-1] + 0.5 if nvalues > 10: ticks = [ticks[0], ticks[-1]] cbar_kwargs.update({'ticks': ticks, 'boundaries': bounds, 'format': '%1i'}) cbarax.tick_params(axis='both', which='both', length=0) cbar = fig.colorbar(sm, cax=cbarax, **cbar_kwargs) cbarax.set_ylabel(f'$\\rm {desc_str}$', fontsize=fs) if prettify: cls.prettify(cbar) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) class ComparativePlot(GenericPlot): def __init__(self, outputs, varname): ''' Constructor. :param outputs: list /generator of simulation outputs to be compared. :param varname: name of variable to extract and compare. ''' super().__init__(outputs) self.varname = varname self.comp_ref_key = None self.meta_ref = None self.comp_info = None self.is_unique_comp = False def checkLabels(self, labels): if labels is not None: if not isIterable(labels): raise TypeError('Invalid labels: must be an iterable') if not all(isinstance(x, str) for x in labels): raise TypeError('Invalid labels: must be string typed') def checkSimType(self, meta): ''' Check consistency of sim types across files. ''' if meta['simkey'] != self.meta_ref['simkey']: raise ValueError('Invalid comparison: different simulation types') def checkCompValues(self, meta, comp_values): ''' Check consistency of differing values across files. ''' # Get differing values across meta dictionaries diffs = differing(self.meta_ref, meta, subdkey='meta') # Check that only one value differs if len(diffs) > 1: logger.warning('More than one differing inputs') self.comp_ref_key = None return [] # Get the key and differing values zkey, refval, val = diffs[0] # If no comparison key yet, fill it up if self.comp_ref_key is None: self.comp_ref_key = zkey self.is_unique_comp = True comp_values += [refval, val] # Otherwise, check that comparison matches the existing one else: if zkey != self.comp_ref_key: logger.warning('inconsistent differing inputs') self.comp_ref_key = None return [] else: comp_values.append(val) return comp_values def checkConsistency(self, meta, comp_values): ''' Check consistency of sim types and check differing inputs. ''' if self.meta_ref is None: self.meta_ref = meta else: self.checkSimType(meta) comp_values = self.checkCompValues(meta, comp_values) if self.comp_ref_key is None: self.is_unique_comp = False return comp_values def getCompLabels(self, comp_values): if self.comp_info is not None: comp_values = np.array(comp_values) * self.comp_info.get('factor', 1) if 'unit' in self.comp_info: p = self.comp_info.get('precision', 0) comp_values = [f"{si_format(v, p)}{self.comp_info['unit']}".replace(' ', '\ ') for v in comp_values] comp_labels = ['$\\rm{} = {}$'.format(self.comp_info['label'], x) for x in comp_values] else: comp_labels = comp_values return comp_values, comp_labels def chooseLabels(self, labels, comp_labels, full_labels): if labels is not None: return labels else: if self.is_unique_comp: return comp_labels else: return full_labels @staticmethod def getCommonLabel(lbls, seps='_'): ''' Get a common label from a list of labels, by removing parts that differ across them. ''' # Split every label according to list of separator characters, and save splitters as well splt_lbls = [re.split(f'([{seps}])', x) for x in lbls] pieces = [x[::2] for x in splt_lbls] splitters = [x[1::2] for x in splt_lbls] ncomps = len(pieces[0]) # Assert that splitters are equivalent across all labels, and reduce them to a single array assert (x == x[0] for x in splitters), 'Inconsistent splitters' splitters = np.array(splitters[0]) # Transform pieces into 2D matrix, and evaluate equality of every piece across labels pieces = np.array(pieces).T all_identical = [np.all(x == x[0]) for x in pieces] if np.sum(all_identical) < ncomps - 1: logger.warning('More than one differing inputs') return '' # Discard differing pieces and remove associated splitters pieces = pieces[all_identical, 0] splitters = splitters[all_identical[:-1]] # Remove last splitter if the last pieces were discarded if splitters.size == pieces.size: splitters = splitters[:-1] # Join common pieces and associated splitters into a single label common_lbl = '' for p, s in zip(pieces, splitters): common_lbl += f'{p}{s}' common_lbl += pieces[-1] return common_lbl.replace('( ', '(') def addExcitationInset(ax, is_excited): ''' Add text inset on axis stating excitation status. ''' ax.text( 0.7, 0.7, f'{"" if is_excited else "not "}excited', transform=ax.transAxes, ha='center', va='center', size=30, bbox=dict( boxstyle='round', fc=(0.8, 1.0, 0.8) if is_excited else (1., 0.8, 0.8) ))