Fixed expanding ONNX function (#7747)

* fix expand_onnx_functions

* refactor + unit test

* fixed function in function case

* fixed expand_onnx_functions

* fixed default value of shape in ValueInfo

* enable xpass model

* changed MergeFrom to Swap

* added xfail with missing test data

* added more unit tests

* styles applied

* used std::rotate, review remarks

* removed debug code

* after offline discussion remarks

* fix checking input/output names on Windows

* names comparator refactor

* replace regex with custom comparison

* review remarks
This commit is contained in:
Mateusz Bencer 2021-10-15 12:40:28 +02:00 committed by GitHub
parent 96df1a14ce
commit ccdd0e61d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 866 additions and 24 deletions

View File

@ -8,6 +8,8 @@
#include <onnx/defs/schema.h>
#include <onnx/shape_inference/implementation.h>
#include <algorithm>
#include "core/model.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
@ -30,6 +32,24 @@ ONNX_NAMESPACE::TypeProto get_input_type(std::string const& name, ONNX_NAMESPACE
}
return ONNX_NAMESPACE::TypeProto();
}
inline void function_expand_and_remove_original_node(const ONNX_NAMESPACE::NodeProto& node,
const ONNX_NAMESPACE::FunctionProto& func_proto,
ONNX_NAMESPACE::GraphProto* graph,
int current_node_idx) {
const auto before_expand_size = graph->node().size();
ONNX_NAMESPACE::FunctionExpandHelper(node, func_proto, *graph);
const auto added_nodes = graph->node().size() - before_expand_size;
// Remove the original node which contained the function
graph->mutable_node()->erase(graph->mutable_node()->begin() + current_node_idx);
// Move nodes from expanded function to position of removed node
std::rotate(graph->mutable_node()->begin() + current_node_idx,
graph->mutable_node()->end() - added_nodes,
graph->mutable_node()->end());
}
} // namespace detail
} // namespace transform
} // namespace onnx_import
@ -60,10 +80,8 @@ void ngraph::onnx_import::transform::expand_onnx_functions(ONNX_NAMESPACE::Model
// Check if operation schema contains a function body and expand function
if (node_op_schema->HasFunction()) {
const auto* func_proto = node_op_schema->GetFunction();
ONNX_NAMESPACE::FunctionExpandHelper(node, *func_proto, *graph_proto);
// Remove the original node which contained the function.
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
// Move index to the previous position because a first node of expanded function can have also function
detail::function_expand_and_remove_original_node(node, *func_proto, graph_proto, i--);
}
else if (node_op_schema->HasContextDependentFunction()) {
@ -82,10 +100,8 @@ void ngraph::onnx_import::transform::expand_onnx_functions(ONNX_NAMESPACE::Model
ONNX_NAMESPACE::FunctionBodyBuildContextImpl ctx(node, input_types);
ONNX_NAMESPACE::FunctionProto func_proto;
node_op_schema->BuildContextDependentFunction(ctx, func_proto);
ONNX_NAMESPACE::FunctionExpandHelper(node, func_proto, *graph_proto);
// Remove the original node which contained the function.
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
// Move index to the previous position because a first node of expanded function can have also function
detail::function_expand_and_remove_original_node(node, func_proto, graph_proto, i--);
}
}
}

View File

@ -30,8 +30,6 @@ public:
if (onnx_tensor.has_shape()) {
m_partial_shape = onnx_common::to_ng_shape(onnx_tensor.shape());
} else {
m_partial_shape = PartialShape::dynamic();
}
}
}
@ -76,7 +74,7 @@ protected:
private:
const ONNX_NAMESPACE::ValueInfoProto* m_value_info_proto;
PartialShape m_partial_shape;
PartialShape m_partial_shape = PartialShape::dynamic();
};
inline std::ostream& operator<<(std::ostream& outs, const ValueInfo& info) {

View File

@ -563,7 +563,8 @@ if (NGRAPH_ONNX_FRONTEND_ENABLE)
list(APPEND SRC
onnx/onnx_import_exceptions.cpp
onnx/onnx_import_library.cpp
onnx/onnx_tensor_names.cpp)
onnx/onnx_tensor_names.cpp
onnx/onnx_transformations.cpp)
endif()
if (NGRAPH_ONNX_FRONTEND_ENABLE)

View File

