diff --git a/PySONIC/parsers.py b/PySONIC/parsers.py index 65af52c..e8fcec9 100644 --- a/PySONIC/parsers.py +++ b/PySONIC/parsers.py @@ -1,512 +1,536 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-06-17 17:47:27 +# @Last Modified time: 2019-06-17 18:52:05 import logging import pprint import numpy as np from argparse import ArgumentParser from .utils import Intensity2Pressure, selectDirDialog, OpenFilesDialog, isIterable from .neurons import getPointNeuron class Parser(ArgumentParser): ''' Generic parser interface. ''' dist_str = '[scale min max n]' def __init__(self): super().__init__() self.pp = pprint.PrettyPrinter(indent=4) self.defaults = {} self.allowed = {} self.factors = {} self.to_parse = {} self.addPlot() self.addVerbose() def pprint(self, args): self.pp.pprint(args) def getDistribution(self, xmin, xmax, nx, scale='lin'): if scale == 'log': xmin, xmax = np.log10(xmin), np.log10(xmax) return {'lin': np.linspace, 'log': np.logspace}[scale](xmin, xmax, nx) def getDistFromList(self, xlist): if not isinstance(xlist, list): raise TypeError('Input must be a list') if len(xlist) != 4: raise ValueError('List must contain exactly 4 arguments ([type, min, max, n])') scale = xlist[0] if scale not in ('log', 'lin'): raise ValueError('Unknown distribution type (must be "lin" or "log")') xmin, xmax = [float(x) for x in xlist[1:-1]] if xmin >= xmax: raise ValueError('Specified minimum higher or equal than specified maximum') nx = int(xlist[-1]) if nx < 2: raise ValueError('Specified number must be at least 2') return self.getDistribution(xmin, xmax, nx, scale=scale) def addVerbose(self): self.add_argument( '-v', '--verbose', default=False, action='store_true', help='Increase verbosity') self.to_parse['loglevel'] = self.parseLogLevel def addPlot(self): self.add_argument( '-p', '--plot', type=str, nargs='+', help='Variables to plot') self.to_parse['pltscheme'] = self.parsePltScheme def addMPI(self): self.add_argument( '--mpi', default=False, action='store_true', help='Use multiprocessing') def addTest(self): self.add_argument( '--test', default=False, action='store_true', help='Run test configuration') def addSave(self): self.add_argument( '-s', '--save', default=False, action='store_true', help='Save output figure(s)') def addFigureExtension(self): self.add_argument( '--figext', type=str, default='png', help='Figure type (extension)') def addCompare(self, desc='Comparative graph'): self.add_argument( '--compare', default=False, action='store_true', help=desc) def addSamplingRate(self): self.add_argument( '--sr', type=int, default=1, help='Sampling rate for plot') def addSpikes(self): self.add_argument( '--spikes', type=str, default='none', help='How to indicate spikes on charge profile ("none", "marks" or "details")') def addHideOutput(self): self.add_argument( '--hide', default=False, action='store_true', help='Hide output') def addTimeRange(self, default=None): self.add_argument( '--trange', type=float, nargs=2, default=default, help='Time lower and upper bounds (ms)') self.to_parse['trange'] = self.parseTimeRange def addPotentialBounds(self, default=None): self.add_argument( '--Vbounds', type=float, nargs=2, default=default, help='Membrane potential lower and upper bounds (mV)') def addFiringRateBounds(self, default): self.add_argument( '--FRbounds', type=float, nargs=2, default=default, help='Firing rate lower and upper bounds (Hz)') def addFiringRateScale(self, default='lin'): self.add_argument( '--FRscale', type=str, choices=('lin', 'log'), default=default, help='Firing rate scale for plot ("lin" or "log")') def addCmap(self, default=None): self.add_argument( '--cmap', type=str, default=default, help='Colormap name') def addCscale(self, default='lin'): self.add_argument( '--cscale', type=str, default=default, choices=('lin', 'log'), help='Color scale ("lin" or "log")') def addInputDir(self, dep_key=None): self.inputdir_dep_key = dep_key self.add_argument( '-i', '--inputdir', type=str, help='Input directory') self.to_parse['inputdir'] = self.parseInputDir def addOutputDir(self, dep_key=None): self.outputdir_dep_key = dep_key self.add_argument( '-o', '--outputdir', type=str, help='Output directory') self.to_parse['outputdir'] = self.parseOutputDir def addInputFiles(self, dep_key=None): self.inputfiles_dep_key = dep_key self.add_argument( '-i', '--inputfiles', type=str, help='Input files') self.to_parse['inputfiles'] = self.parseInputFiles def addPatches(self): self.add_argument( '--patches', type=str, default='one', help='Stimulus patching mode ("none", "one", all", or a boolean list)') self.to_parse['patches'] = self.parsePatches def addThresholdCurve(self): self.add_argument( '--threshold', default=False, action='store_true', help='Show threshold amplitudes') def addNeuron(self): self.add_argument( '-n', '--neuron', type=str, nargs='+', help='Neuron name (string)') self.to_parse['neuron'] = self.parseNeuron def parseNeuron(self, args): return [getPointNeuron(n) for n in args['neuron']] def addInteractive(self): self.add_argument( '--interactive', default=False, action='store_true', help='Make interactive') def addLabels(self): self.add_argument( '--labels', type=str, nargs='+', default=None, help='Labels') def addRelativeTimeBounds(self): self.add_argument( '--rel_tbounds', type=float, nargs='+', default=None, help='Relative time lower and upper bounds') def addPretty(self): self.add_argument( '--pretty', default=False, action='store_true', help='Make figure pretty') def parseTimeRange(self, args): if args['trange'] is None: return None return np.array(args['trange']) * 1e-3 def parsePatches(self, args): if args['patches'] not in ('none', 'one', 'all'): return eval(args['patches']) else: return args['patches'] def parseInputFiles(self, args): if self.inputfiles_dep_key is not None and not args[self.inputfiles_dep_key]: return None elif args['inputfiles'] is None: return OpenFilesDialog('pkl')[0] def parseDir(self, key, args, title, dep_key=None): if dep_key is not None and args[dep_key] is False: return None try: if args[key] is not None: return args[key] else: return selectDirDialog(title=title) except ValueError: raise ValueError('No {} selected'.format(key)) def parseInputDir(self, args): return self.parseDir( 'inputdir', args, 'Select input directory', self.inputdir_dep_key) def parseOutputDir(self, args): if hasattr(self, 'outputdir') and self.outputdir is not None: return self.outputdir else: return self.parseDir( 'outputdir', args, 'Select output directory', self.outputdir_dep_key) def parseLogLevel(self, args): return logging.DEBUG if args.pop('verbose') else logging.INFO def parsePltScheme(self, args): if args['plot'] is None or args['plot'] == ['all']: return None else: return {x: [x] for x in args['plot']} def restrict(self, args, keys): if sum([args[x] is not None for x in keys]) > 1: raise ValueError( 'You must provide only one of the following arguments: {}'.format(', '.join(keys))) def parse2array(self, args, key, factor=1): return np.array(args[key]) * factor def parse(self): args = vars(super().parse_args()) for k, v in self.defaults.items(): if k in args and args[k] is None: args[k] = v if isIterable(v) else [v] for k, parse_method in self.to_parse.items(): args[k] = parse_method(args) return args +class PlotParser(Parser): + def __init__(self): + super().__init__() + self.addHideOutput() + self.addInputFiles() + self.addOutputDir(dep_key='save') + self.addSave() + self.addFigureExtension() + self.addCmap() + self.addPretty() + self.addTimeRange() + self.addCscale() + self.addLabels() + + +class TimeSeriesParser(PlotParser): + def __init__(self): + super().__init__() + self.addSpikes() + self.addSamplingRate() + self.addCompare() + self.addPatches() + + class SimParser(Parser): ''' Generic simulation parser interface. ''' def __init__(self, outputdir=None): super().__init__() self.outputdir = outputdir self.addMPI() self.addOutputDir() def parse(self): args = super().parse() return args class MechSimParser(SimParser): ''' Parser to run mechanical simulations from the command line. ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.defaults.update({ 'radius': 32.0, # nm 'embedding': 0., # um 'Cm0': getPointNeuron('RS').Cm0 * 1e2, # uF/m2 'Qm0': getPointNeuron('RS').Qm0 * 1e5, # nC/m2 'freq': 500.0, # kHz 'amp': 100.0, # kPa 'charge': 0., # nC/cm2 'fs': 100. # % }) self.factors.update({ 'radius': 1e-9, 'embedding': 1e-6, 'Cm0': 1e-2, 'Qm0': 1e-5, 'freq': 1e3, 'amp': 1e3, 'charge': 1e-5, 'fs': 1e-2 }) self.addRadius() self.addEmbedding() self.addCm0() self.addQm0() self.addFdrive() self.addAdrive() self.addCharge() self.addFs() def addRadius(self): self.add_argument( '-a', '--radius', nargs='+', type=float, help='Sonophore radius (nm)') def addEmbedding(self): self.add_argument( '--embedding', nargs='+', type=float, help='Embedding depth (um)') def addCm0(self): self.add_argument( '--Cm0', type=float, help='Resting membrane capacitance (uF/cm2)') def addQm0(self): self.add_argument( '--Qm0', type=float, help='Resting membrane charge density (nC/cm2)') def addFdrive(self): self.add_argument( '-f', '--freq', nargs='+', type=float, help='US frequency (kHz)') def addAdrive(self): self.add_argument( '-A', '--amp', nargs='+', type=float, help='Acoustic pressure amplitude (kPa)') self.add_argument( '--Arange', type=str, nargs='+', help='Amplitude range {} (kPa)'.format(self.dist_str)) self.add_argument( '-I', '--intensity', nargs='+', type=float, help='Acoustic intensity (W/cm2)') self.add_argument( '--Irange', type=str, nargs='+', help='Intensity range {} (W/cm2)'.format(self.dist_str)) self.to_parse['amp'] = self.parseAmp def addAscale(self, default='lin'): self.add_argument( '--Ascale', type=str, choices=('lin', 'log'), default=default, help='Amplitude scale for plot ("lin" or "log")') def addCharge(self): self.add_argument( '-Q', '--charge', nargs='+', type=float, help='Membrane charge density (nC/cm2)') def addFs(self): self.add_argument( '--fs', nargs='+', type=float, help='Sonophore coverage fraction (%%)') self.add_argument( '--spanFs', default=False, action='store_true', help='Span Fs from 1 to 100%%') self.to_parse['fs'] = self.parseFs def parseAmp(self, args): params = ['Irange', 'Arange', 'intensity', 'amp'] self.restrict(args, params[:-1]) Irange, Arange, Int, Adrive = [args.pop(k) for k in params] if Irange is not None: amps = Intensity2Pressure(self.getDistFromList(Irange) * 1e4) # Pa elif Int is not None: amps = Intensity2Pressure(np.array(Int) * 1e4) # Pa elif Arange is not None: amps = self.getDistFromList(Arange) * self.factors['amp'] # Pa else: amps = np.array(Adrive) * self.factors['amp'] # Pa return amps def parseFs(self, args): if args.pop('spanFs', False): return np.arange(1, 101) * self.factors['fs'] # (-) else: return np.array(args['fs']) * self.factors['fs'] # (-) def parse(self): args = super().parse() for key in ['radius', 'embedding', 'Cm0', 'Qm0', 'freq', 'charge']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args class PWSimParser(SimParser): ''' Generic parser interface to run PW patterned simulations from the command line. ''' def __init__(self): super().__init__() self.defaults.update({ 'neuron': 'RS', 'tstim': 100.0, # ms 'toffset': 50., # ms 'PRF': 100.0, # Hz 'DC': 100.0 # % }) self.factors.update({ 'tstim': 1e-3, 'toffset': 1e-3, 'PRF': 1., 'DC': 1e-2 }) self.allowed.update({ 'DC': range(101) }) self.addNeuron() self.addTstim() self.addToffset() self.addPRF() self.addDC() self.addTitrate() self.addSpikes() def addTstim(self): self.add_argument( '-t', '--tstim', nargs='+', type=float, help='Stimulus duration (ms)') def addToffset(self): self.add_argument( '--toffset', nargs='+', type=float, help='Offset duration (ms)') def addPRF(self): self.add_argument( '--PRF', nargs='+', type=float, help='PRF (Hz)') def addDC(self): self.add_argument( '--DC', nargs='+', type=float, help='Duty cycle (%%)') self.add_argument( '--spanDC', default=False, action='store_true', help='Span DC from 1 to 100%%') self.to_parse['DC'] = self.parseDC def addTitrate(self): self.add_argument( '--titrate', default=False, action='store_true', help='Perform titration') def parseAmp(self, args): return NotImplementedError def parseDC(self, args): if args.pop('spanDC'): return np.arange(1, 101) * self.factors['DC'] # (-) else: return np.array(args['DC']) * self.factors['DC'] # (-) def parse(self, args=None): if args is None: args = super().parse() for key in ['tstim', 'toffset', 'PRF']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args class EStimParser(PWSimParser): ''' Parser to run E-STIM simulations from the command line. ''' def __init__(self): super().__init__() self.defaults.update({'amp': 10.0}) # mA/m2 self.factors.update({'amp': 1.}) self.addAstim() def addAstim(self): self.add_argument( '-A', '--amp', nargs='+', type=float, help='Amplitude of injected current density (mA/m2)') self.add_argument( '--Arange', type=str, nargs='+', help='Amplitude range {} (mA/m2)'.format(self.dist_str)) self.to_parse['amp'] = self.parseAmp def parseAmp(self, args): if args.pop('titrate'): return None Arange, Astim = [args.pop(k) for k in ['Arange', 'amp']] if Arange is not None: amps = self.getDistFromList(Arange) * self.factors['amp'] # mA/m2 else: amps = np.array(Astim) * self.factors['amp'] # mA/m2 return amps def parse(self): args = super().parse() return args class AStimParser(PWSimParser, MechSimParser): ''' Parser to run A-STIM simulations from the command line. ''' def __init__(self): MechSimParser.__init__(self) PWSimParser.__init__(self) self.defaults.update({'method': 'sonic'}) self.allowed.update({'method': ['classic', 'hybrid', 'sonic']}) self.addMethod() def addMethod(self): self.add_argument( '-m', '--method', nargs='+', type=str, help='Numerical integration method ({})'.format(', '.join(self.allowed['method']))) self.to_parse['method'] = self.parseMethod def parseMethod(self, args): for item in args['method']: if item not in self.allowed['method']: raise ValueError('Unknown neuron type: "{}"'.format(item)) return args['method'] def parseAmp(self, args): if args.pop('titrate'): return None return MechSimParser.parseAmp(self, args) def parse(self): args = PWSimParser.parse(self, args=MechSimParser.parse(self)) for k in ['Cm0', 'Qm0', 'embedding', 'charge']: del args[k] return args diff --git a/PySONIC/plt/timeseries.py b/PySONIC/plt/timeseries.py index 3766979..6c5e24c 100644 --- a/PySONIC/plt/timeseries.py +++ b/PySONIC/plt/timeseries.py @@ -1,445 +1,447 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2018-09-25 16:18:45 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-06-17 18:12:37 +# @Last Modified time: 2019-06-17 18:40:56 import numpy as np import matplotlib.pyplot as plt from ..core import getModel from ..utils import * from .pltutils import * class TimeSeriesPlot(GenericPlot): ''' Generic interface to build a plot displaying temporal profiles of model simulations. ''' def setTimeLabel(self, ax, tplt, fs): return super().setXLabel(ax, tplt, fs) def setYLabel(self, ax, yplt, fs, grouplabel=None): if grouplabel is not None: yplt['label'] = grouplabel return super().setYLabel(ax, yplt, fs) def checkInputs(self, *args, **kwargs): return NotImplementedError def getStimStates(self, df): try: stimstate = df['stimstate'] except KeyError: stimstate = df['states'] return stimstate.values def getStimPulses(self, t, states): ''' Determine the onset and offset times of pulses from a stimulation vector. :param t: time vector (s). :param states: a vector of stimulation state (ON/OFF) at each instant in time. :return: 3-tuple with number of patches, timing of STIM-ON an STIM-OFF instants. ''' # Compute states derivatives and identify bounds indexes of pulses dstates = np.diff(states) ipulse_on = np.insert(np.where(dstates > 0.0)[0] + 1, 0, 0) ipulse_off = np.where(dstates < 0.0)[0] + 1 if ipulse_off.size < ipulse_on.size: ioff = t.size - 1 if ipulse_off.size == 0: ipulse_off = np.array([ioff]) else: ipulse_off = np.insert(ipulse_off, ipulse_off.size - 1, ioff) # Get time instants for pulses ON and OFF tpulse_on = t[ipulse_on] tpulse_off = t[ipulse_off] return tpulse_on, tpulse_off def addLegend(self, ax, handles, labels, fs, color=None, ls=None): lh = ax.legend(handles, labels, loc=1, fontsize=fs, frameon=False) if color is not None: for l in lh.get_lines(): l.set_color(color) if ls: for l in lh.get_lines(): l.set_linestyle(ls) def materializeSpikes(self, ax, data, tplt, yplt, color, mode, add_to_legend=False): tspikes, Qspikes, Qprominences, thalfmaxbounds, _ = self.getSpikes(data) if tspikes is not None: ax.scatter(tspikes * tplt['factor'], Qspikes * yplt['factor'] + 10, color=color, label='spikes' if add_to_legend else None, marker='v') if mode == 'details': Qbottoms = Qspikes - Qprominences Qmiddles = Qspikes - 0.5 * Qprominences for i in range(len(tspikes)): ax.plot(np.array([tspikes[i]] * 2) * tplt['factor'], np.array([Qspikes[i], Qbottoms[i]]) * yplt['factor'], '--', color=color, label='prominences' if i == 0 and add_to_legend else '') ax.plot(thalfmaxbounds[i] * tplt['factor'], np.array([Qmiddles[i]] * 2) * yplt['factor'], '-.', color=color, label='widths' if i == 0 and add_to_legend else '') return add_to_legend def prepareTime(self, t, tplt): if tplt['onset'] > 0.0: tonset = np.array([-tplt['onset'], -t[0] - t[1]]) t = np.hstack((tonset, t)) return t * tplt['factor'] def addPatches(self, ax, tpatch_on, tpatch_off, tplt, color='#8A8A8A'): for i in range(tpatch_on.size): ax.axvspan(tpatch_on[i] * tplt['factor'], tpatch_off[i] * tplt['factor'], edgecolor='none', facecolor=color, alpha=0.2) def plotInset(self, inset_ax, t, y, tplt, yplt, line, color, lw): inset_window = np.logical_and(t > (inset['xlims'][0] / tplt['factor']), t < (inset['xlims'][1] / tplt['factor'])) inset_ax.plot(t[inset_window] * tplt['factor'], y[inset_window] * yplt['factor'], linewidth=lw, linestyle=line, color=color) return inset_ax def addInsetPatches(self, ax, inset_ax, inset, tpatch_on, tpatch_off, tplt, color): ybottom, ytop = ax.get_ylim() cond_on = np.logical_and(tpatch_on > (inset['xlims'][0] / tfactor), tpatch_on < (inset['xlims'][1] / tfactor)) cond_off = np.logical_and(tpatch_off > (inset['xlims'][0] / tfactor), tpatch_off < (inset['xlims'][1] / tfactor)) cond_glob = np.logical_and(tpatch_on < (inset['xlims'][0] / tfactor), tpatch_off > (inset['xlims'][1] / tfactor)) cond_onoff = np.logical_or(cond_on, cond_off) cond = np.logical_or(cond_onoff, cond_glob) npatches_inset = np.sum(cond) for i in range(npatches_inset): inset_ax.add_patch(Rectangle((tpatch_on[cond][i] * tfactor, ybottom), (tpatch_off[cond][i] - tpatch_on[cond][i]) * tfactor, ytop - ybottom, color=color, alpha=0.1)) class CompTimeSeries(ComparativePlot, TimeSeriesPlot): ''' Interface to build a comparative plot displaying profiles of a specific output variable across different model simulations. ''' def __init__(self, filepaths, varname): ''' Constructor. :param filepaths: list of full paths to output data files to be compared :param varname: name of variable to extract and compare ''' ComparativePlot.__init__(self, filepaths, varname) def checkPatches(self, patches): greypatch = False if patches == 'none': patches = [False] * len(self.filepaths) elif patches == 'all': patches = [True] * len(self.filepaths) elif patches == 'one': patches = [True] + [False] * (len(self.filepaths) - 1) greypatch = True elif isinstance(patches, list): if len(patches) != len(self.filepaths): raise ValueError( 'Invalid patches ({}): not matching number of compared files ({})'.format( len(patches), len(self.filepaths))) if not all(isinstance(p, bool) for p in patches): raise TypeError('Invalid patch sequence: all list items must be boolean typed') else: raise ValueError( 'Invalid patches: must be either "none", all", "one", or a boolean list') return patches, greypatch def checkInputs(self, lines, labels, colors, patches): self.checkLabels(labels) lines = self.checkLines(lines) colors = self.checkColors(colors) patches, greypatch = self.checkPatches(patches) return lines, labels, colors, patches, greypatch def createBackBone(self, figsize): fig, ax = plt.subplots(figsize=figsize) ax.set_zorder(0) return fig, ax def postProcess(self, ax, tplt, yplt, fs, xticks, yticks): self.removeSpines(ax) if 'bounds' in yplt: ax.set_ylim(*yplt['bounds']) self.setTimeLabel(ax, tplt, fs) self.setYLabel(ax, yplt, fs) self.setXTicks(ax, xticks) self.setYTicks(ax, yticks) self.setTickLabelsFontSize(ax, fs) def render(self, figsize=(11, 4), fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', xticks=None, yticks=None, inset=None, frequency=1, spikes='none', cmap=None, cscale='lin', trange=None): ''' Render plot. :param figsize: figure size (x, y) :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: string indicating whether/how to mark stimulation periods with rectangular patches :param xticks: list of x-ticks :param yticks: list of y-ticks :param inset: string indicating whether/how to mark an inset zooming on a particular region of the graph :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param cmap: color map to use for colobar-based comparison (if not None) :param cscale: color scale to use for colobar-based comparison :param trange: optional lower and upper bounds to time axis :return: figure handle ''' lines, labels, colors, patches, greypatch = self.checkInputs( lines, labels, colors, patches) fig, ax = self.createBackBone(figsize) if inset is not None: inset_ax = self.addInset(fig, ax, inset) # Loop through data files handles, comp_values, full_labels = [], [], [] tmin, tmax = np.inf, -np.inf for j, filepath in enumerate(self.filepaths): # Load data data, meta = self.getData(filepath, frequency, trange) meta.pop('tcomp') full_labels.append(figtitle(meta)) # Extract model model = getModel(meta) # Check consistency of sim types and check differing inputs comp_values = self.checkConsistency(meta, comp_values) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) tpatch_on, tpatch_off = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Extract y-variable pltvars = model.getPltVars() if self.varname not in pltvars: raise KeyError( 'Unknown plot variable: "{}". Possible plot variables are: {}'.format( self.varname, ', '.join(['"{}"'.format(p) for p in pltvars.keys()]))) yplt = pltvars[self.varname] y = extractPltVar(model, yplt, data, meta, t.size, self.varname) # Plot time series handles.append(ax.plot(t, y, linewidth=lw, linestyle=lines[j], color=colors[j])[0]) # Optional: add spikes if self.varname == 'Qm' and spikes != 'none': self.materializeSpikes(ax, data, tplt, yplt, colors[j], spikes) # Plot optional inset if inset is not None: inset_ax = self.plotInset(inset_ax, t, y, tplt, yplt, lines[j], colors[j], lw) # Add optional STIM-ON patches if patches[j]: ybottom, ytop = ax.get_ylim() color = '#8A8A8A' if greypatch else handles[j].get_color() self.addPatches(ax, tpatch_on, tpatch_off, tplt, color) if inset is not None: self.addInsetPatches(ax, inset_ax, inset, tpatch_on, tpatch_off, tplt, color) tmin, tmax = min(tmin, t.min()), max(tmax, t.max()) # Determine labels if self.comp_ref_key is not None: self.comp_info = model.inputVars().get(self.comp_ref_key, None) comp_values, comp_labels = self.getCompLabels(comp_values) labels = self.chooseLabels(labels, comp_labels, full_labels) # Post-process figure self.postProcess(ax, tplt, yplt, fs, xticks, yticks) ax.set_xlim(tmin, tmax) fig.tight_layout() if inset is not None: self.materializeInset(ax, inset_ax, inset) # Add labels or colorbar legend if cmap is not None: if not self.is_unique_comp: raise ValueError('Colormap mode unavailable for multiple differing parameters') if self.comp_info is None: raise ValueError('Colormap mode unavailable for qualitative comparisons') self.addCmap(fig, cmap, handles, comp_values, self.comp_info, fs, zscale=cscale) else: self.addLegend(ax, handles, labels, fs) return fig class GroupedTimeSeries(TimeSeriesPlot): ''' Interface to build a plot displaying profiles of several output variables arranged into specific schemes. ''' def __init__(self, filepaths, pltscheme=None): ''' Constructor. :param filepaths: list of full paths to output data files to be compared :param varname: name of variable to extract and compare ''' super().__init__(filepaths) self.pltscheme = pltscheme def createBackBone(self, pltscheme): naxes = len(pltscheme) if naxes == 1: fig, ax = plt.subplots(figsize=(11, 4)) axes = [ax] else: fig, axes = plt.subplots(naxes, 1, figsize=(11, min(3 * naxes, 9))) return fig, axes def postProcess(self, axes, tplt, yplt, fs): for ax in axes: self.removeSpines(ax) self.setTickLabelsFontSize(ax, fs) for ax in axes[:-1]: ax.set_xticklabels([]) self.setTimeLabel(axes[-1], tplt, fs) - def render(self, fs=10, lw=2, labels=None, colors=None, lines=None, patches=True, save=False, + def render(self, fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', save=False, outputdir=None, fig_ext='png', frequency=1, spikes='none', trange=None): ''' Render plot. :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: boolean indicating whether to mark stimulation periods with rectangular patches :param save: boolean indicating whether or not to save the figure(s) :param outputdir: path to output directory in which to save figure(s) :param fig_ext: string indcating figure extension ("png", "pdf", ...) :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param trange: optional lower and upper bounds to time axis :return: figure handle(s) ''' figs = [] for filepath in self.filepaths: # Load data and extract model data, meta = self.getData(filepath, frequency, trange) model = getModel(meta) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) tpatch_on, tpatch_off = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Check plot scheme if provided, otherwise generate it pltvars = model.getPltVars() if self.pltscheme is not None: for key in list(sum(list(self.pltscheme.values()), [])): if key not in pltvars: raise KeyError('Unknown plot variable: "{}"'.format(key)) pltscheme = self.pltscheme else: pltscheme = model.getPltScheme() # Create figure fig, axes = self.createBackBone(pltscheme) # Loop through each subgraph for ax, (grouplabel, keys) in zip(axes, pltscheme.items()): ax_legend_spikes = False # Extract variables to plot nvars = len(keys) ax_pltvars = [pltvars[k] for k in keys] if nvars == 1: ax_pltvars[0]['color'] = 'k' ax_pltvars[0]['ls'] = '-' # Set y-axis unit and bounds self.setYLabel(ax, ax_pltvars[0], fs, grouplabel=grouplabel) if 'bounds' in ax_pltvars[0]: ax_min = min([ap['bounds'][0] for ap in ax_pltvars]) ax_max = max([ap['bounds'][1] for ap in ax_pltvars]) ax.set_ylim(ax_min, ax_max) # Plot time series icolor = 0 for yplt, name in zip(ax_pltvars, pltscheme[grouplabel]): color = yplt.get('color', 'C{}'.format(icolor)) y = extractPltVar(model, yplt, data, meta, t.size, name) ax.plot(t, y, yplt.get('ls', '-'), c=color, lw=lw, label='$\\rm {}$'.format(yplt['label'])) if 'color' not in yplt: icolor += 1 # Optional: add spikes if name == 'Qm' and spikes != 'none': ax_legend_spikes = self.materializeSpikes( ax, data, tplt, yplt, color, spikes, add_to_legend=True) # Add legend if nvars > 1 or 'gate' in ax_pltvars[0]['desc'] or ax_legend_spikes: ax.legend(fontsize=fs, loc=7, ncol=nvars // 4 + 1, frameon=False) # Set x-limits and add optional patches for ax in axes: ax.set_xlim(t.min(), t.max()) - if patches: + if patches != 'none': self.addPatches(ax, tpatch_on, tpatch_off, tplt) # Post-process figure self.postProcess(axes, tplt, yplt, fs) axes[0].set_title(figtitle(meta), fontsize=fs) fig.tight_layout() + fig.canvas.set_window_title(model.filecode(meta)) + # Save figure if needed (automatic or checked) if save: filecode = model.filecode(meta) if outputdir is None: outputdir = os.path.split(filepath)[0] plt_filename = '{}/{}.{}'.format(outputdir, filecode, fig_ext) plt.savefig(plt_filename) logger.info('Saving figure as "{}"'.format(plt_filename)) plt.close() figs.append(fig) return figs if __name__ == '__main__': # example of use filepaths = OpenFilesDialog('pkl')[0] comp_plot = CompTimeSeries(filepaths, 'Qm') fig = comp_plot.render( lines=['-', '--'], labels=['60 kPa', '80 kPa'], patches='one', colors=['r', 'g'], xticks=[0, 100], yticks=[-80, +50], inset={'xcoords': [5, 40], 'ycoords': [-35, 45], 'xlims': [57.5, 60.5], 'ylims': [10, 35]} ) scheme_plot = GroupedTimeSeries(filepaths) figs = scheme_plot.render() plt.show() diff --git a/scripts/plot_activation_map.py b/scripts/plot_activation_map.py index 08ba677..533be3a 100644 --- a/scripts/plot_activation_map.py +++ b/scripts/plot_activation_map.py @@ -1,61 +1,62 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2018-09-26 09:51:43 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-06-17 16:33:26 +# @Last Modified time: 2019-06-17 18:51:26 ''' Plot (duty-cycle x amplitude) US activation map of a neuron at a given frequency and PRF. ''' import numpy as np import matplotlib.pyplot as plt from PySONIC.utils import logger from PySONIC.plt import ActivationMap from PySONIC.parsers import AStimParser def main(): # Parse command line arguments parser = AStimParser() parser.defaults['amp'] = np.logspace(np.log10(10), np.log10(600), 30) # kPa parser.defaults['DC'] = np.arange(1, 101) # % parser.defaults['tstim'] = 1000. # ms parser.defaults['toffset'] = 0. # ms parser.addInputDir() parser.addThresholdCurve() parser.addInteractive() + parser.addSave() parser.addAscale() parser.addTimeRange(default=(0., 240.)) parser.addCmap(default='viridis') parser.addFiringRateBounds((1e0, 1e3)) parser.addFiringRateScale() parser.addPotentialBounds(default=(-150, 50)) - parser.addSave() parser.outputdir_dep_key = 'save' args = parser.parse() logger.setLevel(args['loglevel']) for pneuron in args['neuron']: for a in args['radius']: for Fdrive in args['freq']: for tstim in args['tstim']: for PRF in args['PRF']: actmap = ActivationMap(args['inputdir'], pneuron, a, Fdrive, tstim, PRF, args['amp'], args['DC']) actmap.render( + cmap=args['cmap'], Ascale=args['Ascale'], FRscale=args['FRscale'], FRbounds=args['FRbounds'], interactive=args['interactive'], Vbounds=args['Vbounds'], trange=args['trange'], thresholds=args['threshold'], ) plt.show() if __name__ == '__main__': main() diff --git a/scripts/plot_phase_diagram.py b/scripts/plot_phase_diagram.py index c538c6d..c17fa09 100644 --- a/scripts/plot_phase_diagram.py +++ b/scripts/plot_phase_diagram.py @@ -1,43 +1,39 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-02-13 12:41:26 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-06-17 18:14:20 +# @Last Modified time: 2019-06-17 18:50:48 ''' Plot phase plane diagram of specific simulation output variables. ''' import matplotlib.pyplot as plt from PySONIC.utils import logger from PySONIC.plt import PhaseDiagram -from PySONIC.parsers import Parser +from PySONIC.parsers import PlotParser def main(): - parser = Parser() - parser.addInputFiles() - parser.addTimeRange() - parser.addLabels() + parser = PlotParser() parser.addRelativeTimeBounds() - parser.addPretty() - parser.addCmap() - parser.addCscale() args = parser.parse() logger.setLevel(args['loglevel']) # Plot phase-plane diagram phase_diag = PhaseDiagram(args['inputfiles'], args['plot'][0]) phase_diag.render( trange=args['trange'], rel_tbounds=args['rel_tbounds'], labels=args['labels'], pretty=args['pretty'], cmap=args['cmap'], cscale=args['cscale'] ) - plt.show() + + if not args['hide']: + plt.show() if __name__ == '__main__': main() diff --git a/scripts/plot_timeseries.py b/scripts/plot_timeseries.py index 9c3da51..5e9dda9 100644 --- a/scripts/plot_timeseries.py +++ b/scripts/plot_timeseries.py @@ -1,68 +1,59 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-02-13 12:41:26 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-06-17 18:13:56 +# @Last Modified time: 2019-06-17 18:46:09 ''' Plot temporal profiles of specific simulation output variables. ''' import matplotlib.pyplot as plt from PySONIC.utils import logger from PySONIC.plt import CompTimeSeries, GroupedTimeSeries -from PySONIC.parsers import Parser +from PySONIC.parsers import TimeSeriesParser def main(): # Parse command line arguments - parser = Parser() - parser.addHideOutput() - parser.addInputFiles() - parser.addOutputDir(dep_key='save') - parser.addCompare() - parser.addSave() - parser.addFigureExtension() - parser.addSamplingRate() - parser.addSpikes() - parser.addPatches() - parser.addCmap() - parser.addTimeRange() - parser.addCscale() + parser = TimeSeriesParser() args = parser.parse() logger.setLevel(args['loglevel']) # Plot appropriate graph if args['compare']: if args['plot'] == ['all'] or args['plot'] is None: logger.error('Specific variables must be specified for comparative plots') return for pltvar in args['plot']: try: comp_plot = CompTimeSeries(args['inputfiles'], pltvar) comp_plot.render( + patches=args['patches'], spikes=args['spikes'], frequency=args['sr'], - patches=args['patches'], + trange=args['trange'], cmap=args['cmap'], - cscale=args['cscale'], - trange=args['trange']) + cscale=args['cscale'] + ) except KeyError as e: logger.error(e) return else: scheme_plot = GroupedTimeSeries(args['inputfiles'], pltscheme=args['pltscheme']) scheme_plot.render( - title=True, - save=args['save'], - outputdir=args['outputdir'], - fig_ext=args['figext'], + patches=args['patches'], spikes=args['spikes'], frequency=args['sr'], - trange=args['trange']) + trange=args['trange'], + save=args['save'], + outputdir=args['outputdir'], + fig_ext=args['figext'] + ) + if not args['hide']: plt.show() if __name__ == '__main__': main()