Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F120652338
test_trainer_tpu.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Jul 6, 00:18
Size
3 KB
Mime Type
text/x-python
Expires
Tue, Jul 8, 00:18 (2 d)
Engine
blob
Format
Raw Data
Handle
27195911
Attached To
R11484 ADDI
test_trainer_tpu.py
View Options
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This test is meant to be run in on an instance with TPUs like this:
#
# python examples/xla_spawn.py --num_cores=8 tests/test_trainer_tpu.py
#
# Replace 8 with the number of TPU cores you have.
#
import
sys
from
typing
import
Dict
from
transformers
import
EvalPrediction
,
HfArgumentParser
,
TrainingArguments
,
is_torch_available
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
if
is_torch_available
():
import
torch
from
torch
import
nn
from
torch.utils.data.dataset
import
Dataset
from
transformers
import
Trainer
class
DummyDataset
(
Dataset
):
def
__init__
(
self
,
length
:
int
=
101
):
self
.
length
=
length
def
__len__
(
self
):
return
self
.
length
def
__getitem__
(
self
,
i
)
->
int
:
return
i
class
DummyDataCollator
:
def
__call__
(
self
,
features
):
return
{
"input_ids"
:
torch
.
tensor
(
features
),
"labels"
:
torch
.
tensor
(
features
)}
class
DummyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
()
.
__init__
()
# Add some (unused) params otherwise DDP will complain.
self
.
fc
=
nn
.
Linear
(
120
,
80
)
def
forward
(
self
,
input_ids
,
labels
=
None
):
if
labels
is
not
None
:
return
torch
.
tensor
(
0.0
,
device
=
input_ids
.
device
),
input_ids
else
:
return
input_ids
def
main
():
parser
=
HfArgumentParser
((
TrainingArguments
,))
sys
.
argv
+=
[
"--output_dir"
,
"./examples"
]
training_args
=
parser
.
parse_args_into_dataclasses
()[
0
]
logger
.
warning
(
"Process rank:
%s
, device:
%s
, tpu_num_cores:
%s
"
,
training_args
.
local_rank
,
training_args
.
device
,
training_args
.
tpu_num_cores
,
)
# Essentially, what we want to verify in the distributed case is
# that we get all samples back, in the right order.
# (this is crucial for prediction for instance)
for
dataset_length
in
[
1001
,
256
,
15
]:
dataset
=
DummyDataset
(
dataset_length
)
def
compute_metrics
(
p
:
EvalPrediction
)
->
Dict
:
sequential
=
list
(
range
(
len
(
dataset
)))
success
=
p
.
predictions
.
tolist
()
==
sequential
and
p
.
label_ids
.
tolist
()
==
sequential
return
{
"success"
:
success
}
trainer
=
Trainer
(
model
=
DummyModel
(),
args
=
training_args
,
data_collator
=
DummyDataCollator
(),
eval_dataset
=
dataset
,
compute_metrics
=
compute_metrics
,
)
metrics
=
trainer
.
evaluate
()
logger
.
info
(
metrics
)
if
metrics
[
"eval_success"
]
is
not
True
:
logger
.
error
(
metrics
)
exit
(
1
)
p
=
trainer
.
predict
(
dataset
)
logger
.
info
(
p
.
metrics
)
if
p
.
metrics
[
"eval_success"
]
is
not
True
:
logger
.
error
(
p
.
metrics
)
exit
(
1
)
trainer
.
args
.
eval_accumulation_steps
=
2
metrics
=
trainer
.
evaluate
()
logger
.
info
(
metrics
)
if
metrics
[
"eval_success"
]
is
not
True
:
logger
.
error
(
metrics
)
exit
(
1
)
p
=
trainer
.
predict
(
dataset
)
logger
.
info
(
p
.
metrics
)
if
p
.
metrics
[
"eval_success"
]
is
not
True
:
logger
.
error
(
p
.
metrics
)
exit
(
1
)
trainer
.
args
.
eval_accumulation_steps
=
None
logger
.
info
(
"🔥 All distributed tests successful"
)
def
_mp_fn
(
index
):
# For xla_spawn (TPUs)
main
()
if
__name__
==
"__main__"
:
main
()
Event Timeline
Log In to Comment