@ -0,0 +1,85 @@
ir_version: 6
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "in1"
input: "in2"
output: "greater_or_equal_out"
op_type: "GreaterOrEqual"
}
node {
input: "greater_or_equal_out"
output: "cast_out"
op_type: "Cast"
attribute {
name: "to"
i: 6
type: INT
}
}
node {
input: "cast_out"
output: "y"
output: "y_scale"
output: "y_zero_point"
op_type: "DynamicQuantizeLinear"
}
node {
input: "y"
output: "abs_y"
op_type: "Abs"
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
output {
name: "abs_y"
type {
tensor_type {
elem_type: 6
}
}
}
output {
name: "y_scale"
type {
tensor_type {
elem_type: 1
}
}
}
output {
name: "y_zero_point"
type {
tensor_type {
elem_type: 6
}
}
}
}
opset_import {
version: 12
}

View File

@ -0,0 +1,57 @@
ir_version: 6
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "in1"
input: "in2"
output: "greater_or_equal_out"
op_type: "GreaterOrEqual"
}
node {
input: "greater_or_equal_out"
output: "cast_out"
op_type: "Cast"
attribute {
name: "to"
i: 6
type: INT
}
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
output {
name: "cast_out"
type {
tensor_type {
elem_type: 6
}
}
}
}
opset_import {
version: 12
}

View File

@ -0,0 +1,209 @@
ir_version: 6
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "in1"
input: "in2"
output: "Func_GreaterOrEqual0x5601898ec4f0O1"
op_type: "Greater"
}
node {
input: "in1"
input: "in2"
output: "Func_GreaterOrEqual0x5601898ec4f0O2"
op_type: "Equal"
}
node {
input: "Func_GreaterOrEqual0x5601898ec4f0O1"
input: "Func_GreaterOrEqual0x5601898ec4f0O2"
output: "greater_or_equal_out"
op_type: "Or"
}
node {
input: "greater_or_equal_out"
output: "cast_out"
op_type: "Cast"
attribute {
name: "to"
i: 6
type: INT
}
}
node {
output: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 1
float_data: 0
}
type: TENSOR
}
}
node {
output: "Func_DynamicQuantizeLinear0x560189b38280Q_Max"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 1
float_data: 255
}
type: TENSOR
}
}
node {
input: "cast_out"
output: "Func_DynamicQuantizeLinear0x560189b38280X_Min"
op_type: "ReduceMin"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280X_Min"
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
output: "Func_DynamicQuantizeLinear0x560189b38280X_Min_Adjusted"
op_type: "Min"
}
node {
input: "cast_out"
output: "Func_DynamicQuantizeLinear0x560189b38280X_Max"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280X_Max"
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
output: "Func_DynamicQuantizeLinear0x560189b38280X_Max_Adjusted"
op_type: "Max"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280X_Max_Adjusted"
input: "Func_DynamicQuantizeLinear0x560189b38280X_Min_Adjusted"
output: "Func_DynamicQuantizeLinear0x560189b38280X_Range"
op_type: "Sub"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280X_Range"
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Max"
output: "Func_DynamicQuantizeLinear0x560189b38280Scale"
op_type: "Div"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280X_Min_Adjusted"
input: "Func_DynamicQuantizeLinear0x560189b38280Scale"
output: "Func_DynamicQuantizeLinear0x560189b38280Min_Scaled"
op_type: "Div"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
input: "Func_DynamicQuantizeLinear0x560189b38280Min_Scaled"
output: "Func_DynamicQuantizeLinear0x560189b38280Initial_ZeroPoint_FP"
op_type: "Sub"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280Initial_ZeroPoint_FP"
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Max"
output: "Func_DynamicQuantizeLinear0x560189b38280Clipped_ZeroPoint_FP"
op_type: "Clip"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280Clipped_ZeroPoint_FP"
output: "Func_DynamicQuantizeLinear0x560189b38280Rounded_ZeroPoint_FP"
op_type: "Round"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280Rounded_ZeroPoint_FP"
output: "Func_DynamicQuantizeLinear0x560189b38280Zeropoint"
op_type: "Cast"
attribute {
name: "to"
i: 2
type: INT
}
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280Scale"
output: "y_scale"
op_type: "Identity"
}
node {
input: "Func_DynamicQuantizeLinear0x560189b38280Zeropoint"
output: "y_zero_point"
op_type: "Identity"
}
node {
input: "cast_out"
input: "Func_DynamicQuantizeLinear0x560189b38280Scale"
input: "Func_DynamicQuantizeLinear0x560189b38280Zeropoint"
output: "y"
op_type: "QuantizeLinear"
}
node {
input: "y"
output: "abs_y"
op_type: "Abs"
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
output {
name: "abs_y"
type {
tensor_type {
elem_type: 6
}
}
}
output {
name: "y_scale"
type {
tensor_type {
elem_type: 1
}
}
}
output {
name: "y_zero_point"
type {
tensor_type {
elem_type: 6
}
}
}
}
opset_import {
version: 12
}

