From 29144d3a6b47577d8409b6db4065ae16479686c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Do=C5=82bniak?= Date: Tue, 15 Mar 2022 12:21:30 +0100 Subject: [PATCH] ONNX Expand extended support (#10833) --- .../models/onnx/expand_failsafe_node.prototxt | 52 +++++++++++++++++++ src/core/tests/onnx/onnx_import.in.cpp | 12 +++++ .../onnx/frontend/src/core/graph.cpp | 2 +- .../onnx/frontend/src/op/constant.cpp | 2 +- src/frontends/onnx/frontend/src/op/expand.cpp | 13 ++++- .../onnx/frontend/src/utils/common.cpp | 14 +++++ .../onnx/frontend/src/utils/common.hpp | 10 ++++ 7 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 src/core/tests/models/onnx/expand_failsafe_node.prototxt diff --git a/src/core/tests/models/onnx/expand_failsafe_node.prototxt b/src/core/tests/models/onnx/expand_failsafe_node.prototxt new file mode 100644 index 00000000000..3a0c9de1d2e --- /dev/null +++ b/src/core/tests/models/onnx/expand_failsafe_node.prototxt @@ -0,0 +1,52 @@ +ir_version: 8 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data" + input: "shape" + output: "expanded" + op_type: "Expand" + } + name: "test_graph" + initializer { + dims: 0 + data_type: 7 + name: "shape" + raw_data: "" + } + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "expanded" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/src/core/tests/onnx/onnx_import.in.cpp b/src/core/tests/onnx/onnx_import.in.cpp index fd5f04c36d7..e9765f9d91d 100644 --- a/src/core/tests/onnx/onnx_import.in.cpp +++ b/src/core/tests/onnx/onnx_import.in.cpp @@ -4736,3 +4736,15 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_unsqueeze_ai_onnx_domain_opset13) { test_case.add_expected_output(expected_output); test_case.run(); } + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_expand_failsafe_node) { + const auto function = + onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/expand_failsafe_node.onnx")); + + auto test_case = test::TestCase(function, s_device); + const auto input_data = std::vector{1.0f, 2.0f, 3.0f, 4.0f}; + test_case.add_input(input_data); + // the target shape is an empty constant so the Expand operation should not modify the input shape + test_case.add_expected_output(input_data); + test_case.run(); +} diff --git a/src/frontends/onnx/frontend/src/core/graph.cpp b/src/frontends/onnx/frontend/src/core/graph.cpp index e1fa9a4c032..c7971d840d2 100644 --- a/src/frontends/onnx/frontend/src/core/graph.cpp +++ b/src/frontends/onnx/frontend/src/core/graph.cpp @@ -76,7 +76,7 @@ Graph::Graph(const std::shared_ptr& model_proto, // invalid external data makes initializers creation impossible throw; } catch (const ngraph::ngraph_error&) { - ng_constant = default_opset::Constant::create(tensor.get_ng_type(), Shape{}, {0}); + ng_constant = ngraph::onnx_import::common::make_failsafe_constant(tensor.get_ng_type()); } initializers.emplace(initializer_tensor.name(), tensor); diff --git a/src/frontends/onnx/frontend/src/op/constant.cpp b/src/frontends/onnx/frontend/src/op/constant.cpp index e5993dfe0cb..d5ae5144d30 100644 --- a/src/frontends/onnx/frontend/src/op/constant.cpp +++ b/src/frontends/onnx/frontend/src/op/constant.cpp @@ -24,7 +24,7 @@ inline std::shared_ptr make_ng_constant_impl(const elem try { constant = std::make_shared(type, tensor.get_shape(), tensor.get_data()); } catch (const ngraph::ngraph_error&) { - constant = std::make_shared(type, Shape{}, 0); + constant = common::make_failsafe_constant(type); } return constant; diff --git a/src/frontends/onnx/frontend/src/op/expand.cpp b/src/frontends/onnx/frontend/src/op/expand.cpp index 7c6b1ab287e..f8481244440 100644 --- a/src/frontends/onnx/frontend/src/op/expand.cpp +++ b/src/frontends/onnx/frontend/src/op/expand.cpp @@ -10,6 +10,7 @@ #include "ngraph/op/broadcast.hpp" #include "ngraph/op/constant.hpp" #include "ngraph/op/multiply.hpp" +#include "utils/common.hpp" namespace ngraph { namespace onnx_import { @@ -19,7 +20,17 @@ OutputVector expand(const Node& node) { const Output data{node.get_ng_inputs().at(0)}; const Output shape{node.get_ng_inputs().at(1)}; - return {std::make_shared(data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)}; + if (common::is_failsafe_node(shape.get_node_shared_ptr())) { + // in case the "shape" input is connected to a failsafe node created in place of an invalid initializer + // the target shape should be ignored and this Expand operation should not modify its input tensor + // the Broadcast created below should be eliminated later on by an appropriate optimization pass + const auto identity_broadcast = default_opset::Constant::create(element::i64, Shape{1}, {1}); + return {std::make_shared(data, + identity_broadcast, + ngraph::op::BroadcastType::BIDIRECTIONAL)}; + } else { + return {std::make_shared(data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)}; + } } } // namespace set_1 diff --git a/src/frontends/onnx/frontend/src/utils/common.cpp b/src/frontends/onnx/frontend/src/utils/common.cpp index f1aff3a239d..914a7f8ecd9 100644 --- a/src/frontends/onnx/frontend/src/utils/common.cpp +++ b/src/frontends/onnx/frontend/src/utils/common.cpp @@ -115,6 +115,20 @@ template OutputVector handle_opset6_binary_op(const Node& template OutputVector handle_opset6_binary_op(const Node& node); template OutputVector handle_opset6_binary_op(const Node& node); +const std::string FAILSAFE_NODE = "ONNX_FAILSAFE_NODE"; + +std::shared_ptr make_failsafe_constant(const ngraph::element::Type& dtype) { + const auto failsafe_constant = default_opset::Constant::create(dtype, Shape{}, {0}); + auto& rt_info = failsafe_constant->get_rt_info(); + rt_info[FAILSAFE_NODE] = ""; + return failsafe_constant; +} + +bool is_failsafe_node(const std::shared_ptr& node) { + const auto& rt_info = node->get_rt_info(); + return rt_info.find(FAILSAFE_NODE) != rt_info.end(); +} + } // namespace common } // namespace onnx_import } // namespace ngraph diff --git a/src/frontends/onnx/frontend/src/utils/common.hpp b/src/frontends/onnx/frontend/src/utils/common.hpp index d2e49094801..43f17743886 100644 --- a/src/frontends/onnx/frontend/src/utils/common.hpp +++ b/src/frontends/onnx/frontend/src/utils/common.hpp @@ -132,6 +132,16 @@ std::unique_ptr make_unique(Args&&... args) { /// \return OutputVector with binary op template OutputVector handle_opset6_binary_op(const Node& node); + +/// \brief Creates a "dummy" constant to be used in place of an invalid initializer +/// encountered in the original model. +/// \return A scalar constant containing a single value of zero +/// marked as "failsafe" in the runtime info object +std::shared_ptr make_failsafe_constant(const ngraph::element::Type& dtype); + +/// \brief Checks the node's runtime info object and returns true if this node represents +/// a dummy failsafe node created instead of an incorrect node found in the original model +bool is_failsafe_node(const std::shared_ptr& node); } // namespace common } // namespace onnx_import } // namespace ngraph