Expand ONNX functions to sub-graphs before import (#2733)

Co-authored-by: Bartosz Sledz <bartosz.sledz@intel.com>
This commit is contained in:
Michał Karzyński 2020-11-04 10:48:34 +01:00 committed by GitHub
parent df49a2b987
commit 23188e1b04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 300 additions and 73 deletions

View File

@ -236,10 +236,11 @@ message(STATUS "NGRAPH_FORWARD_CMAKE_ARGS ${NGRAPH_FORWARD_CMAKE_ARGS}")
if (LINUX)
include(GNUInstallDirs)
else()
set(CMAKE_INSTALL_BINDIR "bin")
set(CMAKE_INSTALL_INCLUDEDIR "include")
set(CMAKE_INSTALL_DOCDIR "doc")
set(CMAKE_INSTALL_LIBDIR "lib")
set(CMAKE_INSTALL_BINDIR "bin" CACHE STRING "User executables (bin)")
set(CMAKE_INSTALL_LIBDIR "lib" CACHE STRING "Object code libraries (lib)")
set(CMAKE_INSTALL_INCLUDEDIR "include" CACHE STRING "C header files (include)")
set(CMAKE_INSTALL_DOCDIR "doc" CACHE STRING "Document files (doc)")
mark_as_advanced(CMAKE_INSTALL_BINDIR CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR, CMAKE_INSTALL_DOCDIR)
endif()
if (DEFINED NGRAPH_INSTALL_PREFIX)

View File

@ -20,7 +20,7 @@ include(FetchContent)
# ONNX.proto definition version
#------------------------------------------------------------------------------
set(ONNX_VERSION 1.6.0)
set(ONNX_VERSION 1.7.0)
#------------------------------------------------------------------------------
# Download and install libonnx ...

View File

@ -29,6 +29,9 @@ namespace ngraph
{
std::string get_node_domain(const ONNX_NAMESPACE::NodeProto& node_proto);
std::int64_t get_opset_version(const ONNX_NAMESPACE::ModelProto& model_proto,
const std::string& domain);
class Model
{
public:

View File

@ -0,0 +1,74 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <onnx/onnx_pb.h>
namespace ngraph
{
namespace onnx_import
{
namespace transform
{
/// \brief Replace external_data path in tensors with full path to data file.
///
/// Paths to external data files are stored as relative to model path.
/// This transformation replaces them with a full filesystem path.
/// As a result in further processing data from external files can be read directly.
///
/// \param model_proto Protobuf message with ONNX model to transform.
/// \param model_path Filesystem path to the ONNX model file.
void update_external_data_paths(ONNX_NAMESPACE::ModelProto& model_proto,
const std::string& model_path);
static const std::vector<std::string> onnx_functions_to_expand = {
"Celu",
"DynamicQuantizeLinear",
"GreaterOrEqual",
"LessOrEqual",
"NegativeLogLikelihoodLoss",
"SoftmaxCrossEntropyLoss"};
/// \brief Replace nodes with expanded body of ONNX functions
///
/// Some ONNX operators are specified as functions, which can be expanded to
/// a subgraph or more primitive operations. This functions modifies the ONNX
/// model by replacing operations of types listed in onnx_functions_to_expand
/// with their expanded subgraphs.
///
/// \param model_proto Protobuf message with ONNX model to transform.
void expand_onnx_functions(ONNX_NAMESPACE::ModelProto& model_proto);
static const std::vector<std::string> legacy_ops_to_fixup = {
"DetectionOutput", "FakeQuantize", "GroupNorm", "Normalize", "PriorBox"};
/// \brief Add support for models with custom operators mistakenly registered in
/// "ai.onnx" domain.
///
/// Some legacy models use custom operators (listed in legacy_ops_to_fixup vector) which
/// were registered in the default ONNX domain. This function updates nodes with these
/// operations to use OPENVINO_ONNX_DOMAIN in order to process them correctly
/// in the nGraph ONNX Importer.
///
/// \param model_proto Protobuf message with ONNX model to transform.
void fixup_legacy_operators(ONNX_NAMESPACE::ModelProto& model_proto);
} // namespace transform
} // namespace onnx_import
} // namespace ngraph

View File

@ -29,6 +29,20 @@ namespace ngraph
return node_proto.has_domain() ? node_proto.domain() : "";
}
std::int64_t get_opset_version(const ONNX_NAMESPACE::ModelProto& model_proto,
const std::string& domain)
{
for (const auto& opset_import : model_proto.opset_import())
{
if (domain == opset_import.domain())
{
return opset_import.version();
}
}
throw ngraph_error("Couldn't find operator set's version for domain: " + domain + ".");
}
Model::Model(const ONNX_NAMESPACE::ModelProto& model_proto)
: m_model_proto{&model_proto}
{

View File

@ -0,0 +1,123 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx/defs/function.h>
#include <onnx/defs/schema.h>
#include "model.hpp"
#include "transform.hpp"
#include "ngraph/file_util.hpp"
#include "onnx_import/ops_bridge.hpp"
void ngraph::onnx_import::transform::expand_onnx_functions(ONNX_NAMESPACE::ModelProto& model_proto)
{
auto graph_proto = model_proto.mutable_graph();
for (int i = 0; i < graph_proto->node().size(); ++i)
{
ONNX_NAMESPACE::NodeProto node = graph_proto->node().Get(i);
// Check if node operation is one of the functions we want to expand
if (std::find(onnx_functions_to_expand.begin(),
onnx_functions_to_expand.end(),
node.op_type()) == onnx_functions_to_expand.end())
{
continue;
}
// Retrieve the operation schema from ONNX library
int opset_version = static_cast<int>(get_opset_version(model_proto, node.domain()));
const auto* schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
const auto node_op_schema =
schema_registry->GetSchema(node.op_type(), opset_version, node.domain());
// Check if operation schema found
if (!node_op_schema)
{
continue;
}
// Check if operation schema contains a function body and expand function
if (node_op_schema->HasFunction())
{
const auto* func_proto = node_op_schema->GetFunction();
ONNX_NAMESPACE::FunctionExpandHelper(node, *func_proto, *graph_proto);
// Remove the original node which contained the function.
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
}
else if (node_op_schema->HasContextDependentFunction())
{
ONNX_NAMESPACE::FunctionBodyBuildContextImpl ctx(node);
ONNX_NAMESPACE::FunctionProto func_proto;
node_op_schema->BuildContextDependentFunction(ctx, func_proto);
ONNX_NAMESPACE::FunctionExpandHelper(node, func_proto, *graph_proto);
// Remove the original node which contained the function.
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
}
}
}
void ngraph::onnx_import::transform::update_external_data_paths(
ONNX_NAMESPACE::ModelProto& model_proto, const std::string& model_path)
{
if (model_path.empty())
{
return;
}
const auto model_dir_path = file_util::get_directory(model_path);
auto graph_proto = model_proto.mutable_graph();
for (auto& initializer_tensor : *graph_proto->mutable_initializer())
{
const auto location_key_value_index = 0;
if (initializer_tensor.has_data_location() &&
initializer_tensor.data_location() ==
ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL)
{
const auto external_data_relative_path =
initializer_tensor.external_data(location_key_value_index).value();
auto external_data_full_path =
file_util::path_join(model_dir_path, external_data_relative_path);
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
file_util::convert_path_win_style(external_data_full_path);
#endif
// Set full paths to the external file
initializer_tensor.mutable_external_data(location_key_value_index)
->set_value(external_data_full_path);
}
}
}
void ngraph::onnx_import::transform::fixup_legacy_operators(ONNX_NAMESPACE::ModelProto& model_proto)
{
auto graph_proto = model_proto.mutable_graph();
for (auto& node : *graph_proto->mutable_node())
{
auto it = std::find(legacy_ops_to_fixup.begin(), legacy_ops_to_fixup.end(), node.op_type());
if (it != legacy_ops_to_fixup.end())
{
if (!node.has_domain() || node.domain().empty() || node.domain() == "ai.onnx")
{
node.set_domain(OPENVINO_ONNX_DOMAIN);
}
}
}
}

View File

@ -20,9 +20,9 @@
#include <memory>
#include "ngraph/except.hpp"
#include "ngraph/file_util.hpp"
#include "onnx_import/core/graph.hpp"
#include "onnx_import/core/model.hpp"
#include "onnx_import/core/transform.hpp"
#include "onnx_import/onnx.hpp"
#include "onnx_import/ops_bridge.hpp"
@ -74,65 +74,6 @@ namespace ngraph
} // namespace error
static const std::vector<std::string> legacy_ops_to_fixup = {
"DetectionOutput", "FakeQuantize", "GroupNorm", "Normalize", "PriorBox"};
// There are some models with custom OPs (list above) that has the default domain set.
// So in order to load the models, we need overwrite the OPs' domain to the one they're
// registered
void fixup_legacy_operators(ONNX_NAMESPACE::GraphProto* graph_proto)
{
for (auto& node : *graph_proto->mutable_node())
{
auto it = std::find(
legacy_ops_to_fixup.begin(), legacy_ops_to_fixup.end(), node.op_type());
if (it != legacy_ops_to_fixup.end())
{
if (!node.has_domain() || node.domain().empty() ||
node.domain() == "ai.onnx")
{
node.set_domain(OPENVINO_ONNX_DOMAIN);
}
}
}
}
// The paths to external data files are stored as relative to model path.
// The helper function below combines them and replaces the original relative path.
// As a result in futher processing data from external files can be read directly.
void update_external_data_paths(ONNX_NAMESPACE::ModelProto& model_proto,
const std::string& model_path)
{
if (model_path.empty())
{
return;
}
const auto model_dir_path = file_util::get_directory(model_path);
auto graph_proto = model_proto.mutable_graph();
for (auto& initializer_tensor : *graph_proto->mutable_initializer())
{
const auto location_key_value_index = 0;
if (initializer_tensor.has_data_location() &&
initializer_tensor.data_location() ==
ONNX_NAMESPACE::TensorProto_DataLocation::
TensorProto_DataLocation_EXTERNAL)
{
const auto external_data_relative_path =
initializer_tensor.external_data(location_key_value_index).value();
auto external_data_full_path =
file_util::path_join(model_dir_path, external_data_relative_path);
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
file_util::convert_path_win_style(external_data_full_path);
#endif
// Set full paths to the external file
initializer_tensor.mutable_external_data(location_key_value_index)
->set_value(external_data_full_path);
}
}
}
std::shared_ptr<Function>
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
{
@ -181,8 +122,9 @@ namespace ngraph
#endif
}
detail::fixup_legacy_operators(model_proto.mutable_graph());
detail::update_external_data_paths(model_proto, model_path);
transform::expand_onnx_functions(model_proto);
transform::fixup_legacy_operators(model_proto);
transform::update_external_data_paths(model_proto, model_path);
return detail::convert_to_ng_function(model_proto);
}

View File

@ -173,8 +173,6 @@ xfail_issue_38726 = xfail_test(reason="RuntimeError: nGraph does not support the
"LessOrEqual")
xfail_issue_38732 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"ConvInteger")
xfail_issue_38733 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"Celu")
xfail_issue_38734 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"ai.onnx.preview.training.Adam")
xfail_issue_38735 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"

