Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F62997530
_test_bash_script.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, May 17, 01:07
Size
8 KB
Mime Type
text/x-python
Expires
Sun, May 19, 01:07 (2 d)
Engine
blob
Format
Raw Data
Handle
17718066
Attached To
R11484 ADDI
_test_bash_script.py
View Options
#!/usr/bin/env python
import
argparse
import
os
import
sys
from
unittest.mock
import
patch
import
pytorch_lightning
as
pl
import
timeout_decorator
import
torch
from
distillation
import
SummarizationDistiller
,
distill_main
from
finetune
import
SummarizationModule
,
main
from
transformers
import
MarianMTModel
from
transformers.file_utils
import
cached_path
from
transformers.testing_utils
import
TestCasePlus
,
require_torch_gpu
,
slow
from
utils
import
load_json
MARIAN_MODEL
=
"sshleifer/mar_enro_6_3_student"
class
TestMbartCc25Enro
(
TestCasePlus
):
def
setUp
(
self
):
super
()
.
setUp
()
data_cached
=
cached_path
(
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz"
,
extract_compressed_file
=
True
,
)
self
.
data_dir
=
f
"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k"
@slow
@require_torch_gpu
def
test_model_download
(
self
):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
# @timeout_decorator.timeout(1200)
@slow
@require_torch_gpu
def
test_train_mbart_cc25_enro_script
(
self
):
env_vars_to_replace
=
{
"$MAX_LEN"
:
64
,
"$BS"
:
64
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
self
.
data_dir
,
"facebook/mbart-large-cc25"
:
MARIAN_MODEL
,
# "val_check_interval=0.25": "val_check_interval=1.0",
"--learning_rate=3e-5"
:
"--learning_rate 3e-4"
,
"--num_train_epochs 6"
:
"--num_train_epochs 1"
,
}
# Clean up bash script
bash_script
=
(
self
.
test_file_dir
/
"train_mbart_cc25_enro.sh"
)
.
open
()
.
read
()
.
split
(
"finetune.py"
)[
1
]
.
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
)
.
strip
()
.
replace
(
'"$@"'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
self
.
get_auto_remove_tmp_dir
()
# bash_script = bash_script.replace("--fp16 ", "")
args
=
f
"""
--output_dir {output_dir}
--tokenizer_name Helsinki-NLP/opus-mt-en-ro
--sortish_sampler
--do_predict
--gpus 1
--freeze_encoder
--n_train 40000
--n_val 500
--n_test 500
--fp16_opt_level O1
--num_sanity_val_steps 0
--eval_beams 2
"""
.
split
()
# XXX: args.gpus > 1 : handle multi_gpu in the future
testargs
=
[
"finetune.py"
]
+
bash_script
.
split
()
+
args
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
model
=
main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
self
.
assertEqual
(
len
(
metrics
[
"val"
]),
(
args
.
max_epochs
/
args
.
val_check_interval
))
assert
isinstance
(
last_step_stats
[
f
"val_avg_{model.val_metric}"
],
float
)
self
.
assertGreater
(
last_step_stats
[
"val_avg_gen_time"
],
0.01
)
# model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
self
.
assertLessEqual
(
last_step_stats
[
"val_avg_gen_time"
],
1.0
)
# test learning requirements:
# 1. BLEU improves over the course of training by more than 2 pts
self
.
assertGreater
(
last_step_stats
[
"val_avg_bleu"
]
-
first_step_stats
[
"val_avg_bleu"
],
2
)
# 2. BLEU finishes above 17
self
.
assertGreater
(
last_step_stats
[
"val_avg_bleu"
],
17
)
# 3. test BLEU and val BLEU within ~1.1 pt.
self
.
assertLess
(
abs
(
metrics
[
"val"
][
-
1
][
"val_avg_bleu"
]
-
metrics
[
"test"
][
-
1
][
"test_avg_bleu"
]),
1.1
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
]
.
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
class
TestDistilMarianNoTeacher
(
TestCasePlus
):
@timeout_decorator.timeout
(
600
)
@slow
@require_torch_gpu
def
test_opus_mt_distill_script
(
self
):
data_dir
=
f
"{self.test_file_dir_str}/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
16
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
(
(
self
.
test_file_dir
/
"distil_marian_no_teacher.sh"
)
.
open
()
.
read
()
.
split
(
"distillation.py"
)[
1
]
.
strip
()
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
)
.
strip
()
.
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
self
.
get_auto_remove_tmp_dir
()
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
epochs
=
6
testargs
=
(
[
"distillation.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir={output_dir}"
,
"--gpus=1"
,
"--learning_rate=1e-3"
,
f
"--num_train_epochs={epochs}"
,
"--warmup_steps=10"
,
"--val_check_interval=1.0"
,
"--do_predict"
,
]
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
# assert args.gpus == gpus THIS BREAKS for multi_gpu
model
=
distill_main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
>=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_{model.val_metric}"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
]
.
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
Event Timeline
Log In to Comment