ONNX Expand extended support (#10833)

This commit is contained in:
Tomasz Dołbniak 2022-03-15 12:21:30 +01:00 committed by GitHub
parent 840e622da5
commit 29144d3a6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 102 additions and 3 deletions

View File

@ -0,0 +1,52 @@
ir_version: 8
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
input: "shape"
output: "expanded"
op_type: "Expand"
}
name: "test_graph"
initializer {
dims: 0
data_type: 7
name: "shape"
raw_data: ""
}
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "expanded"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -4736,3 +4736,15 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_unsqueeze_ai_onnx_domain_opset13) {
test_case.add_expected_output(expected_output);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_expand_failsafe_node) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/expand_failsafe_node.onnx"));
auto test_case = test::TestCase(function, s_device);
const auto input_data = std::vector<float>{1.0f, 2.0f, 3.0f, 4.0f};
test_case.add_input<float>(input_data);
// the target shape is an empty constant so the Expand operation should not modify the input shape
test_case.add_expected_output<float>(input_data);
test_case.run();
}

View File

@ -76,7 +76,7 @@ Graph::Graph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
// invalid external data makes initializers creation impossible
throw;
} catch (const ngraph::ngraph_error&) {
ng_constant = default_opset::Constant::create(tensor.get_ng_type(), Shape{}, {0});
ng_constant = ngraph::onnx_import::common::make_failsafe_constant(tensor.get_ng_type());
}
initializers.emplace(initializer_tensor.name(), tensor);

View File

@ -24,7 +24,7 @@ inline std::shared_ptr<default_opset::Constant> make_ng_constant_impl(const elem
try {
constant = std::make_shared<default_opset::Constant>(type, tensor.get_shape(), tensor.get_data<T>());
} catch (const ngraph::ngraph_error&) {
constant = std::make_shared<default_opset::Constant>(type, Shape{}, 0);
constant = common::make_failsafe_constant(type);
}
return constant;

View File

@ -10,6 +10,7 @@
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
#include "utils/common.hpp"
namespace ngraph {
namespace onnx_import {
@ -19,8 +20,18 @@ OutputVector expand(const Node& node) {
const Output<ngraph::Node> data{node.get_ng_inputs().at(0)};
const Output<ngraph::Node> shape{node.get_ng_inputs().at(1)};
if (common::is_failsafe_node(shape.get_node_shared_ptr())) {
// in case the "shape" input is connected to a failsafe node created in place of an invalid initializer
// the target shape should be ignored and this Expand operation should not modify its input tensor
// the Broadcast created below should be eliminated later on by an appropriate optimization pass
const auto identity_broadcast = default_opset::Constant::create(element::i64, Shape{1}, {1});
return {std::make_shared<default_opset::Broadcast>(data,
identity_broadcast,
ngraph::op::BroadcastType::BIDIRECTIONAL)};
} else {
return {std::make_shared<default_opset::Broadcast>(data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)};
}
}
} // namespace set_1

View File

@ -115,6 +115,20 @@ template OutputVector handle_opset6_binary_op<default_opset::Divide>(const Node&
template OutputVector handle_opset6_binary_op<default_opset::Multiply>(const Node& node);
template OutputVector handle_opset6_binary_op<default_opset::Subtract>(const Node& node);
const std::string FAILSAFE_NODE = "ONNX_FAILSAFE_NODE";
std::shared_ptr<default_opset::Constant> make_failsafe_constant(const ngraph::element::Type& dtype) {
const auto failsafe_constant = default_opset::Constant::create(dtype, Shape{}, {0});
auto& rt_info = failsafe_constant->get_rt_info();
rt_info[FAILSAFE_NODE] = "";
return failsafe_constant;
}
bool is_failsafe_node(const std::shared_ptr<ov::Node>& node) {
const auto& rt_info = node->get_rt_info();
return rt_info.find(FAILSAFE_NODE) != rt_info.end();
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph

View File

@ -132,6 +132,16 @@ std::unique_ptr<T> make_unique(Args&&... args) {
/// \return OutputVector with binary op
template <typename T>
OutputVector handle_opset6_binary_op(const Node& node);
/// \brief Creates a "dummy" constant to be used in place of an invalid initializer
/// encountered in the original model.
/// \return A scalar constant containing a single value of zero
/// marked as "failsafe" in the runtime info object
std::shared_ptr<default_opset::Constant> make_failsafe_constant(const ngraph::element::Type& dtype);
/// \brief Checks the node's runtime info object and returns true if this node represents
/// a dummy failsafe node created instead of an incorrect node found in the original model
bool is_failsafe_node(const std::shared_ptr<ov::Node>& node);
} // namespace common
} // namespace onnx_import
} // namespace ngraph