Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F120562971
test_tokenization_rag.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sat, Jul 5, 06:17
Size
7 KB
Mime Type
text/x-python
Expires
Mon, Jul 7, 06:17 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
27198758
Attached To
R11484 ADDI
test_tokenization_rag.py
View Options
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
shutil
import
tempfile
from
unittest
import
TestCase
from
transformers
import
BartTokenizer
,
BartTokenizerFast
,
DPRQuestionEncoderTokenizer
,
DPRQuestionEncoderTokenizerFast
from
transformers.file_utils
import
is_datasets_available
,
is_faiss_available
,
is_torch_available
from
transformers.models.bart.configuration_bart
import
BartConfig
from
transformers.models.bert.tokenization_bert
import
VOCAB_FILES_NAMES
as
DPR_VOCAB_FILES_NAMES
from
transformers.models.dpr.configuration_dpr
import
DPRConfig
from
transformers.models.roberta.tokenization_roberta
import
VOCAB_FILES_NAMES
as
BART_VOCAB_FILES_NAMES
from
transformers.testing_utils
import
require_datasets
,
require_faiss
,
require_tokenizers
,
require_torch
,
slow
if
is_torch_available
()
and
is_datasets_available
()
and
is_faiss_available
():
from
transformers.models.rag.configuration_rag
import
RagConfig
from
transformers.models.rag.tokenization_rag
import
RagTokenizer
@require_faiss
@require_datasets
@require_torch
class
RagTokenizerTest
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdirname
=
tempfile
.
mkdtemp
()
self
.
retrieval_vector_size
=
8
# DPR tok
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[PAD]"
,
"[MASK]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
,
"low"
,
"lowest"
,
]
dpr_tokenizer_path
=
os
.
path
.
join
(
self
.
tmpdirname
,
"dpr_tokenizer"
)
os
.
makedirs
(
dpr_tokenizer_path
,
exist_ok
=
True
)
self
.
vocab_file
=
os
.
path
.
join
(
dpr_tokenizer_path
,
DPR_VOCAB_FILES_NAMES
[
"vocab_file"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
# BART tok
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
,
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"
\u0120
l"
,
"
\u0120
l o"
,
"
\u0120
lo w"
,
"e r"
,
""
]
self
.
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
bart_tokenizer_path
=
os
.
path
.
join
(
self
.
tmpdirname
,
"bart_tokenizer"
)
os
.
makedirs
(
bart_tokenizer_path
,
exist_ok
=
True
)
self
.
vocab_file
=
os
.
path
.
join
(
bart_tokenizer_path
,
BART_VOCAB_FILES_NAMES
[
"vocab_file"
])
self
.
merges_file
=
os
.
path
.
join
(
bart_tokenizer_path
,
BART_VOCAB_FILES_NAMES
[
"merges_file"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
)
+
"
\n
"
)
with
open
(
self
.
merges_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
def
get_dpr_tokenizer
(
self
)
->
DPRQuestionEncoderTokenizer
:
return
DPRQuestionEncoderTokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
tmpdirname
,
"dpr_tokenizer"
))
def
get_bart_tokenizer
(
self
)
->
BartTokenizer
:
return
BartTokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
tmpdirname
,
"bart_tokenizer"
))
def
tearDown
(
self
):
shutil
.
rmtree
(
self
.
tmpdirname
)
@require_tokenizers
def
test_save_load_pretrained_with_saved_config
(
self
):
save_dir
=
os
.
path
.
join
(
self
.
tmpdirname
,
"rag_tokenizer"
)
rag_config
=
RagConfig
(
question_encoder
=
DPRConfig
()
.
to_dict
(),
generator
=
BartConfig
()
.
to_dict
())
rag_tokenizer
=
RagTokenizer
(
question_encoder
=
self
.
get_dpr_tokenizer
(),
generator
=
self
.
get_bart_tokenizer
())
rag_config
.
save_pretrained
(
save_dir
)
rag_tokenizer
.
save_pretrained
(
save_dir
)
new_rag_tokenizer
=
RagTokenizer
.
from_pretrained
(
save_dir
,
config
=
rag_config
)
self
.
assertIsInstance
(
new_rag_tokenizer
.
question_encoder
,
DPRQuestionEncoderTokenizerFast
)
self
.
assertEqual
(
new_rag_tokenizer
.
question_encoder
.
get_vocab
(),
rag_tokenizer
.
question_encoder
.
get_vocab
())
self
.
assertIsInstance
(
new_rag_tokenizer
.
generator
,
BartTokenizerFast
)
self
.
assertEqual
(
new_rag_tokenizer
.
generator
.
get_vocab
(),
rag_tokenizer
.
generator
.
get_vocab
())
@slow
def
test_pretrained_token_nq_tokenizer
(
self
):
tokenizer
=
RagTokenizer
.
from_pretrained
(
"facebook/rag-token-nq"
)
input_strings
=
[
"who got the first nobel prize in physics"
,
"when is the next deadpool movie being released"
,
"which mode is used for short wave broadcast service"
,
"who is the owner of reading football club"
,
"when is the next scandal episode coming out"
,
"when is the last time the philadelphia won the superbowl"
,
"what is the most current adobe flash player version"
,
"how many episodes are there in dragon ball z"
,
"what is the first step in the evolution of the eye"
,
"where is gall bladder situated in human body"
,
"what is the main mineral in lithium batteries"
,
"who is the president of usa right now"
,
"where do the greasers live in the outsiders"
,
"panda is a national animal of which country"
,
"what is the name of manchester united stadium"
,
]
input_dict
=
tokenizer
(
input_strings
)
self
.
assertIsNotNone
(
input_dict
)
@slow
def
test_pretrained_sequence_nq_tokenizer
(
self
):
tokenizer
=
RagTokenizer
.
from_pretrained
(
"facebook/rag-sequence-nq"
)
input_strings
=
[
"who got the first nobel prize in physics"
,
"when is the next deadpool movie being released"
,
"which mode is used for short wave broadcast service"
,
"who is the owner of reading football club"
,
"when is the next scandal episode coming out"
,
"when is the last time the philadelphia won the superbowl"
,
"what is the most current adobe flash player version"
,
"how many episodes are there in dragon ball z"
,
"what is the first step in the evolution of the eye"
,
"where is gall bladder situated in human body"
,
"what is the main mineral in lithium batteries"
,
"who is the president of usa right now"
,
"where do the greasers live in the outsiders"
,
"panda is a national animal of which country"
,
"what is the name of manchester united stadium"
,
]
input_dict
=
tokenizer
(
input_strings
)
self
.
assertIsNotNone
(
input_dict
)
Event Timeline
Log In to Comment