Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F120613556
test_configuration_common.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sat, Jul 5, 14:53
Size
3 KB
Mime Type
text/x-python
Expires
Mon, Jul 7, 14:53 (2 d)
Engine
blob
Format
Raw Data
Handle
27211181
Attached To
R11484 ADDI
test_configuration_common.py
View Options
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
tempfile
class
ConfigTester
(
object
):
def
__init__
(
self
,
parent
,
config_class
=
None
,
**
kwargs
):
self
.
parent
=
parent
self
.
config_class
=
config_class
self
.
inputs_dict
=
kwargs
def
create_and_test_config_common_properties
(
self
):
config
=
self
.
config_class
(
**
self
.
inputs_dict
)
self
.
parent
.
assertTrue
(
hasattr
(
config
,
"vocab_size"
))
self
.
parent
.
assertTrue
(
hasattr
(
config
,
"hidden_size"
))
self
.
parent
.
assertTrue
(
hasattr
(
config
,
"num_attention_heads"
))
self
.
parent
.
assertTrue
(
hasattr
(
config
,
"num_hidden_layers"
))
def
create_and_test_config_to_json_string
(
self
):
config
=
self
.
config_class
(
**
self
.
inputs_dict
)
obj
=
json
.
loads
(
config
.
to_json_string
())
for
key
,
value
in
self
.
inputs_dict
.
items
():
self
.
parent
.
assertEqual
(
obj
[
key
],
value
)
def
create_and_test_config_to_json_file
(
self
):
config_first
=
self
.
config_class
(
**
self
.
inputs_dict
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
json_file_path
=
os
.
path
.
join
(
tmpdirname
,
"config.json"
)
config_first
.
to_json_file
(
json_file_path
)
config_second
=
self
.
config_class
.
from_json_file
(
json_file_path
)
self
.
parent
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
create_and_test_config_from_and_save_pretrained
(
self
):
config_first
=
self
.
config_class
(
**
self
.
inputs_dict
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
config_first
.
save_pretrained
(
tmpdirname
)
config_second
=
self
.
config_class
.
from_pretrained
(
tmpdirname
)
self
.
parent
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
create_and_test_config_with_num_labels
(
self
):
config
=
self
.
config_class
(
**
self
.
inputs_dict
,
num_labels
=
5
)
self
.
parent
.
assertEqual
(
len
(
config
.
id2label
),
5
)
self
.
parent
.
assertEqual
(
len
(
config
.
label2id
),
5
)
config
.
num_labels
=
3
self
.
parent
.
assertEqual
(
len
(
config
.
id2label
),
3
)
self
.
parent
.
assertEqual
(
len
(
config
.
label2id
),
3
)
def
check_config_can_be_init_without_params
(
self
):
if
self
.
config_class
.
is_composition
:
return
config
=
self
.
config_class
()
self
.
parent
.
assertIsNotNone
(
config
)
def
run_common_tests
(
self
):
self
.
create_and_test_config_common_properties
()
self
.
create_and_test_config_to_json_string
()
self
.
create_and_test_config_to_json_file
()
self
.
create_and_test_config_from_and_save_pretrained
()
self
.
create_and_test_config_with_num_labels
()
self
.
check_config_can_be_init_without_params
()
Event Timeline
Log In to Comment