View File

@ -79,7 +79,6 @@ from tests import (BACKEND_NAME,
xfail_issue_33644,
xfail_issue_33515,
xfail_issue_38732,
xfail_issue_38733,
xfail_issue_38734,
xfail_issue_38735)
@ -631,8 +630,6 @@ tests_expected_to_fail = [
(xfail_issue_38732,
"OnnxBackendNodeModelTest.test_convinteger_with_padding_cpu",
"OnnxBackendNodeModelTest.test_basic_convinteger_cpu"),
(xfail_issue_38733,
"OnnxBackendNodeModelTest.test_celu_cpu"),
(xfail_issue_38734,
"OnnxBackendNodeModelTest.test_adam_multiple_cpu",
"OnnxBackendNodeModelTest.test_adam_cpu"),

View File

@ -0,0 +1,61 @@
ir_version: 5
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
output: "y_scale"
output: "y_zero_point"
op_type: "DynamicQuantizeLinear"
}
name: "test_dynamicquantizelinear"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 6
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 6
}
}
}
}
}
output {
name: "y_scale"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "y_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
}
opset_import {
version: 11
}

View File

@ -462,6 +462,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_initializer_wo_input)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/quantization/dynamicquantizelinear.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({-1.f, -2.1f, -1.3f, -2.5f, -3.34f, -4.f});
test_case.add_expected_output<uint8_t>(Shape{6}, {191, 121, 172, 96, 42, 0});
test_case.add_expected_output<float>(Shape{}, {0.0156862754f});
test_case.add_expected_output<uint8_t>(Shape{}, {255});
test_case.run();
}
// ############################################################################ OPERATOR TESTS
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_addmul_abc)
{

View File

@ -46,7 +46,7 @@ NGRAPH_TEST(onnx, check_ir_version_support)
//
// The last step is to also update the details::onnx::contains_onnx_model_keys() function
// in the same file to make sure that prototxt format validation also covers the changes in ONNX
EXPECT_EQ(ONNX_NAMESPACE::Version::IR_VERSION, 6)
EXPECT_EQ(ONNX_NAMESPACE::Version::IR_VERSION, 7)
<< "The IR_VERSION defined in ONNX does not match the version that OpenVINO supports. "
"Please check the source code of this test for details and explanation how to proceed.";
}

View File

@ -15,6 +15,7 @@ onnx_model_quantize_linear
onnx_model_quantize_linear_zero_point
onnx_model_quantize_linear_axis_zero
onnx_model_quantize_linear_axis_negative
onnx_expand_function
# DequantizeLinear:
# C++ exception with description "Unsupported precisions!