Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F120676238
trainer.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Jul 6, 05:22
Size
1 KB
Mime Type
text/x-python
Expires
Tue, Jul 8, 05:22 (2 d)
Engine
blob
Format
Raw Data
Handle
27227515
Attached To
R8206 networkTraining
trainer.py
View Options
import
torch
from
torch.autograd
import
Variable
import
torch.backends.cudnn
as
cudnn
import
torch.optim
as
optim
import
numpy
as
np
import
sys
import
time
class
trainer
:
def
__init__
(
self
,
net
,
train_loader
,
optimizer
,
loss_function
,
logger
,
tester
,
test_every
,
lr_scheduler
=
None
,
lrStepPer
=
'batch'
):
self
.
net
=
net
self
.
dataLoader
=
train_loader
self
.
optimizer
=
optimizer
self
.
crit
=
loss_function
.
cuda
()
self
.
logger
=
logger
self
.
di
=
iter
(
self
.
dataLoader
)
self
.
epoch
=
0
self
.
tot_iter
=
0
self
.
prev_iter
=
self
.
tot_iter
self
.
test_every
=
test_every
self
.
tester
=
tester
self
.
lr_scheduler
=
lr_scheduler
self
.
lrStepPer
=
lrStepPer
def
train
(
self
,
numiter
):
self
.
net
.
train
()
local_iter
=
0
t0
=
time
.
time
()
while
local_iter
<
numiter
:
try
:
img
,
lbl
=
next
(
self
.
di
)
self
.
optimizer
.
zero_grad
()
img
=
Variable
(
img
.
cuda
())
out
=
self
.
net
.
forward
(
img
)
loss
=
self
.
crit
(
out
,
lbl
.
cuda
())
loss
.
backward
()
self
.
optimizer
.
step
()
self
.
logger
.
add
(
loss
.
data
.
cpu
()
.
numpy
(),
out
,
lbl
)
local_iter
+=
1
self
.
tot_iter
+=
1
if
self
.
lr_scheduler
and
self
.
lrStepPer
==
'batch'
:
self
.
lr_scheduler
.
step
()
t1
=
time
.
time
()
if
t1
-
t0
>
3
:
sys
.
stdout
.
write
(
'
\r
Iter:
%8d
\t
Epoch:
%6d
\t
Time/iter:
%6f
'
%
(
self
.
tot_iter
,
self
.
epoch
,
(
t1
-
t0
)
/
(
self
.
tot_iter
-
self
.
prev_iter
)))
t0
=
t1
self
.
prev_iter
=
self
.
tot_iter
except
StopIteration
:
lastLoss
=
self
.
logger
.
logEpoch
(
self
.
net
)
self
.
epoch
+=
1
self
.
di
=
iter
(
self
.
dataLoader
)
if
self
.
test_every
and
self
.
epoch
%
self
.
test_every
==
0
:
self
.
tester
.
test
(
self
.
net
)
if
self
.
lr_scheduler
and
self
.
lrStepPer
==
'epoch'
:
self
.
lr_scheduler
.
step
(
lastLoss
)
Event Timeline
Log In to Comment