View File

@ -0,0 +1,69 @@
ir_version: 6
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "in1"
input: "in2"
output: "Func_GreaterOrEqual0x5562c41eca70O1"
op_type: "Greater"
}
node {
input: "in1"
input: "in2"
output: "Func_GreaterOrEqual0x5562c41eca70O2"
op_type: "Equal"
}
node {
input: "Func_GreaterOrEqual0x5562c41eca70O1"
input: "Func_GreaterOrEqual0x5562c41eca70O2"
output: "greater_or_equal_out"
op_type: "Or"
}
node {
input: "greater_or_equal_out"
output: "cast_out"
op_type: "Cast"
attribute {
name: "to"
i: 6
type: INT
}
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
output {
name: "cast_out"
type {
tensor_type {
elem_type: 6
}
}
}
}
opset_import {
version: 12
}

View File

@ -0,0 +1,192 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0axes"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
int64_data: 1
}
type: TENSOR
}
}
node {
input: "x"
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Max"
op_type: "ReduceMax"
attribute {
name: "axes"
ints: 1
type: INTS
}
}
node {
input: "x"
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Max"
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Sub"
op_type: "Sub"
}
node {
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Sub"
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Exp"
op_type: "Exp"
}
node {
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Exp"
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0axes"
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_RS"
op_type: "ReduceSum"
}
node {
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Exp"
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_RS"
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Div"
op_type: "Div"
}
node {
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Div"
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Log"
op_type: "Log"
}
node {
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_zero"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 6
int32_data: 0
}
type: TENSOR
}
}
node {
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_one"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 6
int32_data: 1
}
type: TENSOR
}
}
node {
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0axes"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
int64_data: 1
}
type: TENSOR
}
}
node {
input: "y"
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0axes"
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0expanded_target"
op_type: "Unsqueeze"
}
node {
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Log"
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0expanded_target"
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0input_gather_element"
op_type: "GatherElements"
attribute {
name: "axis"
i: 1
type: INT
}
}
node {
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0input_gather_element"
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_NCdd"
op_type: "Neg"
}
node {
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_NCdd"
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_zero"
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_one"
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_one"
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_N1dd"
op_type: "Slice"
}
node {
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_N1dd"
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0axes"
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_Ndd"
op_type: "Squeeze"
}
node {
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_Ndd"
output: "z"
op_type: "ReduceMean"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
node {
input: "z"
output: "cast_out"
op_type: "Cast"
attribute {
name: "to"
i: 6
type: INT
}
}
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "y"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "cast_out"
type {
tensor_type {
elem_type: 6
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,67 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "y"
output: "z"
op_type: "SoftmaxCrossEntropyLoss"
attribute {
name: "reduction"
s: "mean"
type: STRING
}
}
node {
input: "z"
output: "cast_out"
op_type: "Cast"
attribute {
name: "to"
i: 6
type: INT
}
}
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "y"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "cast_out"
type {
tensor_type {
elem_type: 6
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -393,6 +393,43 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function) {
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function_dependency_to_created_subgraph) {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/greater_or_equal.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5}, {3.f, 5.f, 3.f, 3.f, 6.f});
test_case.add_input<float>(Shape{5}, {1.f, 4.f, 3.f, 7.f, 8.f});
test_case.add_expected_output<int32_t>(Shape{5}, {1, 1, 1, 0, 0});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_context_dependent_function) {
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/softmax_crossentropy_consumed.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{3, 5},
{0.54881352186203,
0.7151893377304077,
0.6027633547782898,
0.5448831915855408,
0.42365479469299316,
0.6458941102027893,
0.4375872015953064,
0.891772985458374,
0.9636627435684204,
0.3834415078163147,
0.7917250394821167,
0.5288949012756348,
0.5680445432662964,
0.9255966544151306,
0.07103605568408966});
test_case.add_input<int64_t>(Shape{3}, {1, 4, 3});
test_case.add_expected_output<int32_t>(Shape{}, {1});
test_case.run();
}
// ############################################################################ OPERATOR TESTS
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_addmul_abc) {
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"));

View File

@ -0,0 +1,99 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "editor.hpp"
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "onnx_test_util.hpp"
#include "util/test_control.hpp"
static std::string s_manifest = "${MANIFEST}";
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace ngraph;
using namespace onnx_editor;
using namespace ngraph::test;
namespace {
// Names of input and output names of nodes after a function expanding have names based on a node address.
// As a result, the names are different during each tests execution.
// It requires custom way of input/output names comparison.
// https://github.com/onnx/onnx/blob/767f752829f83dbc9bd0a364d6138890f667fc38/onnx/defs/function.cc#L23
bool after_func_expand_name_comp(std::string lhs, std::string rhs) {
// it is equivalent (simplified) to (0x)?[0-9A-Fa-f]{8,} regex, but GCC 4.8 has limited support
auto cut_hex_address = [](std::string& name) {
auto is_hex_symbol = [](const char s) {
if ((s >= 'a' && s <= 'f') || (s >= 'A' && s <= 'F') || (s >= '0' && s <= '9') ||
(s == 'x')) { // if begin with "0x"
return true;
}
return false;
};
// minimum address length (32 bit platforms)
const auto min_address = 8;
auto cut_begin = -1;
auto cut_length = -1;
auto founded_hex = 0;
for (int i = 0; i < name.size(); ++i) {
if (is_hex_symbol(name[i])) {
++founded_hex;
if (cut_begin == -1) {
cut_begin = i;
}
if (founded_hex >= min_address) {
cut_length = founded_hex;
}
} else if (founded_hex < min_address) {
cut_begin = -1;
cut_length = -1;
founded_hex = 0;
}
}
if (cut_begin > 0 && cut_length > 0) {
return name.erase(cut_begin, cut_length);
}
return name;
};
return cut_hex_address(lhs) == cut_hex_address(rhs);
}
} // namespace
NGRAPH_TEST(onnx_transformations, expand_function_greater_or_equal) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/greater_or_equal.onnx")};
editor.decode(); // onnx transformations are applied
const auto ref_model = file_util::path_join(SERIALIZED_ZOO,
"onnx/transformations/reference/"
"greater_or_equal_expanded.onnx");
const auto result = compare_onnx_models(editor.model_string(), ref_model, after_func_expand_name_comp);
EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_transformations, expand_function_softmax_crossentropy) {
ONNXModelEditor editor{
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/softmax_crossentropy_consumed.onnx")};
editor.decode(); // onnx transformations are applied
const auto ref_model = file_util::path_join(SERIALIZED_ZOO,
"onnx/transformations/reference/"
"softmax_crossentropy_consumed_expanded.onnx");
const auto result = compare_onnx_models(editor.model_string(), ref_model, after_func_expand_name_comp);
EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_transformations, expand_function_dynamic_quantize_linear) {
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/dynamic_quantize_linear.onnx")};
editor.decode(); // onnx transformations are applied
const auto ref_model = file_util::path_join(SERIALIZED_ZOO,
"onnx/transformations/reference/"
"dynamic_quantize_linear_expanded.onnx");
const auto result = compare_onnx_models(editor.model_string(), ref_model, after_func_expand_name_comp);
EXPECT_TRUE(result.is_ok) << result.error_message;
}

View File

@ -16,7 +16,9 @@ using namespace ngraph;
using namespace ngraph::test;
namespace {
ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph, const ONNX_NAMESPACE::GraphProto& ref_graph) {
ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph,
CompType comp) {
if (graph.node_size() != ref_graph.node_size()) {
return ComparisonResult::fail("The number of nodes in compared models doesn't match");
} else {
@ -30,14 +32,14 @@ ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph, const ON
}
for (int j = 0; j < lhs.input_size(); ++j) {
if (lhs.input(j) != rhs.input(j)) {
if (!comp(lhs.input(j), rhs.input(j))) {
return ComparisonResult::fail("Input names don't match for nodes at index " + std::to_string(i) +
": " + lhs.input(j) + " vs " + rhs.input(j));
}
}
for (int j = 0; j < lhs.output_size(); ++j) {
if (lhs.output(j) != rhs.output(j)) {
if (!comp(lhs.output(j), rhs.output(j))) {
return ComparisonResult::fail("Output names don't match for nodes at index " + std::to_string(i) +
": " + lhs.output(j) + " vs " + rhs.output(j));
}
@ -169,7 +171,8 @@ ComparisonResult compare_initializers(const ONNX_NAMESPACE::GraphProto& graph,
}
ComparisonResult compare_onnx_graphs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph) {
const ONNX_NAMESPACE::GraphProto& ref_graph,
CompType comp = default_name_comparator) {
ComparisonResult comparison = compare_inputs(graph, ref_graph);
if (!comparison.is_ok) {
return comparison;
@ -185,16 +188,21 @@ ComparisonResult compare_onnx_graphs(const ONNX_NAMESPACE::GraphProto& graph,
return comparison;
}
return compare_nodes(graph, ref_graph);
return compare_nodes(graph, ref_graph, comp);
}
} // namespace
namespace ngraph {
namespace test {
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path) {
bool default_name_comparator(std::string lhs, std::string rhs) {
return lhs == rhs;
}
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path, CompType comp) {
std::stringstream model_stream{model};
const auto model_proto = onnx_common::parse_from_istream(model_stream);
const auto ref_model = onnx_common::parse_from_file(reference_model_path);
return compare_onnx_graphs(model_proto.graph(), ref_model.graph());
return compare_onnx_graphs(model_proto.graph(), ref_model.graph(), comp);
}
} // namespace test
} // namespace ngraph

View File

@ -4,6 +4,7 @@
#pragma once
#include <functional>
#include <string>
namespace ngraph {
@ -27,7 +28,13 @@ struct ComparisonResult {
}
};
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path);
bool default_name_comparator(std::string lhs, std::string rhs);
// comp is a function to compare inputs and outputs names (as default it is a usual std::string comparison)
using CompType = std::function<bool(std::string, std::string)>;
ComparisonResult compare_onnx_models(const std::string& model,
const std::string& reference_model_path,
CompType comp = default_name_comparator);
} // namespace test
} // namespace ngraph

