Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F120532317
check_tf_ops.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sat, Jul 5, 01:34
Size
3 KB
Mime Type
text/x-python
Expires
Mon, Jul 7, 01:34 (2 d)
Engine
blob
Format
Raw Data
Handle
27184998
Attached To
R11484 ADDI
check_tf_ops.py
View Options
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
json
import
os
from
tensorflow.core.protobuf.saved_model_pb2
import
SavedModel
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
REPO_PATH
=
"."
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
INTERNAL_OPS
=
[
"Assert"
,
"AssignVariableOp"
,
"EmptyTensorList"
,
"MergeV2Checkpoints"
,
"ReadVariableOp"
,
"ResourceGather"
,
"RestoreV2"
,
"SaveV2"
,
"ShardedFilename"
,
"StatefulPartitionedCall"
,
"StaticRegexFullMatch"
,
"VarHandleOp"
,
]
def
onnx_compliancy
(
saved_model_path
,
strict
,
opset
):
saved_model
=
SavedModel
()
onnx_ops
=
[]
with
open
(
os
.
path
.
join
(
REPO_PATH
,
"utils"
,
"tf_ops"
,
"onnx.json"
))
as
f
:
onnx_opsets
=
json
.
load
(
f
)[
"opsets"
]
for
i
in
range
(
1
,
opset
+
1
):
onnx_ops
.
extend
(
onnx_opsets
[
str
(
i
)])
with
open
(
saved_model_path
,
"rb"
)
as
f
:
saved_model
.
ParseFromString
(
f
.
read
())
model_op_names
=
set
()
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
for
meta_graph
in
saved_model
.
meta_graphs
:
# Add operations in the graph definition
model_op_names
.
update
(
node
.
op
for
node
in
meta_graph
.
graph_def
.
node
)
# Go through the functions in the graph definition
for
func
in
meta_graph
.
graph_def
.
library
.
function
:
# Add operations in each function
model_op_names
.
update
(
node
.
op
for
node
in
func
.
node_def
)
# Convert to list, sorted if you want
model_op_names
=
sorted
(
model_op_names
)
incompatible_ops
=
[]
for
op
in
model_op_names
:
if
op
not
in
onnx_ops
and
op
not
in
INTERNAL_OPS
:
incompatible_ops
.
append
(
op
)
if
strict
and
len
(
incompatible_ops
)
>
0
:
raise
Exception
(
f
"Found the following incompatible ops for the opset {opset}:
\n
"
+
incompatible_ops
)
elif
len
(
incompatible_ops
)
>
0
:
print
(
f
"Found the following incompatible ops for the opset {opset}:"
)
print
(
*
incompatible_ops
,
sep
=
"
\n
"
)
else
:
print
(
f
"The saved model {saved_model_path} can properly be converted with ONNX."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--saved_model_path"
,
help
=
"Path of the saved model to check (the .pb file)."
)
parser
.
add_argument
(
"--opset"
,
default
=
12
,
type
=
int
,
help
=
"The ONNX opset against which the model has to be tested."
)
parser
.
add_argument
(
"--framework"
,
choices
=
[
"onnx"
],
default
=
"onnx"
,
help
=
"Frameworks against which to test the saved model."
)
parser
.
add_argument
(
"--strict"
,
action
=
"store_true"
,
help
=
"Whether make the checking strict (raise errors) or not (raise warnings)"
)
args
=
parser
.
parse_args
()
if
args
.
framework
==
"onnx"
:
onnx_compliancy
(
args
.
saved_model_path
,
args
.
strict
,
args
.
opset
)
Event Timeline
Log In to Comment