Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F61783338
modeling_flax_auto.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Wed, May 8, 22:47
Size
8 KB
Mime Type
text/x-python
Expires
Fri, May 10, 22:47 (2 d)
Engine
blob
Format
Raw Data
Handle
17558156
Attached To
R11484 ADDI
modeling_flax_auto.py
View Options
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """
from
collections
import
OrderedDict
from
...configuration_utils
import
PretrainedConfig
from
...utils
import
logging
from
..bert.modeling_flax_bert
import
FlaxBertModel
from
..roberta.modeling_flax_roberta
import
FlaxRobertaModel
from
.configuration_auto
import
AutoConfig
,
BertConfig
,
RobertaConfig
logger
=
logging
.
get_logger
(
__name__
)
FLAX_MODEL_MAPPING
=
OrderedDict
(
[
(
RobertaConfig
,
FlaxRobertaModel
),
(
BertConfig
,
FlaxBertModel
),
]
)
class
FlaxAutoModel
(
object
):
r"""
:class:`~transformers.FlaxAutoModel` is a generic model class that will be instantiated as one of the base model
classes of the library when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or the
`FlaxAutoModel.from_config(config)` class methods.
This class cannot be instantiated using `__init__()` (throws an error).
"""
def
__init__
(
self
):
raise
EnvironmentError
(
"FlaxAutoModel is designed to be instantiated "
"using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or "
"`FlaxAutoModel.from_config(config)` methods."
)
@classmethod
def
from_config
(
cls
,
config
):
r"""
Instantiates one of the base model classes of the library from a configuration.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
- isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model
Examples::
config = BertConfig.from_pretrained('bert-base-uncased')
# Download configuration from huggingface.co and cache.
model = FlaxAutoModel.from_config(config)
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
for
config_class
,
model_class
in
FLAX_MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
(
config
)
raise
ValueError
(
f
"Unrecognized configuration class {config.__class__} "
f
"for this kind of FlaxAutoModel: {cls.__name__}.
\n
"
f
"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
)
@classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r"""
Instantiates one of the base model classes of the library from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance based on the
`model_type` property of the config object, or when it's missing, falling back to using pattern matching on the
`pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching in the
`pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
- contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) To
train the model, you should first set it back in training mode with `model.train()`
Args:
pretrained_model_name_or_path: either:
- a string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid
model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or
organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this
case, ``from_pt`` should be set to True and a configuration object should be provided as ``config``
argument.
model_args: (`optional`) Sequence of positional arguments:
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a
pretrained model), or
- the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if
they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error
messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments:
These arguments will be passed to the configuration and the model.
Examples::
model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from huggingface.co and cache.
model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True
"""
config
=
kwargs
.
pop
(
"config"
,
None
)
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
for
config_class
,
model_class
in
FLAX_MODEL_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
raise
ValueError
(
f
"Unrecognized configuration class {config.__class__} "
f
"for this kind of FlaxAutoModel: {cls.__name__}.
\n
"
f
"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
)
Event Timeline
Log In to Comment