View File

@ -75,8 +75,6 @@ xfail_issue_38724 = xfail_test(reason="RuntimeError: While validating ONNX node
"half_pixel")
xfail_issue_38725 = xfail_test(reason="RuntimeError: While validating ONNX node '<Node(Loop): "
"value info has no element type specified")
xfail_issue_38726 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
"LessOrEqual")
xfail_issue_38732 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
"ConvInteger")
xfail_issue_38734 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "

View File

@ -16,7 +16,6 @@ from tests import (
xfail_issue_37957,
xfail_issue_38084,
xfail_issue_39669,
xfail_issue_38726,
xfail_issue_37973,
xfail_issue_47430,
xfail_issue_47495,
@ -145,7 +144,6 @@ if len(zoo_models) > 0:
import_xfail_list = [
# ONNX Model Zoo
(xfail_issue_38701, "test_onnx_model_zoo_text_machine_comprehension_bidirectional_attention_flow_model_bidaf_9_bidaf_bidaf_cpu"),
(xfail_issue_38726, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_decoder_with_lm_head_12_t5_decoder_with_lm_head_cpu"),
# Model MSFT
(xfail_issue_37957, "test_MSFT_opset10_mask_rcnn_keras_mask_rcnn_keras_cpu"),
@ -161,6 +159,7 @@ if len(zoo_models) > 0:
execution_xfail_list = [
# ONNX Model Zoo
(xfail_issue_39669, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_encoder_12_t5_encoder_cpu"),
(xfail_issue_39669, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_decoder_with_lm_head_12_t5_decoder_with_lm_head_cpu"),
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_mask_rcnn_model_MaskRCNN_10_mask_rcnn_R_50_FPN_1x_cpu"),
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_faster_rcnn_model_FasterRCNN_10_faster_rcnn_R_50_FPN_1x_cpu"),
(xfail_issue_47430, "test_onnx_model_zoo_vision_object_detection_segmentation_fcn_model_fcn_resnet50_11_fcn_resnet50_11_model_cpu"),