Fixed expanding ONNX function (#7747)
* fix expand_onnx_functions * refactor + unit test * fixed function in function case * fixed expand_onnx_functions * fixed default value of shape in ValueInfo * enable xpass model * changed MergeFrom to Swap * added xfail with missing test data * added more unit tests * styles applied * used std::rotate, review remarks * removed debug code * after offline discussion remarks * fix checking input/output names on Windows * names comparator refactor * replace regex with custom comparison * review remarks
This commit is contained in:
parent
96df1a14ce
commit
ccdd0e61d5
@ -8,6 +8,8 @@
|
||||
#include <onnx/defs/schema.h>
|
||||
#include <onnx/shape_inference/implementation.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "core/model.hpp"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
@ -30,6 +32,24 @@ ONNX_NAMESPACE::TypeProto get_input_type(std::string const& name, ONNX_NAMESPACE
|
||||
}
|
||||
return ONNX_NAMESPACE::TypeProto();
|
||||
}
|
||||
|
||||
inline void function_expand_and_remove_original_node(const ONNX_NAMESPACE::NodeProto& node,
|
||||
const ONNX_NAMESPACE::FunctionProto& func_proto,
|
||||
ONNX_NAMESPACE::GraphProto* graph,
|
||||
int current_node_idx) {
|
||||
const auto before_expand_size = graph->node().size();
|
||||
ONNX_NAMESPACE::FunctionExpandHelper(node, func_proto, *graph);
|
||||
const auto added_nodes = graph->node().size() - before_expand_size;
|
||||
|
||||
// Remove the original node which contained the function
|
||||
graph->mutable_node()->erase(graph->mutable_node()->begin() + current_node_idx);
|
||||
|
||||
// Move nodes from expanded function to position of removed node
|
||||
std::rotate(graph->mutable_node()->begin() + current_node_idx,
|
||||
graph->mutable_node()->end() - added_nodes,
|
||||
graph->mutable_node()->end());
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace transform
|
||||
} // namespace onnx_import
|
||||
@ -60,10 +80,8 @@ void ngraph::onnx_import::transform::expand_onnx_functions(ONNX_NAMESPACE::Model
|
||||
// Check if operation schema contains a function body and expand function
|
||||
if (node_op_schema->HasFunction()) {
|
||||
const auto* func_proto = node_op_schema->GetFunction();
|
||||
ONNX_NAMESPACE::FunctionExpandHelper(node, *func_proto, *graph_proto);
|
||||
|
||||
// Remove the original node which contained the function.
|
||||
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
|
||||
// Move index to the previous position because a first node of expanded function can have also function
|
||||
detail::function_expand_and_remove_original_node(node, *func_proto, graph_proto, i--);
|
||||
}
|
||||
|
||||
else if (node_op_schema->HasContextDependentFunction()) {
|
||||
@ -82,10 +100,8 @@ void ngraph::onnx_import::transform::expand_onnx_functions(ONNX_NAMESPACE::Model
|
||||
ONNX_NAMESPACE::FunctionBodyBuildContextImpl ctx(node, input_types);
|
||||
ONNX_NAMESPACE::FunctionProto func_proto;
|
||||
node_op_schema->BuildContextDependentFunction(ctx, func_proto);
|
||||
ONNX_NAMESPACE::FunctionExpandHelper(node, func_proto, *graph_proto);
|
||||
|
||||
// Remove the original node which contained the function.
|
||||
graph_proto->mutable_node()->erase(graph_proto->mutable_node()->begin() + i);
|
||||
// Move index to the previous position because a first node of expanded function can have also function
|
||||
detail::function_expand_and_remove_original_node(node, func_proto, graph_proto, i--);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -30,8 +30,6 @@ public:
|
||||
|
||||
if (onnx_tensor.has_shape()) {
|
||||
m_partial_shape = onnx_common::to_ng_shape(onnx_tensor.shape());
|
||||
} else {
|
||||
m_partial_shape = PartialShape::dynamic();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -76,7 +74,7 @@ protected:
|
||||
|
||||
private:
|
||||
const ONNX_NAMESPACE::ValueInfoProto* m_value_info_proto;
|
||||
PartialShape m_partial_shape;
|
||||
PartialShape m_partial_shape = PartialShape::dynamic();
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& outs, const ValueInfo& info) {
|
||||
|
@ -563,7 +563,8 @@ if (NGRAPH_ONNX_FRONTEND_ENABLE)
|
||||
list(APPEND SRC
|
||||
onnx/onnx_import_exceptions.cpp
|
||||
onnx/onnx_import_library.cpp
|
||||
onnx/onnx_tensor_names.cpp)
|
||||
onnx/onnx_tensor_names.cpp
|
||||
onnx/onnx_transformations.cpp)
|
||||
endif()
|
||||
|
||||
if (NGRAPH_ONNX_FRONTEND_ENABLE)
|
||||
|
@ -0,0 +1,85 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "in1"
|
||||
input: "in2"
|
||||
output: "greater_or_equal_out"
|
||||
op_type: "GreaterOrEqual"
|
||||
}
|
||||
node {
|
||||
input: "greater_or_equal_out"
|
||||
output: "cast_out"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "cast_out"
|
||||
output: "y"
|
||||
output: "y_scale"
|
||||
output: "y_zero_point"
|
||||
op_type: "DynamicQuantizeLinear"
|
||||
}
|
||||
node {
|
||||
input: "y"
|
||||
output: "abs_y"
|
||||
op_type: "Abs"
|
||||
}
|
||||
input {
|
||||
name: "in1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "in2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "abs_y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y_scale"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y_zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 12
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "in1"
|
||||
input: "in2"
|
||||
output: "greater_or_equal_out"
|
||||
op_type: "GreaterOrEqual"
|
||||
}
|
||||
node {
|
||||
input: "greater_or_equal_out"
|
||||
output: "cast_out"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "in1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "in2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cast_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 12
|
||||
}
|
@ -0,0 +1,209 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "in1"
|
||||
input: "in2"
|
||||
output: "Func_GreaterOrEqual0x5601898ec4f0O1"
|
||||
op_type: "Greater"
|
||||
}
|
||||
node {
|
||||
input: "in1"
|
||||
input: "in2"
|
||||
output: "Func_GreaterOrEqual0x5601898ec4f0O2"
|
||||
op_type: "Equal"
|
||||
}
|
||||
node {
|
||||
input: "Func_GreaterOrEqual0x5601898ec4f0O1"
|
||||
input: "Func_GreaterOrEqual0x5601898ec4f0O2"
|
||||
output: "greater_or_equal_out"
|
||||
op_type: "Or"
|
||||
}
|
||||
node {
|
||||
input: "greater_or_equal_out"
|
||||
output: "cast_out"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
data_type: 1
|
||||
float_data: 0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Q_Max"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
data_type: 1
|
||||
float_data: 255
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "cast_out"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280X_Min"
|
||||
op_type: "ReduceMin"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280X_Min"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280X_Min_Adjusted"
|
||||
op_type: "Min"
|
||||
}
|
||||
node {
|
||||
input: "cast_out"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280X_Max"
|
||||
op_type: "ReduceMax"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280X_Max"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280X_Max_Adjusted"
|
||||
op_type: "Max"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280X_Max_Adjusted"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280X_Min_Adjusted"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280X_Range"
|
||||
op_type: "Sub"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280X_Range"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Max"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Scale"
|
||||
op_type: "Div"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280X_Min_Adjusted"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Scale"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Min_Scaled"
|
||||
op_type: "Div"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Min_Scaled"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Initial_ZeroPoint_FP"
|
||||
op_type: "Sub"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Initial_ZeroPoint_FP"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Min"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Q_Max"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Clipped_ZeroPoint_FP"
|
||||
op_type: "Clip"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Clipped_ZeroPoint_FP"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Rounded_ZeroPoint_FP"
|
||||
op_type: "Round"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Rounded_ZeroPoint_FP"
|
||||
output: "Func_DynamicQuantizeLinear0x560189b38280Zeropoint"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Scale"
|
||||
output: "y_scale"
|
||||
op_type: "Identity"
|
||||
}
|
||||
node {
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Zeropoint"
|
||||
output: "y_zero_point"
|
||||
op_type: "Identity"
|
||||
}
|
||||
node {
|
||||
input: "cast_out"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Scale"
|
||||
input: "Func_DynamicQuantizeLinear0x560189b38280Zeropoint"
|
||||
output: "y"
|
||||
op_type: "QuantizeLinear"
|
||||
}
|
||||
node {
|
||||
input: "y"
|
||||
output: "abs_y"
|
||||
op_type: "Abs"
|
||||
}
|
||||
input {
|
||||
name: "in1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "in2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "abs_y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y_scale"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y_zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 12
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "in1"
|
||||
input: "in2"
|
||||
output: "Func_GreaterOrEqual0x5562c41eca70O1"
|
||||
op_type: "Greater"
|
||||
}
|
||||
node {
|
||||
input: "in1"
|
||||
input: "in2"
|
||||
output: "Func_GreaterOrEqual0x5562c41eca70O2"
|
||||
op_type: "Equal"
|
||||
}
|
||||
node {
|
||||
input: "Func_GreaterOrEqual0x5562c41eca70O1"
|
||||
input: "Func_GreaterOrEqual0x5562c41eca70O2"
|
||||
output: "greater_or_equal_out"
|
||||
op_type: "Or"
|
||||
}
|
||||
node {
|
||||
input: "greater_or_equal_out"
|
||||
output: "cast_out"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "in1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "in2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cast_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 12
|
||||
}
|
@ -0,0 +1,192 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0axes"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "x"
|
||||
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Max"
|
||||
op_type: "ReduceMax"
|
||||
attribute {
|
||||
name: "axes"
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "x"
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Max"
|
||||
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Sub"
|
||||
op_type: "Sub"
|
||||
}
|
||||
node {
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Sub"
|
||||
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Exp"
|
||||
op_type: "Exp"
|
||||
}
|
||||
node {
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Exp"
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0axes"
|
||||
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_RS"
|
||||
op_type: "ReduceSum"
|
||||
}
|
||||
node {
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Exp"
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_RS"
|
||||
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Div"
|
||||
op_type: "Div"
|
||||
}
|
||||
node {
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Div"
|
||||
output: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Log"
|
||||
op_type: "Log"
|
||||
}
|
||||
node {
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_zero"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 6
|
||||
int32_data: 0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_one"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 6
|
||||
int32_data: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0axes"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "y"
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0axes"
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0expanded_target"
|
||||
op_type: "Unsqueeze"
|
||||
}
|
||||
node {
|
||||
input: "Func_SoftmaxCrossEntropyLoss0x557617acabe0X_Log"
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0expanded_target"
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0input_gather_element"
|
||||
op_type: "GatherElements"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0input_gather_element"
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_NCdd"
|
||||
op_type: "Neg"
|
||||
}
|
||||
node {
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_NCdd"
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_zero"
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_one"
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0const_one"
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_N1dd"
|
||||
op_type: "Slice"
|
||||
}
|
||||
node {
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_N1dd"
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0axes"
|
||||
output: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_Ndd"
|
||||
op_type: "Squeeze"
|
||||
}
|
||||
node {
|
||||
input: "Func_NegativeLogLikelihoodLoss0x557617d1bba0loss_Ndd"
|
||||
output: "z"
|
||||
op_type: "ReduceMean"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "z"
|
||||
output: "cast_out"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cast_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "z"
|
||||
op_type: "SoftmaxCrossEntropyLoss"
|
||||
attribute {
|
||||
name: "reduction"
|
||||
s: "mean"
|
||||
type: STRING
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "z"
|
||||
output: "cast_out"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cast_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -393,6 +393,43 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function_dependency_to_created_subgraph) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/greater_or_equal.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>(Shape{5}, {3.f, 5.f, 3.f, 3.f, 6.f});
|
||||
test_case.add_input<float>(Shape{5}, {1.f, 4.f, 3.f, 7.f, 8.f});
|
||||
test_case.add_expected_output<int32_t>(Shape{5}, {1, 1, 1, 0, 0});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_context_dependent_function) {
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/softmax_crossentropy_consumed.onnx"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>(Shape{3, 5},
|
||||
{0.54881352186203,
|
||||
0.7151893377304077,
|
||||
0.6027633547782898,
|
||||
0.5448831915855408,
|
||||
0.42365479469299316,
|
||||
0.6458941102027893,
|
||||
0.4375872015953064,
|
||||
0.891772985458374,
|
||||
0.9636627435684204,
|
||||
0.3834415078163147,
|
||||
0.7917250394821167,
|
||||
0.5288949012756348,
|
||||
0.5680445432662964,
|
||||
0.9255966544151306,
|
||||
0.07103605568408966});
|
||||
test_case.add_input<int64_t>(Shape{3}, {1, 4, 3});
|
||||
test_case.add_expected_output<int32_t>(Shape{}, {1});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
// ############################################################################ OPERATOR TESTS
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_addmul_abc) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"));
|
||||
|
99
ngraph/test/onnx/onnx_transformations.cpp
Normal file
99
ngraph/test/onnx/onnx_transformations.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "editor.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "onnx_test_util.hpp"
|
||||
#include "util/test_control.hpp"
|
||||
|
||||
static std::string s_manifest = "${MANIFEST}";
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace onnx_editor;
|
||||
using namespace ngraph::test;
|
||||
|
||||
namespace {
|
||||
// Names of input and output names of nodes after a function expanding have names based on a node address.
|
||||
// As a result, the names are different during each tests execution.
|
||||
// It requires custom way of input/output names comparison.
|
||||
// https://github.com/onnx/onnx/blob/767f752829f83dbc9bd0a364d6138890f667fc38/onnx/defs/function.cc#L23
|
||||
bool after_func_expand_name_comp(std::string lhs, std::string rhs) {
|
||||
// it is equivalent (simplified) to (0x)?[0-9A-Fa-f]{8,} regex, but GCC 4.8 has limited support
|
||||
auto cut_hex_address = [](std::string& name) {
|
||||
auto is_hex_symbol = [](const char s) {
|
||||
if ((s >= 'a' && s <= 'f') || (s >= 'A' && s <= 'F') || (s >= '0' && s <= '9') ||
|
||||
(s == 'x')) { // if begin with "0x"
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// minimum address length (32 bit platforms)
|
||||
const auto min_address = 8;
|
||||
auto cut_begin = -1;
|
||||
auto cut_length = -1;
|
||||
|
||||
auto founded_hex = 0;
|
||||
for (int i = 0; i < name.size(); ++i) {
|
||||
if (is_hex_symbol(name[i])) {
|
||||
++founded_hex;
|
||||
if (cut_begin == -1) {
|
||||
cut_begin = i;
|
||||
}
|
||||
if (founded_hex >= min_address) {
|
||||
cut_length = founded_hex;
|
||||
}
|
||||
} else if (founded_hex < min_address) {
|
||||
cut_begin = -1;
|
||||
cut_length = -1;
|
||||
founded_hex = 0;
|
||||
}
|
||||
}
|
||||
if (cut_begin > 0 && cut_length > 0) {
|
||||
return name.erase(cut_begin, cut_length);
|
||||
}
|
||||
return name;
|
||||
};
|
||||
return cut_hex_address(lhs) == cut_hex_address(rhs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
NGRAPH_TEST(onnx_transformations, expand_function_greater_or_equal) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/greater_or_equal.onnx")};
|
||||
editor.decode(); // onnx transformations are applied
|
||||
|
||||
const auto ref_model = file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/transformations/reference/"
|
||||
"greater_or_equal_expanded.onnx");
|
||||
|
||||
const auto result = compare_onnx_models(editor.model_string(), ref_model, after_func_expand_name_comp);
|
||||
EXPECT_TRUE(result.is_ok) << result.error_message;
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_transformations, expand_function_softmax_crossentropy) {
|
||||
ONNXModelEditor editor{
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/softmax_crossentropy_consumed.onnx")};
|
||||
editor.decode(); // onnx transformations are applied
|
||||
|
||||
const auto ref_model = file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/transformations/reference/"
|
||||
"softmax_crossentropy_consumed_expanded.onnx");
|
||||
|
||||
const auto result = compare_onnx_models(editor.model_string(), ref_model, after_func_expand_name_comp);
|
||||
EXPECT_TRUE(result.is_ok) << result.error_message;
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_transformations, expand_function_dynamic_quantize_linear) {
|
||||
ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/dynamic_quantize_linear.onnx")};
|
||||
editor.decode(); // onnx transformations are applied
|
||||
|
||||
const auto ref_model = file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/transformations/reference/"
|
||||
"dynamic_quantize_linear_expanded.onnx");
|
||||
|
||||
const auto result = compare_onnx_models(editor.model_string(), ref_model, after_func_expand_name_comp);
|
||||
EXPECT_TRUE(result.is_ok) << result.error_message;
|
||||
}
|
@ -16,7 +16,9 @@ using namespace ngraph;
|
||||
using namespace ngraph::test;
|
||||
|
||||
namespace {
|
||||
ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph, const ONNX_NAMESPACE::GraphProto& ref_graph) {
|
||||
ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph,
|
||||
const ONNX_NAMESPACE::GraphProto& ref_graph,
|
||||
CompType comp) {
|
||||
if (graph.node_size() != ref_graph.node_size()) {
|
||||
return ComparisonResult::fail("The number of nodes in compared models doesn't match");
|
||||
} else {
|
||||
@ -30,14 +32,14 @@ ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph, const ON
|
||||
}
|
||||
|
||||
for (int j = 0; j < lhs.input_size(); ++j) {
|
||||
if (lhs.input(j) != rhs.input(j)) {
|
||||
if (!comp(lhs.input(j), rhs.input(j))) {
|
||||
return ComparisonResult::fail("Input names don't match for nodes at index " + std::to_string(i) +
|
||||
": " + lhs.input(j) + " vs " + rhs.input(j));
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < lhs.output_size(); ++j) {
|
||||
if (lhs.output(j) != rhs.output(j)) {
|
||||
if (!comp(lhs.output(j), rhs.output(j))) {
|
||||
return ComparisonResult::fail("Output names don't match for nodes at index " + std::to_string(i) +
|
||||
": " + lhs.output(j) + " vs " + rhs.output(j));
|
||||
}
|
||||
@ -169,7 +171,8 @@ ComparisonResult compare_initializers(const ONNX_NAMESPACE::GraphProto& graph,
|
||||
}
|
||||
|
||||
ComparisonResult compare_onnx_graphs(const ONNX_NAMESPACE::GraphProto& graph,
|
||||
const ONNX_NAMESPACE::GraphProto& ref_graph) {
|
||||
const ONNX_NAMESPACE::GraphProto& ref_graph,
|
||||
CompType comp = default_name_comparator) {
|
||||
ComparisonResult comparison = compare_inputs(graph, ref_graph);
|
||||
if (!comparison.is_ok) {
|
||||
return comparison;
|
||||
@ -185,16 +188,21 @@ ComparisonResult compare_onnx_graphs(const ONNX_NAMESPACE::GraphProto& graph,
|
||||
return comparison;
|
||||
}
|
||||
|
||||
return compare_nodes(graph, ref_graph);
|
||||
return compare_nodes(graph, ref_graph, comp);
|
||||
}
|
||||
} // namespace
|
||||
namespace ngraph {
|
||||
namespace test {
|
||||
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path) {
|
||||
|
||||
bool default_name_comparator(std::string lhs, std::string rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path, CompType comp) {
|
||||
std::stringstream model_stream{model};
|
||||
const auto model_proto = onnx_common::parse_from_istream(model_stream);
|
||||
const auto ref_model = onnx_common::parse_from_file(reference_model_path);
|
||||
return compare_onnx_graphs(model_proto.graph(), ref_model.graph());
|
||||
return compare_onnx_graphs(model_proto.graph(), ref_model.graph(), comp);
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ngraph
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
namespace ngraph {
|
||||
@ -27,7 +28,13 @@ struct ComparisonResult {
|
||||
}
|
||||
};
|
||||
|
||||
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path);
|
||||
bool default_name_comparator(std::string lhs, std::string rhs);
|
||||
|
||||
// comp is a function to compare inputs and outputs names (as default it is a usual std::string comparison)
|
||||
using CompType = std::function<bool(std::string, std::string)>;
|
||||
ComparisonResult compare_onnx_models(const std::string& model,
|
||||
const std::string& reference_model_path,
|
||||
CompType comp = default_name_comparator);
|
||||
|
||||
} // namespace test
|
||||
} // namespace ngraph
|
||||
|
@ -75,8 +75,6 @@ xfail_issue_38724 = xfail_test(reason="RuntimeError: While validating ONNX node
|
||||
"half_pixel")
|
||||
xfail_issue_38725 = xfail_test(reason="RuntimeError: While validating ONNX node '<Node(Loop): "
|
||||
"value info has no element type specified")
|
||||
xfail_issue_38726 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
|
||||
"LessOrEqual")
|
||||
xfail_issue_38732 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
|
||||
"ConvInteger")
|
||||
xfail_issue_38734 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
|
||||
|
@ -16,7 +16,6 @@ from tests import (
|
||||
xfail_issue_37957,
|
||||
xfail_issue_38084,
|
||||
xfail_issue_39669,
|
||||
xfail_issue_38726,
|
||||
xfail_issue_37973,
|
||||
xfail_issue_47430,
|
||||
xfail_issue_47495,
|
||||
@ -145,7 +144,6 @@ if len(zoo_models) > 0:
|
||||
import_xfail_list = [
|
||||
# ONNX Model Zoo
|
||||
(xfail_issue_38701, "test_onnx_model_zoo_text_machine_comprehension_bidirectional_attention_flow_model_bidaf_9_bidaf_bidaf_cpu"),
|
||||
(xfail_issue_38726, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_decoder_with_lm_head_12_t5_decoder_with_lm_head_cpu"),
|
||||
|
||||
# Model MSFT
|
||||
(xfail_issue_37957, "test_MSFT_opset10_mask_rcnn_keras_mask_rcnn_keras_cpu"),
|
||||
@ -161,6 +159,7 @@ if len(zoo_models) > 0:
|
||||
execution_xfail_list = [
|
||||
# ONNX Model Zoo
|
||||
(xfail_issue_39669, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_encoder_12_t5_encoder_cpu"),
|
||||
(xfail_issue_39669, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_decoder_with_lm_head_12_t5_decoder_with_lm_head_cpu"),
|
||||
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_mask_rcnn_model_MaskRCNN_10_mask_rcnn_R_50_FPN_1x_cpu"),
|
||||
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_faster_rcnn_model_FasterRCNN_10_faster_rcnn_R_50_FPN_1x_cpu"),
|
||||
(xfail_issue_47430, "test_onnx_model_zoo_vision_object_detection_segmentation_fcn_model_fcn_resnet50_11_fcn_resnet50_11_model_cpu"),
|
||||
|
Loading…
Reference in New Issue
Block a user