Page MenuHomec4science

tokenization_auto.py
No OneTemporary

File Metadata

Created
Tue, May 7, 04:25

tokenization_auto.py

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Tokenizer class. """
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...file_utils import is_sentencepiece_available, is_tokenizers_available
from ...utils import logging
from ..bart.tokenization_bart import BartTokenizer
from ..bert.tokenization_bert import BertTokenizer
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from ..bertweet.tokenization_bertweet import BertweetTokenizer
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
from ..convbert.tokenization_convbert import ConvBertTokenizer
from ..ctrl.tokenization_ctrl import CTRLTokenizer
from ..deberta.tokenization_deberta import DebertaTokenizer
from ..distilbert.tokenization_distilbert import DistilBertTokenizer
from ..dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
from ..electra.tokenization_electra import ElectraTokenizer
from ..flaubert.tokenization_flaubert import FlaubertTokenizer
from ..fsmt.tokenization_fsmt import FSMTTokenizer
from ..funnel.tokenization_funnel import FunnelTokenizer
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
from ..herbert.tokenization_herbert import HerbertTokenizer
from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer
from ..led.tokenization_led import LEDTokenizer
from ..longformer.tokenization_longformer import LongformerTokenizer
from ..lxmert.tokenization_lxmert import LxmertTokenizer
from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer
from ..mpnet.tokenization_mpnet import MPNetTokenizer
from ..openai.tokenization_openai import OpenAIGPTTokenizer
from ..phobert.tokenization_phobert import PhobertTokenizer
from ..prophetnet.tokenization_prophetnet import ProphetNetTokenizer
from ..rag.tokenization_rag import RagTokenizer
from ..retribert.tokenization_retribert import RetriBertTokenizer
from ..roberta.tokenization_roberta import RobertaTokenizer
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
from ..tapas.tokenization_tapas import TapasTokenizer
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
from ..xlm.tokenization_xlm import XLMTokenizer
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BartConfig,
BertConfig,
BertGenerationConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
ConvBertConfig,
CTRLConfig,
DebertaConfig,
DebertaV2Config,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
FSMTConfig,
FunnelConfig,
GPT2Config,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LxmertConfig,
M2M100Config,
MarianConfig,
MBartConfig,
MobileBertConfig,
MPNetConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
ProphetNetConfig,
RagConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
Speech2TextConfig,
SqueezeBertConfig,
T5Config,
TapasConfig,
TransfoXLConfig,
Wav2Vec2Config,
XLMConfig,
XLMProphetNetConfig,
XLMRobertaConfig,
XLNetConfig,
replace_list_option_in_docstrings,
)
if is_sentencepiece_available():
from ..albert.tokenization_albert import AlbertTokenizer
from ..barthez.tokenization_barthez import BarthezTokenizer
from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer
from ..camembert.tokenization_camembert import CamembertTokenizer
from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer
from ..m2m_100 import M2M100Tokenizer
from ..marian.tokenization_marian import MarianTokenizer
from ..mbart.tokenization_mbart import MBartTokenizer
from ..mbart.tokenization_mbart50 import MBart50Tokenizer
from ..mt5 import MT5Tokenizer
from ..pegasus.tokenization_pegasus import PegasusTokenizer
from ..reformer.tokenization_reformer import ReformerTokenizer
from ..speech_to_text import Speech2TextTokenizer
from ..t5.tokenization_t5 import T5Tokenizer
from ..xlm_prophetnet.tokenization_xlm_prophetnet import XLMProphetNetTokenizer
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
from ..xlnet.tokenization_xlnet import XLNetTokenizer
else:
AlbertTokenizer = None
BarthezTokenizer = None
BertGenerationTokenizer = None
CamembertTokenizer = None
DebertaV2Tokenizer = None
MarianTokenizer = None
MBartTokenizer = None
MBart50Tokenizer = None
MT5Tokenizer = None
PegasusTokenizer = None
ReformerTokenizer = None
T5Tokenizer = None
XLMRobertaTokenizer = None
XLNetTokenizer = None
XLMProphetNetTokenizer = None
M2M100Tokenizer = None
Speech2TextTokenizer = None
if is_tokenizers_available():
from ..albert.tokenization_albert_fast import AlbertTokenizerFast
from ..bart.tokenization_bart_fast import BartTokenizerFast
from ..barthez.tokenization_barthez_fast import BarthezTokenizerFast
from ..bert.tokenization_bert_fast import BertTokenizerFast
from ..camembert.tokenization_camembert_fast import CamembertTokenizerFast
from ..convbert.tokenization_convbert_fast import ConvBertTokenizerFast
from ..distilbert.tokenization_distilbert_fast import DistilBertTokenizerFast
from ..dpr.tokenization_dpr_fast import DPRQuestionEncoderTokenizerFast
from ..electra.tokenization_electra_fast import ElectraTokenizerFast
from ..funnel.tokenization_funnel_fast import FunnelTokenizerFast
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from ..herbert.tokenization_herbert_fast import HerbertTokenizerFast
from ..layoutlm.tokenization_layoutlm_fast import LayoutLMTokenizerFast
from ..led.tokenization_led_fast import LEDTokenizerFast
from ..longformer.tokenization_longformer_fast import LongformerTokenizerFast
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
from ..mbart.tokenization_mbart50_fast import MBart50TokenizerFast
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast
from ..mpnet.tokenization_mpnet_fast import MPNetTokenizerFast
from ..mt5 import MT5TokenizerFast
from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast
from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
from ..t5.tokenization_t5_fast import T5TokenizerFast
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
from ..xlnet.tokenization_xlnet_fast import XLNetTokenizerFast
else:
AlbertTokenizerFast = None
BartTokenizerFast = None
BarthezTokenizerFast = None
BertTokenizerFast = None
CamembertTokenizerFast = None
ConvBertTokenizerFast = None
DistilBertTokenizerFast = None
DPRQuestionEncoderTokenizerFast = None
ElectraTokenizerFast = None
FunnelTokenizerFast = None
GPT2TokenizerFast = None
HerbertTokenizerFast = None
LayoutLMTokenizerFast = None
LEDTokenizerFast = None
LongformerTokenizerFast = None
LxmertTokenizerFast = None
MBartTokenizerFast = None
MBart50TokenizerFast = None
MobileBertTokenizerFast = None
MPNetTokenizerFast = None
MT5TokenizerFast = None
OpenAIGPTTokenizerFast = None
PegasusTokenizerFast = None
ReformerTokenizerFast = None
RetriBertTokenizerFast = None
RobertaTokenizerFast = None
SqueezeBertTokenizerFast = None
T5TokenizerFast = None
XLMRobertaTokenizerFast = None
XLNetTokenizerFast = None
logger = logging.get_logger(__name__)
TOKENIZER_MAPPING = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(T5Config, (T5Tokenizer, T5TokenizerFast)),
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
(CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)),
(PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)),
(BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotConfig, (BlenderbotTokenizer, None)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
(TransfoXLConfig, (TransfoXLTokenizer, None)),
(XLNetConfig, (XLNetTokenizer, XLNetTokenizerFast)),
(FlaubertConfig, (FlaubertTokenizer, None)),
(XLMConfig, (XLMTokenizer, None)),
(CTRLConfig, (CTRLTokenizer, None)),
(FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)),
(DebertaConfig, (DebertaTokenizer, None)),
(DebertaV2Config, (DebertaV2Tokenizer, None)),
(RagConfig, (RagTokenizer, None)),
(XLMProphetNetConfig, (XLMProphetNetTokenizer, None)),
(Speech2TextConfig, (Speech2TextTokenizer, None)),
(M2M100Config, (M2M100Tokenizer, None)),
(ProphetNetConfig, (ProphetNetTokenizer, None)),
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
(TapasConfig, (TapasTokenizer, None)),
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
]
)
# For tokenizers which are not directly mapped from a config
NO_CONFIG_TOKENIZER = [
BertJapaneseTokenizer,
BertweetTokenizer,
HerbertTokenizer,
HerbertTokenizerFast,
PhobertTokenizer,
BarthezTokenizer,
BarthezTokenizerFast,
MBart50Tokenizer,
MBart50TokenizerFast,
]
SLOW_TOKENIZER_MAPPING = {
k: (v[0] if v[0] is not None else v[1])
for k, v in TOKENIZER_MAPPING.items()
if (v[0] is not None or v[1] is not None)
}
def tokenizer_class_from_name(class_name: str):
all_tokenizer_classes = (
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
+ NO_CONFIG_TOKENIZER
)
for c in all_tokenizer_classes:
if c.__name__ == class_name:
return c
class AutoTokenizer:
r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
created with the :meth:`AutoTokenizer.from_pretrained` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(SLOW_TOKENIZER_MAPPING)
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r"""
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
List options
Params:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved
using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.,
``./my_model_directory/``.
- A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
single vocabulary file (like Bert or XLNet), e.g.: ``./my_model_directory/vocab.txt``. (Not
applicable to all derived classes)
inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__()`` method.
config (:class:`~transformers.PreTrainedConfig`, `optional`)
The configuration object used to dertermine the tokenizer class to instantiate.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download the model weights and configuration files and override the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
subfolder (:obj:`str`, `optional`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
facebook/rag-token-base), specify it here.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to try to load the fast version of the tokenizer.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__()`` for more details.
Examples::
>>> from transformers import AutoTokenizer
>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
>>> tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
use_fast = kwargs.pop("use_fast", True)
if config.tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config.tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
tokenizer_class_candidate = config.tokenizer_class
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
raise ValueError(
"Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)
)
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
# if model is an encoder decoder, the encoder tokenizer class is used by default
if isinstance(config, EncoderDecoderConfig):
if type(config.decoder) is not type(config.encoder): # noqa: E721
logger.warn(
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
f"config class: {config.decoder.__class}. It is not recommended to use the "
"`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
"specific tokenizer classes."
)
config = config.encoder
if type(config) in TOKENIZER_MAPPING.keys():
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
if tokenizer_class_py is not None:
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
"in order to use this tokenizer."
)
raise ValueError(
"Unrecognized configuration class {} to build an AutoTokenizer.\n"
"Model type should be one of {}.".format(
config.__class__, ", ".join(c.__name__ for c in TOKENIZER_MAPPING.keys())
)
)

Event Timeline