Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F65038621
pltutils.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, May 31, 07:13
Size
19 KB
Mime Type
text/x-python
Expires
Sun, Jun 2, 07:13 (2 d)
Engine
blob
Format
Raw Data
Handle
17968096
Attached To
R4670 PySONIC (old)
pltutils.py
View Options
# -*- coding: utf-8 -*-
# @Author: Theo Lemaire
# @Email: theo.lemaire@epfl.ch
# @Date: 2017-08-21 14:33:36
# @Last Modified by: Theo Lemaire
# @Last Modified time: 2021-05-26 13:43:14
''' Useful functions to generate plots. '''
import
re
import
numpy
as
np
import
pandas
as
pd
import
matplotlib
from
matplotlib.patches
import
Polygon
,
Rectangle
from
matplotlib
import
cm
,
colors
import
matplotlib.pyplot
as
plt
from
..core
import
getModel
from
..utils
import
*
# Matplotlib parameters
matplotlib
.
rcParams
[
'pdf.fonttype'
]
=
42
matplotlib
.
rcParams
[
'ps.fonttype'
]
=
42
matplotlib
.
rcParams
[
'font.family'
]
=
'arial'
def
getSymmetricCmap
(
cmap_key
):
cmap
=
plt
.
get_cmap
(
cmap_key
)
cl
=
np
.
vstack
((
cmap
.
colors
,
cmap
.
reversed
()
.
colors
))
return
colors
.
LinearSegmentedColormap
.
from_list
(
f
'sym_{cmap_key}'
,
cl
)
for
k
in
[
'viridis'
,
'plasma'
,
'inferno'
,
'magma'
,
'cividis'
]:
for
cmap_key
in
[
k
,
f
'{k}_r'
]:
sym_cmap
=
getSymmetricCmap
(
cmap_key
)
plt
.
register_cmap
(
name
=
sym_cmap
.
name
,
cmap
=
sym_cmap
)
def
cm2inch
(
*
tupl
):
inch
=
2.54
if
isinstance
(
tupl
[
0
],
tuple
):
return
tuple
(
i
/
inch
for
i
in
tupl
[
0
])
else
:
return
tuple
(
i
/
inch
for
i
in
tupl
)
def
extractPltVar
(
model
,
pltvar
,
df
,
meta
=
None
,
nsamples
=
0
,
name
=
''
):
if
'func'
in
pltvar
:
s
=
pltvar
[
'func'
]
if
not
s
.
startswith
(
'meta'
):
s
=
f
'model.{s}'
try
:
var
=
eval
(
s
)
except
AttributeError
as
err
:
if
hasattr
(
model
,
'pneuron'
):
var
=
eval
(
s
.
replace
(
'model'
,
'model.pneuron'
))
else
:
raise
err
elif
'key'
in
pltvar
:
var
=
df
[
pltvar
[
'key'
]]
elif
'constant'
in
pltvar
:
var
=
eval
(
pltvar
[
'constant'
])
*
np
.
ones
(
nsamples
)
else
:
var
=
df
[
name
]
if
isinstance
(
var
,
pd
.
Series
):
var
=
var
.
values
var
=
var
.
copy
()
if
var
.
size
==
nsamples
-
1
:
var
=
np
.
insert
(
var
,
0
,
var
[
0
])
var
*=
pltvar
.
get
(
'factor'
,
1
)
return
var
def
setGrid
(
n
,
ncolmax
=
3
):
''' Determine number of rows and columns in figure grid, based on number of
variables to plot. '''
if
n
<=
ncolmax
:
return
(
1
,
n
)
else
:
return
((
n
-
1
)
//
ncolmax
+
1
,
ncolmax
)
def
setNormalizer
(
cmap
,
bounds
,
scale
=
'lin'
):
norm
=
{
'lin'
:
colors
.
Normalize
,
'log'
:
colors
.
LogNorm
,
'symlog'
:
colors
.
SymLogNorm
}[
scale
](
*
bounds
)
sm
=
cm
.
ScalarMappable
(
norm
=
norm
,
cmap
=
cmap
)
sm
.
_A
=
[]
return
norm
,
sm
class
GenericPlot
:
def
__init__
(
self
,
outputs
):
''' Constructor.
:param outputs: list / generator of simulation outputs
'''
try
:
iter
(
outputs
)
except
TypeError
:
outputs
=
[
outputs
]
self
.
outputs
=
outputs
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
render
(
*
args
,
**
kwargs
)
def
figtitle
(
self
,
model
,
meta
):
return
model
.
desc
(
meta
)
@staticmethod
def
wraptitle
(
ax
,
title
,
maxwidth
=
120
,
sep
=
':'
,
fs
=
10
,
y
=
1.0
):
if
len
(
title
)
>
maxwidth
:
title
=
'
\n
'
.
join
(
title
.
split
(
sep
))
y
=
0.94
h
=
ax
.
set_title
(
title
,
fontsize
=
fs
)
h
.
set_y
(
y
)
@staticmethod
def
getData
(
entry
,
frequency
=
1
,
trange
=
None
):
if
entry
is
None
:
raise
ValueError
(
'non-existing data'
)
if
isinstance
(
entry
,
str
):
data
,
meta
=
loadData
(
entry
,
frequency
)
else
:
data
,
meta
=
entry
data
=
data
.
iloc
[::
frequency
]
if
trange
is
not
None
:
tmin
,
tmax
=
trange
data
=
data
.
loc
[(
data
[
't'
]
>=
tmin
)
&
(
data
[
't'
]
<=
tmax
)]
return
data
,
meta
def
render
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
@staticmethod
def
getSimType
(
fname
):
''' Get sim type from filename. '''
mo
=
re
.
search
(
'(^[A-Z]*)_(.*).pkl'
,
fname
)
if
not
mo
:
raise
ValueError
(
f
'Could not find sim-key in filename: "{fname}"'
)
return
mo
.
group
(
1
)
@staticmethod
def
getModel
(
*
args
,
**
kwargs
):
return
getModel
(
*
args
,
**
kwargs
)
@staticmethod
def
getTimePltVar
(
tscale
):
''' Return time plot variable for a given temporal scale. '''
return
{
'desc'
:
'time'
,
'label'
:
'time'
,
'unit'
:
tscale
,
'factor'
:
{
'ms'
:
1e3
,
'us'
:
1e6
}[
tscale
],
'onset'
:
{
'ms'
:
1e-3
,
'us'
:
1e-6
}[
tscale
]
}
@staticmethod
def
createBackBone
(
*
args
,
**
kwargs
):
raise
NotImplementedError
@staticmethod
def
prettify
(
ax
,
xticks
=
None
,
yticks
=
None
,
xfmt
=
'{:.0f}'
,
yfmt
=
'{:+.0f}'
):
try
:
ticks
=
ax
.
get_ticks
()
ticks
=
(
min
(
ticks
),
max
(
ticks
))
ax
.
set_ticks
(
ticks
)
ax
.
set_ticklabels
([
xfmt
.
format
(
x
)
for
x
in
ticks
])
except
AttributeError
:
if
xticks
is
None
:
xticks
=
ax
.
get_xticks
()
xticks
=
(
min
(
xticks
),
max
(
xticks
))
if
yticks
is
None
:
yticks
=
ax
.
get_yticks
()
yticks
=
(
min
(
yticks
),
max
(
yticks
))
ax
.
set_xticks
(
xticks
)
ax
.
set_yticks
(
yticks
)
if
xfmt
is
not
None
:
ax
.
set_xticklabels
([
xfmt
.
format
(
x
)
for
x
in
xticks
])
if
yfmt
is
not
None
:
ax
.
set_yticklabels
([
yfmt
.
format
(
y
)
for
y
in
yticks
])
@staticmethod
def
addInset
(
fig
,
ax
,
inset
):
''' Create inset axis. '''
inset_ax
=
fig
.
add_axes
(
ax
.
get_position
())
inset_ax
.
set_zorder
(
1
)
inset_ax
.
set_xlim
(
inset
[
'xlims'
][
0
],
inset
[
'xlims'
][
1
])
inset_ax
.
set_ylim
(
inset
[
'ylims'
][
0
],
inset
[
'ylims'
][
1
])
inset_ax
.
set_xticks
([])
inset_ax
.
set_yticks
([])
inset_ax
.
add_patch
(
Rectangle
((
inset
[
'xlims'
][
0
],
inset
[
'ylims'
][
0
]),
inset
[
'xlims'
][
1
]
-
inset
[
'xlims'
][
0
],
inset
[
'ylims'
][
1
]
-
inset
[
'ylims'
][
0
],
color
=
'w'
))
return
inset_ax
@staticmethod
def
materializeInset
(
ax
,
inset_ax
,
inset
):
''' Materialize inset with zoom boox. '''
# Re-position inset axis
axpos
=
ax
.
get_position
()
left
,
right
,
=
rescale
(
inset
[
'xcoords'
],
ax
.
get_xlim
()[
0
],
ax
.
get_xlim
()[
1
],
axpos
.
x0
,
axpos
.
x0
+
axpos
.
width
)
bottom
,
top
,
=
rescale
(
inset
[
'ycoords'
],
ax
.
get_ylim
()[
0
],
ax
.
get_ylim
()[
1
],
axpos
.
y0
,
axpos
.
y0
+
axpos
.
height
)
inset_ax
.
set_position
([
left
,
bottom
,
right
-
left
,
top
-
bottom
])
for
i
in
inset_ax
.
spines
.
values
():
i
.
set_linewidth
(
2
)
# Materialize inset target region with contour frame
ax
.
plot
(
inset
[
'xlims'
],
[
inset
[
'ylims'
][
0
]]
*
2
,
linestyle
=
'-'
,
color
=
'k'
)
ax
.
plot
(
inset
[
'xlims'
],
[
inset
[
'ylims'
][
1
]]
*
2
,
linestyle
=
'-'
,
color
=
'k'
)
ax
.
plot
([
inset
[
'xlims'
][
0
]]
*
2
,
inset
[
'ylims'
],
linestyle
=
'-'
,
color
=
'k'
)
ax
.
plot
([
inset
[
'xlims'
][
1
]]
*
2
,
inset
[
'ylims'
],
linestyle
=
'-'
,
color
=
'k'
)
# Link target and inset with dashed lines if possible
if
inset
[
'xcoords'
][
1
]
<
inset
[
'xlims'
][
0
]:
ax
.
plot
([
inset
[
'xcoords'
][
1
],
inset
[
'xlims'
][
0
]],
[
inset
[
'ycoords'
][
0
],
inset
[
'ylims'
][
0
]],
linestyle
=
'--'
,
color
=
'k'
)
ax
.
plot
([
inset
[
'xcoords'
][
1
],
inset
[
'xlims'
][
0
]],
[
inset
[
'ycoords'
][
1
],
inset
[
'ylims'
][
1
]],
linestyle
=
'--'
,
color
=
'k'
)
elif
inset
[
'xcoords'
][
0
]
>
inset
[
'xlims'
][
1
]:
ax
.
plot
([
inset
[
'xcoords'
][
0
],
inset
[
'xlims'
][
1
]],
[
inset
[
'ycoords'
][
0
],
inset
[
'ylims'
][
0
]],
linestyle
=
'--'
,
color
=
'k'
)
ax
.
plot
([
inset
[
'xcoords'
][
0
],
inset
[
'xlims'
][
1
]],
[
inset
[
'ycoords'
][
1
],
inset
[
'ylims'
][
1
]],
linestyle
=
'--'
,
color
=
'k'
)
else
:
logger
.
warning
(
'Inset x-coordinates intersect with those of target region'
)
def
postProcess
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
@staticmethod
def
removeSpines
(
ax
):
for
item
in
[
'top'
,
'right'
]:
ax
.
spines
[
item
]
.
set_visible
(
False
)
@staticmethod
def
setXTicks
(
ax
,
xticks
=
None
):
if
xticks
is
not
None
:
ax
.
set_xticks
(
xticks
)
@staticmethod
def
setYTicks
(
ax
,
yticks
=
None
):
if
yticks
is
not
None
:
ax
.
set_yticks
(
yticks
)
@staticmethod
def
setTickLabelsFontSize
(
ax
,
fs
):
for
tick
in
ax
.
xaxis
.
get_major_ticks
()
+
ax
.
yaxis
.
get_major_ticks
():
tick
.
label
.
set_fontsize
(
fs
)
@staticmethod
def
setXLabel
(
ax
,
xplt
,
fs
):
ax
.
set_xlabel
(
'$
\\
rm {}\ ({})$'
.
format
(
xplt
[
"label"
],
xplt
[
"unit"
]),
fontsize
=
fs
)
@staticmethod
def
setYLabel
(
ax
,
yplt
,
fs
):
ax
.
set_ylabel
(
'$
\\
rm {}\ ({})$'
.
format
(
yplt
[
"label"
],
yplt
.
get
(
"unit"
,
""
)),
fontsize
=
fs
)
@classmethod
def
addCmap
(
cls
,
fig
,
cmap
,
handles
,
comp_values
,
comp_info
,
fs
,
prettify
,
zscale
=
'lin'
):
if
all
(
isinstance
(
x
,
str
)
for
x
in
comp_values
):
# If list of strings, assume that index suffixes can be extracted
prefix
,
suffixes
=
extractCommonPrefix
(
comp_values
)
comp_values
=
[
int
(
s
)
for
s
in
suffixes
]
desc_str
=
f
'{prefix}\ index'
else
:
# Rescale comparison values and adjust unit
comp_values
=
np
.
asarray
(
comp_values
)
*
comp_info
.
get
(
'factor'
,
1.
)
comp_factor
,
comp_prefix
=
getSIpair
(
comp_values
,
scale
=
zscale
)
comp_values
/=
comp_factor
comp_info
[
'unit'
]
=
comp_prefix
+
comp_info
[
'unit'
]
desc_str
=
comp_info
[
"desc"
]
.
replace
(
" "
,
"\ "
)
if
len
(
comp_info
[
'unit'
])
>
0
:
desc_str
=
f
"{desc_str}\ ({comp_info['unit']})"
nvalues
=
len
(
comp_values
)
# Create colormap and normalizer
try
:
mymap
=
plt
.
get_cmap
(
cmap
)
except
ValueError
:
mymap
=
plt
.
get_cmap
(
swapFirstLetterCase
(
cmap
))
norm
,
sm
=
setNormalizer
(
mymap
,
(
min
(
comp_values
),
max
(
comp_values
)),
zscale
)
# Extract and adjust line colors
zcolors
=
sm
.
to_rgba
(
comp_values
)
for
lh
,
c
in
zip
(
handles
,
zcolors
):
if
isIterable
(
lh
):
for
item
in
lh
:
item
.
set_color
(
c
)
else
:
lh
.
set_color
(
c
)
# Add colorbar
fig
.
subplots_adjust
(
left
=
0.1
,
right
=
0.8
,
bottom
=
0.15
,
top
=
0.95
,
hspace
=
0.5
)
cbarax
=
fig
.
add_axes
([
0.85
,
0.15
,
0.03
,
0.8
])
cbar_kwargs
=
{}
if
all
(
isinstance
(
x
,
int
)
for
x
in
comp_values
):
dx
=
np
.
diff
(
comp_values
)
if
all
(
x
==
dx
[
0
]
for
x
in
dx
):
dx
=
dx
[
0
]
ticks
=
comp_values
bounds
=
np
.
hstack
((
ticks
,
[
max
(
ticks
)
+
dx
]))
-
dx
/
2
if
nvalues
>
10
:
ticks
=
[
ticks
[
0
],
ticks
[
-
1
]]
cbar_kwargs
.
update
({
'ticks'
:
ticks
,
'boundaries'
:
bounds
,
'format'
:
'
%1i
'
})
cbarax
.
tick_params
(
axis
=
'both'
,
which
=
'both'
,
length
=
0
)
cbar
=
fig
.
colorbar
(
sm
,
cax
=
cbarax
,
**
cbar_kwargs
)
fig
.
sm
=
sm
# add scalar mappable as figure attribute in case of future need
cbarax
.
set_ylabel
(
f
'$
\\
rm {desc_str}$'
,
fontsize
=
fs
)
if
prettify
:
cls
.
prettify
(
cbar
)
for
item
in
cbarax
.
get_yticklabels
():
item
.
set_fontsize
(
fs
)
class
ComparativePlot
(
GenericPlot
):
def
__init__
(
self
,
outputs
,
varname
):
''' Constructor.
:param outputs: list /generator of simulation outputs to be compared.
:param varname: name of variable to extract and compare.
'''
super
()
.
__init__
(
outputs
)
self
.
varname
=
varname
self
.
comp_ref_key
=
None
self
.
meta_ref
=
None
self
.
comp_info
=
None
self
.
is_unique_comp
=
False
def
checkLabels
(
self
,
labels
):
if
labels
is
not
None
:
if
not
isIterable
(
labels
):
raise
TypeError
(
'Invalid labels: must be an iterable'
)
if
not
all
(
isinstance
(
x
,
str
)
for
x
in
labels
):
raise
TypeError
(
'Invalid labels: must be string typed'
)
def
checkSimType
(
self
,
meta
):
''' Check consistency of sim types across files. '''
if
meta
[
'simkey'
]
!=
self
.
meta_ref
[
'simkey'
]:
raise
ValueError
(
'Invalid comparison: different simulation types'
)
def
checkCompValues
(
self
,
meta
,
comp_values
):
''' Check consistency of differing values across files. '''
# Get differing values across meta dictionaries
diffs
=
differing
(
self
.
meta_ref
,
meta
,
subdkey
=
'meta'
)
# Check that only one value differs
if
len
(
diffs
)
>
1
:
logger
.
warning
(
'More than one differing inputs'
)
self
.
comp_ref_key
=
None
return
[]
# Get the key and differing values
zkey
,
refval
,
val
=
diffs
[
0
]
# If no comparison key yet, fill it up
if
self
.
comp_ref_key
is
None
:
self
.
comp_ref_key
=
zkey
self
.
is_unique_comp
=
True
comp_values
+=
[
refval
,
val
]
# Otherwise, check that comparison matches the existing one
else
:
if
zkey
!=
self
.
comp_ref_key
:
logger
.
warning
(
'inconsistent differing inputs'
)
self
.
comp_ref_key
=
None
return
[]
else
:
comp_values
.
append
(
val
)
return
comp_values
def
checkConsistency
(
self
,
meta
,
comp_values
):
''' Check consistency of sim types and check differing inputs. '''
if
self
.
meta_ref
is
None
:
self
.
meta_ref
=
meta
else
:
self
.
checkSimType
(
meta
)
comp_values
=
self
.
checkCompValues
(
meta
,
comp_values
)
if
self
.
comp_ref_key
is
None
:
self
.
is_unique_comp
=
False
return
comp_values
def
getCompLabels
(
self
,
comp_values
):
if
self
.
comp_info
is
not
None
:
comp_values
=
np
.
array
(
comp_values
)
*
self
.
comp_info
.
get
(
'factor'
,
1
)
if
'unit'
in
self
.
comp_info
:
p
=
self
.
comp_info
.
get
(
'precision'
,
0
)
comp_values
=
[
f
"{si_format(v, p)}{self.comp_info['unit']}"
.
replace
(
' '
,
'\ '
)
for
v
in
comp_values
]
comp_labels
=
[
'$
\\
rm{} = {}$'
.
format
(
self
.
comp_info
[
'label'
],
x
)
for
x
in
comp_values
]
else
:
comp_labels
=
comp_values
return
comp_values
,
comp_labels
def
chooseLabels
(
self
,
labels
,
comp_labels
,
full_labels
):
if
labels
is
not
None
:
return
labels
else
:
if
self
.
is_unique_comp
:
return
comp_labels
else
:
return
full_labels
@staticmethod
def
getCommonLabel
(
lbls
,
seps
=
'_'
):
''' Get a common label from a list of labels, by removing parts that differ across them. '''
# Split every label according to list of separator characters, and save splitters as well
splt_lbls
=
[
re
.
split
(
f
'([{seps}])'
,
x
)
for
x
in
lbls
]
pieces
=
[
x
[::
2
]
for
x
in
splt_lbls
]
splitters
=
[
x
[
1
::
2
]
for
x
in
splt_lbls
]
ncomps
=
len
(
pieces
[
0
])
# Assert that splitters are equivalent across all labels, and reduce them to a single array
assert
(
x
==
x
[
0
]
for
x
in
splitters
),
'Inconsistent splitters'
splitters
=
np
.
array
(
splitters
[
0
])
# Transform pieces into 2D matrix, and evaluate equality of every piece across labels
pieces
=
np
.
array
(
pieces
)
.
T
all_identical
=
[
np
.
all
(
x
==
x
[
0
])
for
x
in
pieces
]
if
np
.
sum
(
all_identical
)
<
ncomps
-
1
:
logger
.
warning
(
'More than one differing inputs'
)
return
''
# Discard differing pieces and remove associated splitters
pieces
=
pieces
[
all_identical
,
0
]
splitters
=
splitters
[
all_identical
[:
-
1
]]
# Remove last splitter if the last pieces were discarded
if
splitters
.
size
==
pieces
.
size
:
splitters
=
splitters
[:
-
1
]
# Join common pieces and associated splitters into a single label
common_lbl
=
''
for
p
,
s
in
zip
(
pieces
,
splitters
):
common_lbl
+=
f
'{p}{s}'
common_lbl
+=
pieces
[
-
1
]
return
common_lbl
.
replace
(
'( '
,
'('
)
def
addExcitationInset
(
ax
,
is_excited
):
''' Add text inset on axis stating excitation status. '''
ax
.
text
(
0.7
,
0.7
,
f
'{"" if is_excited else "not "}excited'
,
transform
=
ax
.
transAxes
,
ha
=
'center'
,
va
=
'center'
,
size
=
30
,
bbox
=
dict
(
boxstyle
=
'round'
,
fc
=
(
0.8
,
1.0
,
0.8
)
if
is_excited
else
(
1.
,
0.8
,
0.8
)
))
def
mirrorProp
(
org
,
new
,
prop
):
''' Mirror an instance property onto another instance of the same class. '''
getattr
(
new
,
f
'set_{prop}'
)(
getattr
(
org
,
f
'get_{prop}'
)())
def
mirrorAxis
(
org_ax
,
new_ax
,
p
=
False
):
''' Mirror content of original axis to a new axis. That includes:
- position on the figure
- spines properties
- ticks, ticklabels, and labels
- vertical spans
'''
mirrorProp
(
org_ax
,
new_ax
,
'position'
)
for
sk
in
[
'bottom'
,
'left'
,
'right'
,
'top'
]:
mirrorProp
(
org_ax
.
spines
[
sk
],
new_ax
.
spines
[
sk
],
'visible'
)
for
prop
in
[
'label'
,
'ticks'
,
'ticklabels'
]:
for
k
in
[
'x'
,
'y'
]:
mirrorProp
(
org_ax
,
new_ax
,
f
'{k}{prop}'
)
ax_children
=
org_ax
.
get_children
()
vspans
=
filter
(
lambda
x
:
isinstance
(
x
,
Polygon
),
ax_children
)
for
vs
in
vspans
:
props
=
vs
.
properties
()
xmin
,
xmax
=
[
props
[
'xy'
][
i
][
0
]
for
i
in
[
0
,
2
]]
kwargs
=
{
k
:
props
[
k
]
for
k
in
[
'alpha'
,
'edgecolor'
,
'facecolor'
]}
if
kwargs
[
'edgecolor'
]
==
(
0.0
,
0.0
,
0.0
,
0.0
):
kwargs
[
'edgecolor'
]
=
'none'
new_ax
.
axvspan
(
xmin
,
xmax
,
**
kwargs
)
def
harmonizeAxesLimits
(
axes
,
dim
=
'xy'
):
''' Harmonize x and/or y limits across an array of axes. '''
axes
=
axes
.
flatten
()
xlims
,
ylims
=
[
np
.
inf
,
-
np
.
inf
],
[
np
.
inf
,
-
np
.
inf
]
for
ax
in
axes
:
xlims
=
[
min
(
xlims
[
0
],
ax
.
get_xlim
()[
0
]),
max
(
xlims
[
1
],
ax
.
get_xlim
()[
1
])]
ylims
=
[
min
(
ylims
[
0
],
ax
.
get_ylim
()[
0
]),
max
(
ylims
[
1
],
ax
.
get_ylim
()[
1
])]
for
ax
in
axes
:
if
dim
in
[
'xy'
,
'x'
]:
ax
.
set_xlim
(
*
xlims
)
if
dim
in
[
'xy'
,
'y'
]:
ax
.
set_ylim
(
*
ylims
)
def
hideSpines
(
ax
,
spines
=
'all'
):
if
isIterable
(
ax
):
for
item
in
ax
:
hideSpines
(
item
,
spines
=
spines
)
else
:
if
spines
==
'all'
:
spines
=
[
'top'
,
'bottom'
,
'right'
,
'left'
]
for
sk
in
spines
:
ax
.
spines
[
sk
]
.
set_visible
(
False
)
def
hideTicks
(
ax
,
key
=
'xy'
):
if
isIterable
(
ax
):
for
item
in
ax
:
hideTicks
(
item
,
key
=
key
)
if
key
in
[
'xy'
,
'x'
]:
ax
.
set_xticks
([])
if
key
in
[
'xy'
,
'y'
]:
ax
.
set_yticks
([])
def
addXscale
(
ax
,
xoffset
,
yoffset
,
unit
=
''
,
lw
=
2
,
fmt
=
'.0f'
,
fs
=
10
,
side
=
'bottom'
):
ybase
=
{
'bottom'
:
0
,
'top'
:
1
}[
side
]
text_extra_yoffset
=
0.07
if
side
==
'bottom'
:
yoffset
=
-
yoffset
text_extra_yoffset
=
-
text_extra_yoffset
ax
.
plot
([
xoffset
,
1
+
xoffset
],
[
ybase
+
yoffset
]
*
2
,
c
=
'k'
,
transform
=
ax
.
transAxes
,
linewidth
=
lw
,
clip_on
=
False
)
xytext
=
(
0.5
+
xoffset
,
ybase
+
yoffset
+
text_extra_yoffset
)
va
=
{
'top'
:
'bottom'
,
'bottom'
:
'top'
}[
side
]
xscale
=
np
.
ptp
(
ax
.
get_xlim
())
ax
.
text
(
*
xytext
,
f
'{xscale:{fmt}} {unit}'
,
transform
=
ax
.
transAxes
,
ha
=
'center'
,
va
=
va
,
fontsize
=
fs
)
def
addYscale
(
ax
,
xoffset
,
yoffset
,
unit
=
''
,
lw
=
2
,
fmt
=
'.0f'
,
fs
=
10
,
side
=
'right'
):
xbase
=
{
'left'
:
0
,
'right'
:
1
}[
side
]
text_extra_xoffset
=
0.07
if
side
==
'left'
:
xoffset
=
-
xoffset
text_extra_xoffset
=
-
text_extra_xoffset
ax
.
plot
([
xbase
+
xoffset
]
*
2
,
[
yoffset
,
1
+
yoffset
],
c
=
'k'
,
transform
=
ax
.
transAxes
,
linewidth
=
lw
,
clip_on
=
False
)
xytext
=
(
xbase
+
xoffset
+
text_extra_xoffset
,
.
5
+
yoffset
)
ha
=
{
'left'
:
'right'
,
'right'
:
'left'
}[
side
]
yscale
=
np
.
ptp
(
ax
.
get_ylim
())
ax
.
text
(
*
xytext
,
f
'{yscale:{fmt}} {unit}'
,
transform
=
ax
.
transAxes
,
ha
=
ha
,
va
=
'center'
,
rotation
=
90
,
fontsize
=
fs
)
Event Timeline
Log In to Comment