Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F60642561
convert_albert_original_tf_checkpoint_to_pytorch.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Wed, May 1, 16:19
Size
2 KB
Mime Type
text/x-python
Expires
Fri, May 3, 16:19 (2 d)
Engine
blob
Format
Raw Data
Handle
17388941
Attached To
R11484 ADDI
convert_albert_original_tf_checkpoint_to_pytorch.py
View Options
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALBERT checkpoint."""
import
argparse
import
torch
from
transformers
import
AlbertConfig
,
AlbertForPreTraining
,
load_tf_weights_in_albert
from
transformers.utils
import
logging
logging
.
set_verbosity_info
()
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
albert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
AlbertConfig
.
from_json_file
(
albert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
AlbertForPreTraining
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_albert
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--albert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained ALBERT model.
\n
"
"This specifies the model architecture."
,
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
albert_config_file
,
args
.
pytorch_dump_path
)
Event Timeline
Log In to Comment