Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F62075259
tokenization_mbart.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, May 10, 18:38
Size
8 KB
Mime Type
text/x-python
Expires
Sun, May 12, 18:38 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
17599599
Attached To
R11484 ADDI
tokenization_mbart.py
View Options
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
...tokenization_utils
import
BatchEncoding
from
...utils
import
logging
from
..xlm_roberta.tokenization_xlm_roberta
import
XLMRobertaTokenizer
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"facebook/mbart-large-en-ro"
:
"https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model"
,
"facebook/mbart-large-cc25"
:
"https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model"
,
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
"facebook/mbart-large-en-ro"
:
1024
,
"facebook/mbart-large-cc25"
:
1024
,
}
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
,
]
class
MBartTokenizer
(
XLMRobertaTokenizer
):
"""
Construct an MBART tokenizer.
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer`. Refer to
superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
initialization parameters and other methods.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents.
Examples::
>>> from transformers import MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
"""
vocab_files_names
=
VOCAB_FILES_NAMES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
prefix_tokens
:
List
[
int
]
=
[]
suffix_tokens
:
List
[
int
]
=
[]
def
__init__
(
self
,
*
args
,
tokenizer_file
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
**
kwargs
):
super
()
.
__init__
(
*
args
,
tokenizer_file
=
tokenizer_file
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
**
kwargs
)
self
.
sp_model_size
=
len
(
self
.
sp_model
)
self
.
lang_code_to_id
=
{
code
:
self
.
sp_model_size
+
i
+
self
.
fairseq_offset
for
i
,
code
in
enumerate
(
FAIRSEQ_LANGUAGE_CODES
)
}
self
.
id_to_lang_code
=
{
v
:
k
for
k
,
v
in
self
.
lang_code_to_id
.
items
()}
self
.
fairseq_tokens_to_ids
[
"<mask>"
]
=
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
self
.
fairseq_tokens_to_ids
.
update
(
self
.
lang_code_to_id
)
self
.
fairseq_ids_to_tokens
=
{
v
:
k
for
k
,
v
in
self
.
fairseq_tokens_to_ids
.
items
()}
self
.
_additional_special_tokens
=
list
(
self
.
lang_code_to_id
.
keys
())
self
.
_src_lang
=
src_lang
if
src_lang
is
not
None
else
"en_XX"
self
.
cur_lang_code_id
=
self
.
lang_code_to_id
[
self
.
_src_lang
]
self
.
tgt_lang
=
tgt_lang
self
.
set_src_lang_special_tokens
(
self
.
_src_lang
)
@property
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
+
1
# Plus 1 for the mask token
@property
def
src_lang
(
self
)
->
str
:
return
self
.
_src_lang
@src_lang.setter
def
src_lang
(
self
,
new_src_lang
:
str
)
->
None
:
self
.
_src_lang
=
new_src_lang
self
.
set_src_lang_special_tokens
(
self
.
_src_lang
)
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if
already_has_special_tokens
:
if
token_ids_1
is
not
None
:
raise
ValueError
(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return
list
(
map
(
lambda
x
:
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
,
token_ids_0
))
prefix_ones
=
[
1
]
*
len
(
self
.
prefix_tokens
)
suffix_ones
=
[
1
]
*
len
(
self
.
suffix_tokens
)
if
token_ids_1
is
None
:
return
prefix_ones
+
([
0
]
*
len
(
token_ids_0
))
+
suffix_ones
return
prefix_ones
+
([
0
]
*
len
(
token_ids_0
))
+
([
0
]
*
len
(
token_ids_1
))
+
suffix_ones
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``X [eos, tgt_lang_code]``
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if
token_ids_1
is
None
:
return
self
.
prefix_tokens
+
token_ids_0
+
self
.
suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
def
prepare_seq2seq_batch
(
self
,
src_texts
:
List
[
str
],
src_lang
:
str
=
"en_XX"
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_lang
:
str
=
"ro_RO"
,
**
kwargs
,
)
->
BatchEncoding
:
self
.
src_lang
=
src_lang
self
.
tgt_lang
=
tgt_lang
return
super
()
.
prepare_seq2seq_batch
(
src_texts
,
tgt_texts
,
**
kwargs
)
@contextmanager
def
as_target_tokenizer
(
self
):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self
.
set_tgt_lang_special_tokens
(
self
.
tgt_lang
)
yield
self
.
set_src_lang_special_tokens
(
self
.
src_lang
)
def
set_src_lang_special_tokens
(
self
,
src_lang
)
->
None
:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
self
.
cur_lang_code
=
self
.
lang_code_to_id
[
src_lang
]
self
.
prefix_tokens
=
[]
self
.
suffix_tokens
=
[
self
.
eos_token_id
,
self
.
cur_lang_code
]
def
set_tgt_lang_special_tokens
(
self
,
lang
:
str
)
->
None
:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
self
.
cur_lang_code
=
self
.
lang_code_to_id
[
lang
]
self
.
prefix_tokens
=
[]
self
.
suffix_tokens
=
[
self
.
eos_token_id
,
self
.
cur_lang_code
]
Event Timeline
Log In to Comment