Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F96774527
trainer.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Mon, Dec 30, 17:34
Size
4 KB
Mime Type
text/x-python
Expires
Wed, Jan 1, 17:34 (2 d)
Engine
blob
Format
Raw Data
Handle
23258269
Attached To
R11789 DED Contrastive Learning
trainer.py
View Options
import
torch
import
numpy
as
np
def
fit
(
train_loader
,
val_loader
,
model
,
loss_fn
,
optimizer
,
scheduler
,
n_epochs
,
cuda
,
log_interval
,
metrics
=
[],
start_epoch
=
0
):
"""
Loaders, model, loss function and metrics should work together for a given task,
i.e. The model should be able to process data output of loaders,
loss function should process target output of loaders and outputs from the model
Examples: Classification: batch loader, classification model, NLL loss, accuracy metric
Siamese network: Siamese loader, siamese model, contrastive loss
Online triplet learning: batch loader, embedding model, online triplet loss
"""
train_losses
=
[]
val_losses
=
[]
for
epoch
in
range
(
0
,
start_epoch
):
scheduler
.
step
()
for
epoch
in
range
(
start_epoch
,
n_epochs
):
scheduler
.
step
()
# Train stage
train_loss
,
metrics
=
train_epoch
(
train_loader
,
model
,
loss_fn
,
optimizer
,
cuda
,
log_interval
,
metrics
)
message
=
'Epoch: {}/{}. Train set: Average loss: {:.4f}'
.
format
(
epoch
+
1
,
n_epochs
,
train_loss
)
train_losses
.
append
(
train_loss
)
for
metric
in
metrics
:
message
+=
'
\t
{}: {}'
.
format
(
metric
.
name
(),
metric
.
value
())
val_loss
,
metrics
=
test_epoch
(
val_loader
,
model
,
loss_fn
,
cuda
,
metrics
)
val_loss
/=
len
(
val_loader
)
message
+=
'
\n
Epoch: {}/{}. Validation set: Average loss: {:.4f}'
.
format
(
epoch
+
1
,
n_epochs
,
val_loss
)
val_losses
.
append
(
val_loss
)
for
metric
in
metrics
:
message
+=
'
\t
{}: {}'
.
format
(
metric
.
name
(),
metric
.
value
())
print
(
message
)
return
train_losses
,
val_losses
def
train_epoch
(
train_loader
,
model
,
loss_fn
,
optimizer
,
cuda
,
log_interval
,
metrics
):
for
metric
in
metrics
:
metric
.
reset
()
model
.
train
()
losses
=
[]
total_loss
=
0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
target
=
target
if
len
(
target
)
>
0
else
None
if
not
type
(
data
)
in
(
tuple
,
list
):
data
=
(
data
,)
if
cuda
:
data
=
tuple
(
d
.
cuda
()
for
d
in
data
)
if
target
is
not
None
:
target
=
target
.
cuda
()
optimizer
.
zero_grad
()
outputs
=
model
(
*
data
)
if
type
(
outputs
)
not
in
(
tuple
,
list
):
outputs
=
(
outputs
,)
loss_inputs
=
outputs
if
target
is
not
None
:
target
=
(
target
,)
loss_inputs
+=
target
loss_outputs
=
loss_fn
(
*
loss_inputs
)
loss
=
loss_outputs
[
0
]
if
type
(
loss_outputs
)
in
(
tuple
,
list
)
else
loss_outputs
losses
.
append
(
loss
.
item
())
total_loss
+=
loss
.
item
()
loss
.
backward
()
optimizer
.
step
()
for
metric
in
metrics
:
metric
(
outputs
,
target
,
loss_outputs
)
if
batch_idx
%
log_interval
==
0
:
message
=
'Train: [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}'
.
format
(
batch_idx
*
len
(
data
[
0
]),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
np
.
mean
(
losses
))
for
metric
in
metrics
:
message
+=
'
\t
{}: {}'
.
format
(
metric
.
name
(),
metric
.
value
())
print
(
message
)
losses
=
[]
total_loss
/=
(
batch_idx
+
1
)
return
total_loss
,
metrics
def
test_epoch
(
val_loader
,
model
,
loss_fn
,
cuda
,
metrics
):
with
torch
.
no_grad
():
for
metric
in
metrics
:
metric
.
reset
()
model
.
eval
()
val_loss
=
0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
val_loader
):
target
=
target
if
len
(
target
)
>
0
else
None
if
not
type
(
data
)
in
(
tuple
,
list
):
data
=
(
data
,)
if
cuda
:
data
=
tuple
(
d
.
cuda
()
for
d
in
data
)
if
target
is
not
None
:
target
=
target
.
cuda
()
outputs
=
model
(
*
data
)
if
type
(
outputs
)
not
in
(
tuple
,
list
):
outputs
=
(
outputs
,)
loss_inputs
=
outputs
if
target
is
not
None
:
target
=
(
target
,)
loss_inputs
+=
target
loss_outputs
=
loss_fn
(
*
loss_inputs
)
loss
=
loss_outputs
[
0
]
if
type
(
loss_outputs
)
in
(
tuple
,
list
)
else
loss_outputs
val_loss
+=
loss
.
item
()
for
metric
in
metrics
:
metric
(
outputs
,
target
,
loss_outputs
)
return
val_loss
,
metrics
Event Timeline
Log In to Comment