Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F60647034
actmap.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Wed, May 1, 17:03
Size
16 KB
Mime Type
text/x-python
Expires
Fri, May 3, 17:03 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
17389845
Attached To
R4670 PySONIC (old)
actmap.py
View Options
# -*- coding: utf-8 -*-
# @Author: Theo Lemaire
# @Email: theo.lemaire@epfl.ch
# @Date: 2019-06-04 18:24:29
# @Last Modified by: Theo Lemaire
# @Last Modified time: 2020-08-04 17:52:34
import
abc
import
csv
from
itertools
import
product
import
numpy
as
np
import
pandas
as
pd
import
matplotlib.pyplot
as
plt
import
copy
from
..core
import
NeuronalBilayerSonophore
,
PulsedProtocol
,
AcousticDrive
,
LogBatch
,
Batch
from
..utils
import
logger
,
si_format
,
isIterable
from
.pltutils
import
cm2inch
,
setNormalizer
from
.timeseries
import
GroupedTimeSeries
from
..postpro
import
detectSpikes
class
XYMap
(
LogBatch
):
''' Generic 2D map object interface. '''
offset_options
=
{
'lr'
:
(
1
,
-
1
),
'ur'
:
(
1
,
1
),
'll'
:
(
-
1
,
-
1
),
'ul'
:
(
-
1
,
1
)
}
def
__init__
(
self
,
root
,
xvec
,
yvec
):
self
.
root
=
root
self
.
xvec
=
xvec
self
.
yvec
=
yvec
super
()
.
__init__
([
list
(
pair
)
for
pair
in
product
(
self
.
xvec
,
self
.
yvec
)],
root
=
root
)
def
checkVector
(
self
,
name
,
value
):
if
not
isIterable
(
value
):
raise
ValueError
(
f
'{name} vector must be an iterable'
)
if
not
isinstance
(
value
,
np
.
ndarray
):
value
=
np
.
asarray
(
value
)
if
len
(
value
.
shape
)
>
1
:
raise
ValueError
(
f
'{name} vector must be one-dimensional'
)
return
value
@property
def
in_key
(
self
):
return
self
.
xkey
@property
def
unit
(
self
):
return
self
.
xunit
@property
def
xvec
(
self
):
return
self
.
_xvec
@xvec.setter
def
xvec
(
self
,
value
):
self
.
_xvec
=
self
.
checkVector
(
'x'
,
value
)
@property
def
yvec
(
self
):
return
self
.
_yvec
@yvec.setter
def
yvec
(
self
,
value
):
self
.
_yvec
=
self
.
checkVector
(
'x'
,
value
)
@property
@abc.abstractmethod
def
xkey
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
xfactor
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
xunit
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
ykey
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
yfactor
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
yunit
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
zkey
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
zunit
(
self
):
raise
NotImplementedError
@property
@abc.abstractmethod
def
zfactor
(
self
):
raise
NotImplementedError
@property
def
out_keys
(
self
):
return
[
f
'{self.zkey} ({self.zunit})'
]
@property
def
in_labels
(
self
):
return
[
f
'{self.xkey} ({self.xunit})'
,
f
'{self.ykey} ({self.yunit})'
]
def
getLogData
(
self
):
''' Retrieve the batch log file data (inputs and outputs) as a dataframe. '''
return
pd
.
read_csv
(
self
.
fpath
,
sep
=
self
.
delimiter
)
.
sort_values
(
self
.
in_labels
)
def
getInput
(
self
):
''' Retrieve the logged batch inputs as an array. '''
return
self
.
getLogData
()[
self
.
in_labels
]
.
values
def
getOutput
(
self
):
return
np
.
reshape
(
super
()
.
getOutput
(),
(
self
.
xvec
.
size
,
self
.
yvec
.
size
))
.
T
def
writeLabels
(
self
):
with
open
(
self
.
fpath
,
'w'
)
as
csvfile
:
writer
=
csv
.
writer
(
csvfile
,
delimiter
=
self
.
delimiter
)
writer
.
writerow
([
*
self
.
in_labels
,
*
self
.
out_keys
])
def
isEntry
(
self
,
comb
):
''' Check if a given input is logged in the batch log file. '''
inputs
=
self
.
getInput
()
if
len
(
inputs
)
==
0
:
return
False
imatches_x
=
np
.
where
(
np
.
isclose
(
inputs
[:,
0
],
comb
[
0
],
rtol
=
self
.
rtol
,
atol
=
self
.
atol
))[
0
]
imatches_y
=
np
.
where
(
np
.
isclose
(
inputs
[:,
1
],
comb
[
1
],
rtol
=
self
.
rtol
,
atol
=
self
.
atol
))[
0
]
imatches
=
list
(
set
(
imatches_x
)
.
intersection
(
imatches_y
))
if
len
(
imatches
)
==
0
:
return
False
return
True
@property
def
inputscode
(
self
):
''' String describing the batch inputs. '''
xcode
=
self
.
rangecode
(
self
.
xvec
,
self
.
xkey
,
self
.
xunit
)
ycode
=
self
.
rangecode
(
self
.
yvec
,
self
.
ykey
,
self
.
yunit
)
return
'_'
.
join
([
xcode
,
ycode
])
@staticmethod
def
getScaleType
(
x
):
xmin
,
xmax
,
nx
=
x
.
min
(),
x
.
max
(),
x
.
size
if
np
.
all
(
np
.
isclose
(
x
,
np
.
logspace
(
np
.
log10
(
xmin
),
np
.
log10
(
xmax
),
nx
))):
return
'log'
else
:
return
'lin'
# elif np.all(np.isclose(x, np.linspace(xmin, xmax, nx))):
# return 'lin'
# else:
# raise ValueError('Unknown distribution type')
@property
def
xscale
(
self
):
return
self
.
getScaleType
(
self
.
xvec
)
@property
def
yscale
(
self
):
return
self
.
getScaleType
(
self
.
yvec
)
@staticmethod
def
computeMeshEdges
(
x
,
scale
):
''' Compute the appropriate edges of a mesh that quads a linear or logarihtmic distribution.
:param x: the input vector
:param scale: the type of distribution ('lin' for linear, 'log' for logarihtmic)
:return: the edges vector
'''
if
scale
==
'log'
:
x
=
np
.
log10
(
x
)
range_func
=
np
.
logspace
else
:
range_func
=
np
.
linspace
dx
=
x
[
1
]
-
x
[
0
]
n
=
x
.
size
+
1
return
range_func
(
x
[
0
]
-
dx
/
2
,
x
[
-
1
]
+
dx
/
2
,
n
)
@abc.abstractmethod
def
compute
(
self
,
x
):
''' Compute the necessary output(s) for a given inputs combination. '''
raise
NotImplementedError
def
run
(
self
,
**
kwargs
):
super
()
.
run
(
**
kwargs
)
self
.
getLogData
()
.
to_csv
(
self
.
filepath
(),
sep
=
self
.
delimiter
,
index
=
False
)
def
getOnClickXY
(
self
,
event
):
''' Get x and y values from from x and y click event coordinates. '''
x
=
self
.
xvec
[
np
.
searchsorted
(
self
.
xedges
,
event
.
xdata
)
-
1
]
y
=
self
.
yvec
[
np
.
searchsorted
(
self
.
yedges
,
event
.
ydata
)
-
1
]
return
x
,
y
def
onClick
(
self
,
event
):
''' Exexecute specific action when the user clicks on a cell in the 2D map. '''
pass
@property
@abc.abstractmethod
def
title
(
self
):
raise
NotImplementedError
def
getZBounds
(
self
):
matrix
=
self
.
getOutput
()
zmin
,
zmax
=
np
.
nanmin
(
matrix
),
np
.
nanmax
(
matrix
)
logger
.
info
(
f
'{self.zkey} range: {zmin:.0f} - {zmax:.0f} {self.zunit}'
)
return
zmin
,
zmax
def
checkZbounds
(
self
,
zbounds
):
zmin
,
zmax
=
self
.
getZBounds
()
if
zmin
<
zbounds
[
0
]:
logger
.
warning
(
f
'Minimal {self.zkey} ({zmin:.0f} {self.zunit}) is below defined lower bound ({zbounds[0]:.0f} {self.zunit})'
)
if
zmax
>
zbounds
[
1
]:
logger
.
warning
(
f
'Maximal {self.zkey} ({zmax:.0f} {self.zunit}) is above defined upper bound ({zbounds[1]:.0f} {self.zunit})'
)
def
render
(
self
,
xscale
=
'lin'
,
yscale
=
'lin'
,
zscale
=
'lin'
,
zbounds
=
None
,
fs
=
8
,
cmap
=
'viridis'
,
interactive
=
False
,
figsize
=
None
,
insets
=
None
,
inset_offset
=
0.05
,
extend_under
=
False
,
extend_over
=
False
):
# Get figure size
if
figsize
is
None
:
figsize
=
cm2inch
(
12
,
7
)
# Compute Z normalizer
mymap
=
copy
.
copy
(
plt
.
get_cmap
(
cmap
))
if
not
extend_under
:
mymap
.
set_under
(
'silver'
)
if
not
extend_over
:
mymap
.
set_over
(
'silver'
)
if
zbounds
is
None
:
zbounds
=
self
.
getZBounds
()
else
:
self
.
checkZbounds
(
zbounds
)
norm
,
sm
=
setNormalizer
(
mymap
,
zbounds
,
zscale
)
nan_eq
=
zbounds
[
0
]
-
1
if
zscale
==
'lin'
else
0.5
*
zbounds
[
0
]
# Compute mesh edges
self
.
xedges
=
self
.
computeMeshEdges
(
self
.
xvec
,
xscale
)
self
.
yedges
=
self
.
computeMeshEdges
(
self
.
yvec
,
yscale
)
# Create figure
fig
,
ax
=
plt
.
subplots
(
figsize
=
figsize
)
fig
.
subplots_adjust
(
left
=
0.15
,
bottom
=
0.15
,
right
=
0.8
,
top
=
0.92
)
ax
.
set_title
(
self
.
title
,
fontsize
=
fs
)
ax
.
set_xlabel
(
f
'{self.xkey} ({self.xunit})'
,
fontsize
=
fs
,
labelpad
=-
0.5
)
ax
.
set_ylabel
(
f
'{self.ykey} ({self.yunit})'
,
fontsize
=
fs
)
for
item
in
ax
.
get_xticklabels
()
+
ax
.
get_yticklabels
():
item
.
set_fontsize
(
fs
)
if
xscale
==
'log'
:
ax
.
set_xscale
(
'log'
)
if
yscale
==
'log'
:
ax
.
set_yscale
(
'log'
)
# Retrieve data and replace NaNs with specific out-of-bounds value
data
=
self
.
getOutput
()
data
[
np
.
isnan
(
data
)]
=
nan_eq
# Plot map with specific color code
ax
.
pcolormesh
(
self
.
xedges
,
self
.
yedges
,
data
,
cmap
=
mymap
,
norm
=
norm
)
# Plot potential insets
if
insets
is
not
None
:
x_data
,
y_data
,
*
_
=
zip
(
*
insets
)
ax
.
scatter
(
x_data
,
y_data
,
s
=
80
,
facecolors
=
'none'
,
edgecolors
=
'k'
,
linestyle
=
'--'
,
lw
=
1
)
axis_to_data
=
ax
.
transAxes
+
ax
.
transData
.
inverted
()
data_to_axis
=
axis_to_data
.
inverted
()
for
x
,
y
,
label
,
direction
in
insets
:
xyoffset
=
np
.
array
(
self
.
offset_options
[
direction
])
*
inset_offset
# in axis coords
xytext
=
axis_to_data
.
transform
(
np
.
array
(
data_to_axis
.
transform
((
x
,
y
)))
+
xyoffset
)
ax
.
annotate
(
label
,
xy
=
(
x
,
y
),
xytext
=
xytext
,
fontsize
=
fs
,
horizontalalignment
=
'right'
,
arrowprops
=
{
'facecolor'
:
'black'
,
'arrowstyle'
:
'-'
})
# Plot z-scale colorbar
pos1
=
ax
.
get_position
()
# get the map axis position
cbarax
=
fig
.
add_axes
([
pos1
.
x1
+
0.02
,
pos1
.
y0
,
0.03
,
pos1
.
height
])
if
not
extend_under
and
not
extend_over
:
extend
=
'neither'
elif
extend_under
and
extend_over
:
extend
=
'both'
else
:
extend
=
'max'
if
extend_over
else
'min'
fig
.
colorbar
(
sm
,
cax
=
cbarax
,
extend
=
extend
)
cbarax
.
set_ylabel
(
f
'{self.zkey} ({self.zunit})'
,
fontsize
=
fs
)
for
item
in
cbarax
.
get_yticklabels
():
item
.
set_fontsize
(
fs
)
if
interactive
:
fig
.
canvas
.
mpl_connect
(
'button_press_event'
,
lambda
event
:
self
.
onClick
(
event
))
return
fig
class
ActivationMap
(
XYMap
):
xkey
=
'Duty cycle'
xfactor
=
1e2
xunit
=
'%'
ykey
=
'Amplitude'
yfactor
=
1e-3
yunit
=
'kPa'
onclick_colors
=
None
def
__init__
(
self
,
root
,
pneuron
,
a
,
fs
,
f
,
tstim
,
PRF
,
amps
,
DCs
):
self
.
nbls
=
NeuronalBilayerSonophore
(
a
,
pneuron
)
self
.
drive
=
AcousticDrive
(
f
,
None
)
self
.
pp
=
PulsedProtocol
(
tstim
,
0.
,
PRF
,
.
5
)
self
.
fs
=
fs
super
()
.
__init__
(
root
,
DCs
*
self
.
xfactor
,
amps
*
self
.
yfactor
)
@property
def
sim_args
(
self
):
return
[
self
.
drive
,
self
.
pp
,
self
.
fs
,
'sonic'
,
None
]
@property
def
title
(
self
):
s
=
'Activation map - {} neuron @ {}Hz, {}Hz PRF ({}m sonophore'
.
format
(
self
.
nbls
.
pneuron
.
name
,
*
si_format
([
self
.
drive
.
f
,
self
.
pp
.
PRF
,
self
.
nbls
.
a
]))
if
self
.
fs
<
1
:
s
=
f
'{s}, {self.fs * 1e2:.0f}
% c
overage'
return
f
'{s})'
def
corecode
(
self
):
corecodes
=
self
.
nbls
.
filecodes
(
*
self
.
sim_args
)
del
corecodes
[
'nature'
]
if
'DC'
in
corecodes
:
del
corecodes
[
'DC'
]
return
'_'
.
join
(
filter
(
lambda
x
:
x
is
not
None
,
corecodes
.
values
()))
def
compute
(
self
,
x
):
''' Compute firing rate from simulation output '''
# Adapt drive and pulsed protocol
self
.
pp
.
DC
=
x
[
0
]
/
self
.
xfactor
self
.
drive
.
A
=
x
[
1
]
/
self
.
yfactor
# Get model output, running simulation if needed
data
,
_
=
self
.
nbls
.
getOutput
(
*
self
.
sim_args
,
outputdir
=
self
.
root
)
return
self
.
xfunc
(
data
)
@abc.abstractmethod
def
xfunc
(
self
,
data
):
raise
NotImplementedError
def
addThresholdCurve
(
self
,
ax
,
fs
,
mpi
=
False
):
queue
=
[[
self
.
drive
,
PulsedProtocol
(
self
.
pp
.
tstim
,
self
.
pp
.
toffset
,
self
.
pp
.
PRF
,
DC
/
self
.
xfactor
),
self
.
fs
,
'sonic'
,
None
]
for
DC
in
self
.
xvec
]
batch
=
Batch
(
self
.
nbls
.
titrate
,
queue
)
Athrs
=
np
.
array
(
batch
.
run
(
mpi
=
mpi
,
loglevel
=
logger
.
level
))
ax
.
plot
(
self
.
xvec
,
Athrs
*
self
.
yfactor
,
'-'
,
color
=
'#F26522'
,
linewidth
=
3
,
label
=
'threshold amplitudes'
)
ax
.
legend
(
loc
=
'lower center'
,
frameon
=
False
,
fontsize
=
fs
)
@property
@abc.abstractmethod
def
onclick_pltscheme
(
self
):
raise
NotImplementedError
def
onClick
(
self
,
event
):
''' Execute action when the user clicks on a cell in the 2D map. '''
DC
,
A
=
self
.
getOnClickXY
(
event
)
self
.
plotTimeseries
(
DC
,
A
)
plt
.
show
()
def
plotTimeseries
(
self
,
DC
,
A
,
**
kwargs
):
''' Plot related timeseries for a given duty cycle and amplitude. '''
self
.
drive
.
A
=
A
/
self
.
yfactor
self
.
pp
.
DC
=
DC
/
self
.
xfactor
# Get model output, running simulation if needed
data
,
meta
=
self
.
nbls
.
getOutput
(
*
self
.
sim_args
,
outputdir
=
self
.
root
)
# Plot timeseries of appropriate variables
timeseries
=
GroupedTimeSeries
([(
data
,
meta
)],
pltscheme
=
self
.
onclick_pltscheme
)
return
timeseries
.
render
(
colors
=
self
.
onclick_colors
,
**
kwargs
)[
0
]
def
render
(
self
,
yscale
=
'log'
,
thresholds
=
False
,
mpi
=
False
,
**
kwargs
):
fig
=
super
()
.
render
(
yscale
=
yscale
,
**
kwargs
)
if
thresholds
:
self
.
addThresholdCurve
(
fig
.
axes
[
0
],
fs
=
12
,
mpi
=
mpi
)
return
fig
class
FiringRateMap
(
ActivationMap
):
zkey
=
'Firing rate'
zunit
=
'Hz'
zfactor
=
1e0
suffix
=
'FRmap'
onclick_pltscheme
=
{
'V_m\ |\ Q_/C_{m0}'
:
[
'Vm'
,
'Qm/Cm0'
]}
onclick_colors
=
[
'darkgrey'
,
'k'
]
def
xfunc
(
self
,
data
):
''' Detect spikes in data and compute firing rate. '''
ispikes
,
_
=
detectSpikes
(
data
)
if
ispikes
.
size
>
1
:
t
=
data
[
't'
]
.
values
sr
=
1
/
np
.
diff
(
t
[
ispikes
])
return
np
.
mean
(
sr
)
else
:
return
np
.
nan
def
render
(
self
,
zscale
=
'log'
,
**
kwargs
):
return
super
()
.
render
(
zscale
=
zscale
,
**
kwargs
)
class
CalciumMap
(
ActivationMap
):
zkey
=
'[Ca2+]i'
zunit
=
'uM'
zfactor
=
1e6
suffix
=
'Camap'
onclick_pltscheme
=
{
'Cai'
:
[
'Cai'
]}
def
xfunc
(
self
,
data
):
''' Detect spikes in data and compute firing rate. '''
Cai
=
data
[
'Cai'
]
.
values
*
self
.
zfactor
# uM
return
np
.
mean
(
Cai
)
def
render
(
self
,
zscale
=
'log'
,
**
kwargs
):
return
super
()
.
render
(
zscale
=
zscale
,
**
kwargs
)
map_classes
=
{
'FR'
:
FiringRateMap
,
'Cai'
:
CalciumMap
}
def
getActivationMap
(
key
,
*
args
,
**
kwargs
):
if
key
not
in
map_classes
:
raise
ValueError
(
f
'{key} is not a valid map type'
)
return
map_classes
[
key
](
*
args
,
**
kwargs
)
class
GammaMap
(
XYMap
):
''' Interface to a 2D map showing relative capacitance oscillation amplitude
resulting from BLS simulations at various frequencies and amplitude.
'''
xkey
=
'f_US'
xfactor
=
1e0
xunit
=
'kHz'
ykey
=
'A'
yfactor
=
1e0
yunit
=
'kPa'
zkey
=
'gamma'
zfactor
=
1e0
zunit
=
'-'
suffix
=
'gamma'
def
__init__
(
self
,
root
,
bls
,
freqs
,
amps
):
self
.
bls
=
bls
.
copy
()
super
()
.
__init__
(
root
,
freqs
,
amps
)
@property
def
title
(
self
):
return
f
'Gamma map - {self.bls}'
def
corecode
(
self
):
return
f
'gamma_map_bls{self.bls.a * 1e9:.0f}nm'
def
compute
(
self
,
x
):
f
,
A
=
x
data
,
meta
=
self
.
bls
.
simulate
(
AcousticDrive
(
f
*
1e3
,
A
*
1e3
),
0.
)
Cm
=
self
.
bls
.
v_capacitance
(
data
[
'Z'
])
gamma
=
np
.
ptp
(
Cm
)
/
(
2
*
self
.
bls
.
Cm0
)
logger
.
info
(
f
'f = {f:.2f} kHz, A = {A:.2f} kPa, gamma = {gamma:.2f}'
)
return
gamma
def
onClick
(
self
,
event
):
''' Execute action when the user clicks on a cell in the 2D map. '''
x
=
self
.
getOnClickXY
(
event
)
f
,
A
=
x
out
=
self
.
bls
.
simulate
(
AcousticDrive
(
f
*
1e3
,
A
*
1e3
),
0.
)
GroupedTimeSeries
([
out
])
.
render
()
plt
.
show
()
def
render
(
self
,
xscale
=
'log'
,
yscale
=
'log'
,
figsize
=
(
6
,
4
),
fs
=
12
,
**
kwargs
):
fig
=
super
()
.
render
(
xscale
=
xscale
,
yscale
=
yscale
,
figsize
=
figsize
,
fs
=
fs
,
**
kwargs
)
levels
=
[
0.1
,
0.3
,
0.5
,
0.7
]
colors
=
[
'w'
,
'k'
,
'k'
,
'k'
]
ax
=
fig
.
axes
[
0
]
CS
=
ax
.
contour
(
self
.
xvec
,
self
.
yvec
,
self
.
getOutput
(),
levels
,
colors
=
colors
)
ax
.
clabel
(
CS
,
fontsize
=
fs
,
fmt
=
lambda
x
:
f
'{x:g}'
,
inline_spacing
=
2
)
return
fig
Event Timeline
Log In to Comment