Expand ONNX functions to sub-graphs before import (#2733)
Co-authored-by: Bartosz Sledz <bartosz.sledz@intel.com>
This commit is contained in:
parent
df49a2b987
commit
23188e1b04
@ -236,10 +236,11 @@ message(STATUS "NGRAPH_FORWARD_CMAKE_ARGS ${NGRAPH_FORWARD_CMAKE_ARGS}")
|
||||
if (LINUX)
|
||||
include(GNUInstallDirs)
|
||||
else()
|
||||
set(CMAKE_INSTALL_BINDIR "bin")
|
||||
set(CMAKE_INSTALL_INCLUDEDIR "include")
|
||||
set(CMAKE_INSTALL_DOCDIR "doc")
|
||||
set(CMAKE_INSTALL_LIBDIR "lib")
|
||||
set(CMAKE_INSTALL_BINDIR "bin" CACHE STRING "User executables (bin)")
|
||||
set(CMAKE_INSTALL_LIBDIR "lib" CACHE STRING "Object code libraries (lib)")
|
||||
set(CMAKE_INSTALL_INCLUDEDIR "include" CACHE STRING "C header files (include)")
|
||||
set(CMAKE_INSTALL_DOCDIR "doc" CACHE STRING "Document files (doc)")
|
||||
mark_as_advanced(CMAKE_INSTALL_BINDIR CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR, CMAKE_INSTALL_DOCDIR)
|
||||
endif()
|
||||
|
||||
if (DEFINED NGRAPH_INSTALL_PREFIX)
|
||||
|
@ -20,7 +20,7 @@ include(FetchContent)
|
||||
# ONNX.proto definition version
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
set(ONNX_VERSION 1.6.0)
|
||||
set(ONNX_VERSION 1.7.0)
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Download and install libonnx ...
|
||||
|
@ -29,6 +29,9 @@ namespace ngraph
|
||||
{
|
||||
std::string get_node_domain(const ONNX_NAMESPACE::NodeProto& node_proto);
|
||||
|
||||
std::int64_t get_opset_version(const ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const std::string& domain);
|
||||
|
||||
class Model
|
||||
{
|
||||
public:
|
||||
|
@ -0,0 +1,74 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <onnx/onnx_pb.h>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
namespace transform
|
||||
{
|
||||
/// \brief Replace external_data path in tensors with full path to data file.
|
||||
///
|
||||
/// Paths to external data files are stored as relative to model path.
|
||||
/// This transformation replaces them with a full filesystem path.
|
||||
/// As a result in further processing data from external files can be read directly.
|
||||
///
|
||||
/// \param model_proto Protobuf message with ONNX model to transform.
|
||||
/// \param model_path Filesystem path to the ONNX model file.
|
||||
void update_external_data_paths(ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const std::string& model_path);
|
||||
|
||||
static const std::vector<std::string> onnx_functions_to_expand = {
|
||||
"Celu",
|
||||
"DynamicQuantizeLinear",
|
||||
"GreaterOrEqual",
|
||||
"LessOrEqual",
|
||||
"NegativeLogLikelihoodLoss",
|
||||
"SoftmaxCrossEntropyLoss"};
|
||||
|
||||
/// \brief Replace nodes with expanded body of ONNX functions
|
||||
///
|
||||
/// Some ONNX operators are specified as functions, which can be expanded to
|
||||
/// a subgraph or more primitive operations. This functions modifies the ONNX
|
||||
/// model by replacing operations of types listed in onnx_functions_to_expand
|
||||
/// with their expanded subgraphs.
|
||||
///
|
||||
/// \param model_proto Protobuf message with ONNX model to transform.
|
||||
void expand_onnx_functions(ONNX_NAMESPACE::ModelProto& model_proto);
|
||||
|
||||
static const std::vector<std::string> legacy_ops_to_fixup = {
|
||||
"DetectionOutput", "FakeQuantize", "GroupNorm", "Normalize", "PriorBox"};
|
||||
|
||||
/// \brief Add support for models with custom operators mistakenly registered in
|
||||
/// "ai.onnx" domain.
|
||||
///
|
||||
/// Some legacy models use custom operators (listed in legacy_ops_to_fixup vector) which
|
||||
/// were registered in the default ONNX domain. This function updates nodes with these
|
||||
/// operations to use OPENVINO_ONNX_DOMAIN in order to process them correctly
|
||||
/// in the nGraph ONNX Importer.
|
||||
///
|
||||
/// \param model_proto Protobuf message with ONNX model to transform.
|
||||
void fixup_legacy_operators(ONNX_NAMESPACE::ModelProto& model_proto);
|
||||
|
||||
} // namespace transform
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
@ -29,6 +29,20 @@ namespace ngraph
|
||||
return node_proto.has_domain() ? node_proto.domain() : "";
|
||||
}
|
||||
|
||||
std::int64_t get_opset_version(const ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const std::string& domain)
|
||||
{
|
||||
for (const auto& opset_import : model_proto.opset_import())
|
||||
{
|
||||
if (domain == opset_import.domain())
|
||||
{
|
||||
return opset_import.version();
|
||||
}
|
||||
}
|
||||
|
||||
throw ngraph_error("Couldn't find operator set's version for domain: " + domain + ".");
|
||||
}
|
||||
|
||||
Model::Model(const ONNX_NAMESPACE::ModelProto& model_proto)
|
||||
: m_model_proto{&model_proto}
|
||||
{
|
||||
|
123
ngraph/frontend/onnx_import/src/core/transform.cpp
Normal file
123
ngraph/frontend/onnx_import/src/core/transform.cpp
Normal file
@ -0,0 +1,123 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <onnx/defs/function.h>
|
||||
#include <onnx/defs/schema.h>
|
||||
|
||||
#include "model.hpp"
|
||||
#include "transform.hpp"
|
||||
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "onnx_import/ops_bridge.hpp"
|
||||
|
||||
void ngraph::onnx_import::transform::expand_onnx_functions(ONNX_NAMESPACE::ModelProto& model_proto)
|
||||
{
|
||||
auto graph_proto = model_proto.mutable_graph();
|
||||
|
||||
for (int i = 0; i < graph_proto->node().size(); ++i)
|
||||
{
|
||||
ONNX_NAMESPACE::NodeProto node = graph_proto->node().Get(i);
|
||||
|
||||
// Check if node operation is one of the functions we want to expand
|
||||
if (std::find(onnx_functions_to_expand.begin(),
|
||||
onnx_functions_to_expand.end(),
|
||||
node.op_type()) == onnx_functions_to_expand.end())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Retrieve the operation schema from ONNX library
|
||||
int opset_version = static_cast<int>(get_opset_version(model_proto, node.domain()));
|
||||
const auto* schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
|
||||
const auto node_op_schema =
|
||||
schema_registry->GetSchema(node.op_type(), opset_version, node.domain());
|
||||
|
||||
// Check if operation schema found
|
||||
if (!node_op_schema)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if operation schema contains a function body and expand function
|
||||
if (node_op_schema->HasFunction())
|
||||
{
|
||||
const auto* func_proto = node_op_schema->GetFunction();
|
||||
ONNX_NAMESPACE::FunctionExpandHelper(node, *func_proto, *graph_proto);
|
||||
|
||||
// Remove the original node which contained the function.
|
||||
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
|
||||
}
|
||||
|
||||
else if (node_op_schema->HasContextDependentFunction())
|
||||
{
|
||||
ONNX_NAMESPACE::FunctionBodyBuildContextImpl ctx(node);
|
||||
ONNX_NAMESPACE::FunctionProto func_proto;
|
||||
node_op_schema->BuildContextDependentFunction(ctx, func_proto);
|
||||
ONNX_NAMESPACE::FunctionExpandHelper(node, func_proto, *graph_proto);
|
||||
|
||||
// Remove the original node which contained the function.
|
||||
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::onnx_import::transform::update_external_data_paths(
|
||||
ONNX_NAMESPACE::ModelProto& model_proto, const std::string& model_path)
|
||||
{
|
||||
if (model_path.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
const auto model_dir_path = file_util::get_directory(model_path);
|
||||
auto graph_proto = model_proto.mutable_graph();
|
||||
for (auto& initializer_tensor : *graph_proto->mutable_initializer())
|
||||
{
|
||||
const auto location_key_value_index = 0;
|
||||
if (initializer_tensor.has_data_location() &&
|
||||
initializer_tensor.data_location() ==
|
||||
ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL)
|
||||
{
|
||||
const auto external_data_relative_path =
|
||||
initializer_tensor.external_data(location_key_value_index).value();
|
||||
auto external_data_full_path =
|
||||
file_util::path_join(model_dir_path, external_data_relative_path);
|
||||
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
file_util::convert_path_win_style(external_data_full_path);
|
||||
#endif
|
||||
|
||||
// Set full paths to the external file
|
||||
initializer_tensor.mutable_external_data(location_key_value_index)
|
||||
->set_value(external_data_full_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::onnx_import::transform::fixup_legacy_operators(ONNX_NAMESPACE::ModelProto& model_proto)
|
||||
{
|
||||
auto graph_proto = model_proto.mutable_graph();
|
||||
for (auto& node : *graph_proto->mutable_node())
|
||||
{
|
||||
auto it = std::find(legacy_ops_to_fixup.begin(), legacy_ops_to_fixup.end(), node.op_type());
|
||||
if (it != legacy_ops_to_fixup.end())
|
||||
{
|
||||
if (!node.has_domain() || node.domain().empty() || node.domain() == "ai.onnx")
|
||||
{
|
||||
node.set_domain(OPENVINO_ONNX_DOMAIN);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -20,9 +20,9 @@
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "onnx_import/core/graph.hpp"
|
||||
#include "onnx_import/core/model.hpp"
|
||||
#include "onnx_import/core/transform.hpp"
|
||||
#include "onnx_import/onnx.hpp"
|
||||
#include "onnx_import/ops_bridge.hpp"
|
||||
|
||||
@ -74,65 +74,6 @@ namespace ngraph
|
||||
|
||||
} // namespace error
|
||||
|
||||
static const std::vector<std::string> legacy_ops_to_fixup = {
|
||||
"DetectionOutput", "FakeQuantize", "GroupNorm", "Normalize", "PriorBox"};
|
||||
|
||||
// There are some models with custom OPs (list above) that has the default domain set.
|
||||
// So in order to load the models, we need overwrite the OPs' domain to the one they're
|
||||
// registered
|
||||
void fixup_legacy_operators(ONNX_NAMESPACE::GraphProto* graph_proto)
|
||||
{
|
||||
for (auto& node : *graph_proto->mutable_node())
|
||||
{
|
||||
auto it = std::find(
|
||||
legacy_ops_to_fixup.begin(), legacy_ops_to_fixup.end(), node.op_type());
|
||||
if (it != legacy_ops_to_fixup.end())
|
||||
{
|
||||
if (!node.has_domain() || node.domain().empty() ||
|
||||
node.domain() == "ai.onnx")
|
||||
{
|
||||
node.set_domain(OPENVINO_ONNX_DOMAIN);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The paths to external data files are stored as relative to model path.
|
||||
// The helper function below combines them and replaces the original relative path.
|
||||
// As a result in futher processing data from external files can be read directly.
|
||||
void update_external_data_paths(ONNX_NAMESPACE::ModelProto& model_proto,
|
||||
const std::string& model_path)
|
||||
{
|
||||
if (model_path.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
const auto model_dir_path = file_util::get_directory(model_path);
|
||||
auto graph_proto = model_proto.mutable_graph();
|
||||
for (auto& initializer_tensor : *graph_proto->mutable_initializer())
|
||||
{
|
||||
const auto location_key_value_index = 0;
|
||||
if (initializer_tensor.has_data_location() &&
|
||||
initializer_tensor.data_location() ==
|
||||
ONNX_NAMESPACE::TensorProto_DataLocation::
|
||||
TensorProto_DataLocation_EXTERNAL)
|
||||
{
|
||||
const auto external_data_relative_path =
|
||||
initializer_tensor.external_data(location_key_value_index).value();
|
||||
auto external_data_full_path =
|
||||
file_util::path_join(model_dir_path, external_data_relative_path);
|
||||
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
file_util::convert_path_win_style(external_data_full_path);
|
||||
#endif
|
||||
|
||||
// Set full paths to the external file
|
||||
initializer_tensor.mutable_external_data(location_key_value_index)
|
||||
->set_value(external_data_full_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Function>
|
||||
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
|
||||
{
|
||||
@ -181,8 +122,9 @@ namespace ngraph
|
||||
#endif
|
||||
}
|
||||
|
||||
detail::fixup_legacy_operators(model_proto.mutable_graph());
|
||||
detail::update_external_data_paths(model_proto, model_path);
|
||||
transform::expand_onnx_functions(model_proto);
|
||||
transform::fixup_legacy_operators(model_proto);
|
||||
transform::update_external_data_paths(model_proto, model_path);
|
||||
|
||||
return detail::convert_to_ng_function(model_proto);
|
||||
}
|
||||
|
@ -173,8 +173,6 @@ xfail_issue_38726 = xfail_test(reason="RuntimeError: nGraph does not support the
|
||||
"LessOrEqual")
|
||||
xfail_issue_38732 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
"ConvInteger")
|
||||
xfail_issue_38733 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
"Celu")
|
||||
xfail_issue_38734 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
"ai.onnx.preview.training.Adam")
|
||||
xfail_issue_38735 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
|
@ -79,7 +79,6 @@ from tests import (BACKEND_NAME,
|
||||
xfail_issue_33644,
|
||||
xfail_issue_33515,
|
||||
xfail_issue_38732,
|
||||
xfail_issue_38733,
|
||||
xfail_issue_38734,
|
||||
xfail_issue_38735)
|
||||
|
||||
@ -631,8 +630,6 @@ tests_expected_to_fail = [
|
||||
(xfail_issue_38732,
|
||||
"OnnxBackendNodeModelTest.test_convinteger_with_padding_cpu",
|
||||
"OnnxBackendNodeModelTest.test_basic_convinteger_cpu"),
|
||||
(xfail_issue_38733,
|
||||
"OnnxBackendNodeModelTest.test_celu_cpu"),
|
||||
(xfail_issue_38734,
|
||||
"OnnxBackendNodeModelTest.test_adam_multiple_cpu",
|
||||
"OnnxBackendNodeModelTest.test_adam_cpu"),
|
||||
|
@ -0,0 +1,61 @@
|
||||
ir_version: 5
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
output: "y"
|
||||
output: "y_scale"
|
||||
output: "y_zero_point"
|
||||
op_type: "DynamicQuantizeLinear"
|
||||
}
|
||||
name: "test_dynamicquantizelinear"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y_scale"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y_zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -462,6 +462,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_initializer_wo_input)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/quantization/dynamicquantizelinear.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({-1.f, -2.1f, -1.3f, -2.5f, -3.34f, -4.f});
|
||||
test_case.add_expected_output<uint8_t>(Shape{6}, {191, 121, 172, 96, 42, 0});
|
||||
test_case.add_expected_output<float>(Shape{}, {0.0156862754f});
|
||||
test_case.add_expected_output<uint8_t>(Shape{}, {255});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
// ############################################################################ OPERATOR TESTS
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_addmul_abc)
|
||||
{
|
||||
|
@ -46,7 +46,7 @@ NGRAPH_TEST(onnx, check_ir_version_support)
|
||||
//
|
||||
// The last step is to also update the details::onnx::contains_onnx_model_keys() function
|
||||
// in the same file to make sure that prototxt format validation also covers the changes in ONNX
|
||||
EXPECT_EQ(ONNX_NAMESPACE::Version::IR_VERSION, 6)
|
||||
EXPECT_EQ(ONNX_NAMESPACE::Version::IR_VERSION, 7)
|
||||
<< "The IR_VERSION defined in ONNX does not match the version that OpenVINO supports. "
|
||||
"Please check the source code of this test for details and explanation how to proceed.";
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ onnx_model_quantize_linear
|
||||
onnx_model_quantize_linear_zero_point
|
||||
onnx_model_quantize_linear_axis_zero
|
||||
onnx_model_quantize_linear_axis_negative
|
||||
onnx_expand_function
|
||||
|
||||
# DequantizeLinear:
|
||||
# C++ exception with description "Unsupported precisions!
|
||||
|
Loading…
Reference in New Issue
Block a user