Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F86463682
metrics.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Oct 6, 15:47
Size
1 KB
Mime Type
text/x-python
Expires
Tue, Oct 8, 15:47 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
21427649
Attached To
R12535 ME-390-2022
metrics.py
View Options
# Source: https://github.com/dmizr/phuber/blob/master/phuber/metrics.py
import
torch
class
LossMetric
:
"""Keeps track of the loss over an epoch"""
def
__init__
(
self
)
->
None
:
self
.
running_loss
=
0
self
.
count
=
0
def
update
(
self
,
loss
:
float
,
batch_size
:
int
)
->
None
:
self
.
running_loss
+=
loss
*
batch_size
self
.
count
+=
batch_size
def
compute
(
self
)
->
float
:
return
self
.
running_loss
/
self
.
count
def
reset
(
self
)
->
None
:
self
.
running_loss
=
0
self
.
count
=
0
class
AccuracyMetric
:
"""Keeps track of the top-k accuracy over an epoch
Args:
k (int): Value of k for top-k accuracy
"""
def
__init__
(
self
,
k
:
int
=
1
)
->
None
:
self
.
correct
=
0
self
.
total
=
0
self
.
k
=
k
def
update
(
self
,
out
:
torch
.
Tensor
,
target
:
torch
.
Tensor
)
->
None
:
# Computes top-k accuracy
_
,
indices
=
torch
.
topk
(
out
,
self
.
k
,
dim
=-
1
)
target_in_top_k
=
torch
.
eq
(
indices
,
target
[:,
None
])
.
bool
()
.
any
(
-
1
)
total_correct
=
torch
.
sum
(
target_in_top_k
,
dtype
=
torch
.
int
)
.
item
()
total_samples
=
target
.
shape
[
0
]
self
.
correct
+=
total_correct
self
.
total
+=
total_samples
def
compute
(
self
)
->
float
:
return
self
.
correct
/
self
.
total
def
reset
(
self
)
->
None
:
self
.
correct
=
0
self
.
total
=
0
Event Timeline
Log In to Comment