ONNX Expand extended support (#10833)
This commit is contained in:
parent
840e622da5
commit
29144d3a6b
52
src/core/tests/models/onnx/expand_failsafe_node.prototxt
Normal file
52
src/core/tests/models/onnx/expand_failsafe_node.prototxt
Normal file
@ -0,0 +1,52 @@
|
||||
ir_version: 8
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "shape"
|
||||
output: "expanded"
|
||||
op_type: "Expand"
|
||||
}
|
||||
name: "test_graph"
|
||||
initializer {
|
||||
dims: 0
|
||||
data_type: 7
|
||||
name: "shape"
|
||||
raw_data: ""
|
||||
}
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "expanded"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -4736,3 +4736,15 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_unsqueeze_ai_onnx_domain_opset13) {
|
||||
test_case.add_expected_output(expected_output);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_expand_failsafe_node) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/expand_failsafe_node.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
const auto input_data = std::vector<float>{1.0f, 2.0f, 3.0f, 4.0f};
|
||||
test_case.add_input<float>(input_data);
|
||||
// the target shape is an empty constant so the Expand operation should not modify the input shape
|
||||
test_case.add_expected_output<float>(input_data);
|
||||
test_case.run();
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ Graph::Graph(const std::shared_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
|
||||
// invalid external data makes initializers creation impossible
|
||||
throw;
|
||||
} catch (const ngraph::ngraph_error&) {
|
||||
ng_constant = default_opset::Constant::create(tensor.get_ng_type(), Shape{}, {0});
|
||||
ng_constant = ngraph::onnx_import::common::make_failsafe_constant(tensor.get_ng_type());
|
||||
}
|
||||
|
||||
initializers.emplace(initializer_tensor.name(), tensor);
|
||||
|
@ -24,7 +24,7 @@ inline std::shared_ptr<default_opset::Constant> make_ng_constant_impl(const elem
|
||||
try {
|
||||
constant = std::make_shared<default_opset::Constant>(type, tensor.get_shape(), tensor.get_data<T>());
|
||||
} catch (const ngraph::ngraph_error&) {
|
||||
constant = std::make_shared<default_opset::Constant>(type, Shape{}, 0);
|
||||
constant = common::make_failsafe_constant(type);
|
||||
}
|
||||
|
||||
return constant;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "utils/common.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
@ -19,7 +20,17 @@ OutputVector expand(const Node& node) {
|
||||
const Output<ngraph::Node> data{node.get_ng_inputs().at(0)};
|
||||
const Output<ngraph::Node> shape{node.get_ng_inputs().at(1)};
|
||||
|
||||
if (common::is_failsafe_node(shape.get_node_shared_ptr())) {
|
||||
// in case the "shape" input is connected to a failsafe node created in place of an invalid initializer
|
||||
// the target shape should be ignored and this Expand operation should not modify its input tensor
|
||||
// the Broadcast created below should be eliminated later on by an appropriate optimization pass
|
||||
const auto identity_broadcast = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
||||
return {std::make_shared<default_opset::Broadcast>(data,
|
||||
identity_broadcast,
|
||||
ngraph::op::BroadcastType::BIDIRECTIONAL)};
|
||||
} else {
|
||||
return {std::make_shared<default_opset::Broadcast>(data, shape, ngraph::op::BroadcastType::BIDIRECTIONAL)};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
|
@ -115,6 +115,20 @@ template OutputVector handle_opset6_binary_op<default_opset::Divide>(const Node&
|
||||
template OutputVector handle_opset6_binary_op<default_opset::Multiply>(const Node& node);
|
||||
template OutputVector handle_opset6_binary_op<default_opset::Subtract>(const Node& node);
|
||||
|
||||
const std::string FAILSAFE_NODE = "ONNX_FAILSAFE_NODE";
|
||||
|
||||
std::shared_ptr<default_opset::Constant> make_failsafe_constant(const ngraph::element::Type& dtype) {
|
||||
const auto failsafe_constant = default_opset::Constant::create(dtype, Shape{}, {0});
|
||||
auto& rt_info = failsafe_constant->get_rt_info();
|
||||
rt_info[FAILSAFE_NODE] = "";
|
||||
return failsafe_constant;
|
||||
}
|
||||
|
||||
bool is_failsafe_node(const std::shared_ptr<ov::Node>& node) {
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
return rt_info.find(FAILSAFE_NODE) != rt_info.end();
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -132,6 +132,16 @@ std::unique_ptr<T> make_unique(Args&&... args) {
|
||||
/// \return OutputVector with binary op
|
||||
template <typename T>
|
||||
OutputVector handle_opset6_binary_op(const Node& node);
|
||||
|
||||
/// \brief Creates a "dummy" constant to be used in place of an invalid initializer
|
||||
/// encountered in the original model.
|
||||
/// \return A scalar constant containing a single value of zero
|
||||
/// marked as "failsafe" in the runtime info object
|
||||
std::shared_ptr<default_opset::Constant> make_failsafe_constant(const ngraph::element::Type& dtype);
|
||||
|
||||
/// \brief Checks the node's runtime info object and returns true if this node represents
|
||||
/// a dummy failsafe node created instead of an incorrect node found in the original model
|
||||
bool is_failsafe_node(const std::shared_ptr<ov::Node>& node);
|
||||
} // namespace common
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
Loading…
Reference in New Issue
Block a user