Page MenuHomec4science
No OneTemporary

File Metadata

Thu, May 30, 21:01

# -*- coding: utf-8 -*-
# @Author: Theo Lemaire
# @Email:
# @Date: 2019-11-28 16:42:50
# @Last Modified by: Theo Lemaire
# @Last Modified time: 2021-05-27 09:10:29
import numpy as np
from .utils import logger, isWithin
class OutOfBoundsError(Exception):
def __init__(self, bounds):
msg = f'No threshold found within the [{bounds[0]:.2e} - {bounds[1]:.2e}] interval'
class MaxNIterations(Exception):
def __init__(self, max_nit, history):
msg = f'Maximum number of iterations ({max_nit}) reached, history = {history}'
class Thresholder:
''' Class used to determine the threshold satisfying a given condition within a
continuous search interval, using a binary search with initial preconditioning.
eps_machine = np.sqrt(np.finfo(float).eps)
err_val = np.nan
def __init__(self, feval, xbounds, x0=None, eps_thr=None, rel_eps_thr=1e-2,
max_nit=50, precheck=False, fbound=2):
''' Initialization.
:param feval: evaluation function returning whether condition is satisfied
:param xbounds: initial search interval for threshold
:param x0: initial evaluation value
:param eps_thr: maximum absolute error
:param rel_eps_thr: maximum relative error
:param precheck: boolean stating whether to perform an initial check
for the existence of a threshold within the interval
:param fbound: integer factor indicating the magnitude of the initial bounding procedure
:return: final threshold, or full search history
self.feval = feval
self.xbounds = xbounds
self.rel_eps_thr = rel_eps_thr
self.eps_thr = eps_thr
self.max_nit = max_nit
self.fbound = fbound
self.precheck = precheck
self.x0 = x0
def feval(self):
return self._feval
def feval(self, value):
if not callable(value):
raise ValueError('feval must be a callable object')
self._feval = value
def xbounds(self):
return self._xbounds
def xbounds(self, value):
if len(value) != 2:
raise ValueError('xbounds must be an iterbale of size 2')
if value[0] >= value[1]:
raise ValueError('lower bound must be smaller than upper bound')
self._xbounds = value
def fixed_lb(self):
return self.xbounds[0]
def fixed_lb(self, value):
self.xbounds = (value, self.fixed_ub)
def fixed_ub(self):
return self.xbounds[1]
def fixed_ub(self, value):
self.xbounds = (self.fixed_lb, value)
def x0(self):
return self._x0
def x0(self, value):
if value is None: # If not specified, set to geometric mean of search interval
value = self.getStartPoint(self.xbounds, x=0.5, scale='log')
if value == 0.: # If zero, set to mid-point of search interval
value = self.getStartPoint(self.xbounds, x=0.5, scale='lin')
self._x0 = value
def eps_thr(self):
return self._eps_thr
def eps_thr(self, value):
if value is None: # If not specified, set to infinity
value = np.inf
self._eps_thr = value
def rel_eps_thr(self):
return self._rel_eps_thr
def rel_eps_thr(self, value):
value = isWithin('rel_eps_thr', value, (0., 1.))
self._rel_eps_thr = value
def max_nit(self):
return self._max_nit
def max_nit(self, value):
if not isinstance(value, int):
raise ValueError('max_nit must be of type int')
if value < 1:
raise ValueError('max_nit must be greater than 0')
self._max_nit = value
def precheck(self):
return self._precheck
def precheck(self, value):
if not isinstance(value, bool):
raise ValueError('precheck must be of type bool')
self._precheck = value
def fbound(self):
return self._fbound
def fbound(self, value):
if value is not None:
if value <= 1:
raise ValueError('bounding factor must be greater than 1')
# If fixed lower bound is zero, re-assign it to absolue threshold (if provided)
# or to machine epsilon
if self.fixed_lb == 0.:
self.fixed_lb = self.eps_thr / 2 if self.eps_thr < np.inf else self.eps_machine
# Search interval must span more than 2 times bounding factor
if self.fixed_ub / self.fixed_lb <= 2 * value:
raise ValueError(f'search interval too narrow for factor bounding')
self._fbound = value
def x(self):
return self._x_history[-1]
def x(self, value):
if not hasattr(self, '_x_history'):
self._x_history = []
def x_history(self):
return np.array(self._x_history)
def is_above(self):
return self._eval_history[-1]
def is_above(self, value):
if not hasattr(self, '_eval_history'):
self._eval_history = []
def eval_history(self):
return np.array(self._eval_history)
def has_changed_eval(self):
return len(set(self._eval_history)) > 1
def eval(self):
self.is_above = self.feval(self.x)
isWithin('x', self.x, self.xbounds, raise_warning=False)
def nits(self):
return len(self._x_history)
def midpoint(self):
return (self.ub + / 2
def eff_thr(self):
return min(self.rel_eps_thr *, self.eps_thr)
def hasConverged(self):
return np.abs(self.ub - <= 2 * self.eff_thr
def getStartPoint(bounds, x=0.5, scale='lin'):
''' Define a value located at a given relative distance between two bounds.
:param bounds: lower and upper bound values
:param x: relative logarithmic distance, between 0 (lower bound) and 1 (upper bound)
:param scale: scale type between bounds ('lin' / 'log')
:return: scaled starting value
if scale == 'log':
bounds = np.log10(bounds)
x0 = (1 - x) * bounds[0] + x * bounds[1]
if scale == 'log':
x0 = np.power(10., x0)
return x0
def checkNiterations(self):
''' Check that number of iterations does not exceed limit. '''
if self.nits >= self.max_nit:
raise MaxNIterations(self.max_nit, self._x_history)
def initBounds(self):, self.ub = self.xbounds
def checkAtBound(self):
''' Evaluate at the appropriate bound based on last evaluation result, and
raise error if evaluation indicates no threshold within interval. '''
last_eval = self.is_above
self.x = if self.is_above else self.ub
if self.is_above == last_eval:
raise OutOfBoundsError(self.xbounds)
def preCondition(self):
''' Refine search interval by either multiplying or dividing x by a specific integer
factor k until target lies within an interval [x, kx]
# If exact match between (k * x) and ub or between (x / k) and lb, adapt k slightly
if self.x * self.fbound == self.ub or * self.fbound == self.x:
self.fbound *= 0.99
# Iterate while upper bound is more than (k * x) or lower bound is less than (x / k)
while < self.x / self.fbound or self.ub > self.x * self.fbound:
# Refine interval and x based on feval result
if self.is_above:
self.ub = self.x
self.x = self.ub / self.fbound
else: = self.x
self.x = self.fbound *
# If lower bound greater or equal to upper bound -> raise error
if >= self.ub:
raise OutOfBoundsError(self.xbounds)
# Evaluate
# Set x to interval mid-point and re-evaluate
self.x = self.midpoint
def binSearch(self):
''' Binary search until interval is smaller than most stringent threshold criterion. '''
while not self.hasConverged():
# Refine interval based on feval result
if self.is_above:
self.ub = self.x
else: = self.x
# Set x to interval mid-point and re-evaluate
self.x = self.midpoint
def refine(self):
''' Refine threshold once convergence has been reached. '''
# If last value is not above threshold
if not self.is_above:
# Set x to interval mid-point and re-evaluate (to ensure relative convergence), self.x = self.x, self.midpoint
# If last value still not above threshold, evaluate at upper bound
if not self.is_above:
self.x = self.ub
def run(self, output_history=False):
self.x = self.x0
if self.precheck: # Run pre-check at the approprite interval bound if required
self.initBounds() # Re-initialize bounds
if self.fbound is not None: # Perform initial factor bounding if required
self.binSearch() # Run binary search until convergence
if not self.has_changed_eval: # if feval has not changed output, evaluate at the bound
self.refine() # refine to make sure final value is above threshold
except (OutOfBoundsError, MaxNIterations) as err: # if error is raised, assign nan and log
self.x = self.err_val
def threshold(*args, output_history=False, **kwargs):
''' Wrapper function around the Thresholder class.
:param output_history: boolean stating whether to return history of search procedure
:return: final threshold, or full search history
th = Thresholder(*args, **kwargs)
if output_history:
return th.x_history, th.eval_history
return th.x
def titrate(model, drive, pp, **kwargs):
''' Use a binary search to determine the threshold amplitude needed
to obtain neural excitation for a given duration, PRF and duty cycle.
:param model: model object
:param drive: unresolved drive object
:param pp: pulsed protocol object
:param xfunc: function determining whether condition is reached from simulation output
:param Arange: search interval for electric current amplitude, iteratively refined
:return: excitation amplitude (in drive units)
xfunc = kwargs.pop('xfunc', None)
Arange = kwargs.pop('Arange', None)
# Default output function
if xfunc is None:
xfunc = model.titrationFunc
# Default amplitude interval
if Arange is None:
Arange = model.getArange(drive)
return threshold(
lambda x: xfunc(model.simulate(drive.updatedX(x), pp, **kwargs)[0]),

Event Timeline