[ONNX] Replace subgraph's inputs from parent with Parameter before node is created (#10113)
This patch fixes case when If operator has subgraph with just Identity op, which input comes from parent graph. Since Identity is eliminated, its input is incorrectly pulled to this subgraph's body. For example: this ONNX subgraph: ``` +-----------+ |AveragePool| +-+---+-----+ | | +----+ v | ..... | | | v +-------|--------------------------+ | | If | | then|branch else branch | +-------|--------+-----------------+ | | | | | v | | | +-----------+ | | | | Identity | | ......... | | +-----------+ | | | | | | | | +----------------+-----------------+ ``` was converted to following (incorrect) nGraph representation: ``` +-------------+ | AveragePool | +--+---+------+ | | +----+ v | ..... | | | v +-------|---------------------------+ | | If | | then|branch else branch | +-------|---------+-----------------+ | v | | | +-----------+ | | | | Parameter | | | | +-----------+ | | | | | | | v | | | +-------------+ | | | | AveragePool | | ......... | | +-------------+ | | | | | | | v | | | +--------+ | | | | Result | | | | +--------+ | | | | | +-----------------+-----------------+ ``` With this change, subgraph's inputs from parent scope are replaced with Parameter before nGraph node is created. In that case Identity's input is a Parameter (and not AveragePool) and therefore 'then branch' looks like: ``` +-----------+ | Parameter | +-----------+ | v +-----------+ | Result | +-----------+ ``` Ticket: 73895.
This commit is contained in:
parent
b7c62fcfbc
commit
72216a9b95
@ -0,0 +1,238 @@
|
||||
ir_version: 6
|
||||
graph {
|
||||
node {
|
||||
output: "zero"
|
||||
name: "Constant_6"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "input"
|
||||
input: "zero"
|
||||
output: "unsqueeze"
|
||||
op_type: "Unsqueeze"
|
||||
}
|
||||
node {
|
||||
output: "pads"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 10
|
||||
data_type: 7
|
||||
int64_data: 0
|
||||
int64_data: 0
|
||||
int64_data: 1
|
||||
int64_data: 0
|
||||
int64_data: 0
|
||||
int64_data: 0
|
||||
int64_data: 0
|
||||
int64_data: 1
|
||||
int64_data: 0
|
||||
int64_data: 0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "unsqueeze"
|
||||
input: "pads"
|
||||
output: "pad"
|
||||
name: "Pad_1"
|
||||
op_type: "Pad"
|
||||
attribute {
|
||||
name: "mode"
|
||||
type: STRING
|
||||
s: "constant"
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "pad"
|
||||
output: "avgpool"
|
||||
name: "AveragePool_2"
|
||||
op_type: "AveragePool"
|
||||
attribute {
|
||||
name: "ceil_mode"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "kernel_shape"
|
||||
ints: 3
|
||||
ints: 1
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
attribute {
|
||||
name: "pads"
|
||||
ints: 0
|
||||
ints: 0
|
||||
ints: 0
|
||||
ints: 0
|
||||
ints: 0
|
||||
ints: 0
|
||||
type: INTS
|
||||
}
|
||||
attribute {
|
||||
name: "strides"
|
||||
ints: 1
|
||||
ints: 1
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "index"
|
||||
name: "Constant_3"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "avgpool"
|
||||
output: "avgpool_shape"
|
||||
name: "Shape_4"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "avgpool_shape"
|
||||
input: "index"
|
||||
output: "gather"
|
||||
name: "Gather_5"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "one"
|
||||
name: "Constant_6"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "gather"
|
||||
input: "one"
|
||||
output: "equal"
|
||||
name: "Equal_7"
|
||||
op_type: "Equal"
|
||||
}
|
||||
node {
|
||||
input: "equal"
|
||||
output: "if"
|
||||
name: "If_8"
|
||||
op_type: "If"
|
||||
attribute {
|
||||
name: "then_branch"
|
||||
g {
|
||||
node {
|
||||
input: "avgpool"
|
||||
input: "one"
|
||||
output: "then_output"
|
||||
name: "Squeeze_9"
|
||||
op_type: "Squeeze"
|
||||
}
|
||||
name: "then"
|
||||
output {
|
||||
name: "then_output"
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
attribute {
|
||||
name: "else_branch"
|
||||
g {
|
||||
node {
|
||||
input: "avgpool"
|
||||
output: "else_output"
|
||||
name: "Identity_10"
|
||||
op_type: "Identity"
|
||||
}
|
||||
name: "else"
|
||||
output {
|
||||
name: "else_output"
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "input"
|
||||
input: "if"
|
||||
output: "output"
|
||||
name: "Add_11"
|
||||
op_type: "Add"
|
||||
}
|
||||
input {
|
||||
name: "input"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -692,6 +692,32 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_if_inside_loop) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_with_only_indentity_in_else_branch) {
|
||||
/*
|
||||
unsq = unsqueeze(input)
|
||||
padded = pad(unsq)
|
||||
avgpool = avgpool(padded, kernel=[3, 1, 1])
|
||||
if_output = if (avgpool.shape[1] == 1) {
|
||||
squeeze(avgpool)
|
||||
} else {
|
||||
identity(avgpool)
|
||||
}
|
||||
output = add(input, if_output)
|
||||
*/
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_with_only_indentity_in_else_branch.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
|
||||
std::vector<float> x(shape_size(Shape{1, 5, 2, 2}));
|
||||
std::iota(x.begin(), x.end(), 0);
|
||||
std::vector<float> expected{1.333333, 3, 4.666666, 6.333333, 8, 10, 12, 14, 16, 18,
|
||||
20, 22, 24, 26, 28, 30, 25.33333, 27, 28.666667, 30.33333};
|
||||
test_case.add_input<float>(x);
|
||||
test_case.add_expected_output<float>(expected);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_dynamic_inputs) {
|
||||
/*
|
||||
if (condition) {
|
||||
|
@ -60,6 +60,9 @@ public:
|
||||
/// \return Description of Node
|
||||
const std::string& get_description() const;
|
||||
|
||||
const std::string& input(int index) const;
|
||||
std::size_t get_inputs_size() const;
|
||||
|
||||
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
|
||||
const std::string& output(int index) const;
|
||||
std::size_t get_outputs_size() const;
|
||||
|
@ -186,34 +186,38 @@ std::shared_ptr<Function> Graph::convert() {
|
||||
return create_function();
|
||||
}
|
||||
|
||||
OutputVector Graph::make_framework_nodes(const Node& onnx_node) {
|
||||
std::shared_ptr<frontend::ONNXFrameworkNode> framework_node;
|
||||
if (onnx_node.has_subgraphs()) {
|
||||
const auto& subgraphs = onnx_node.get_subgraphs();
|
||||
auto inputs = onnx_node.get_ng_inputs();
|
||||
std::vector<std::shared_ptr<Function>> functions;
|
||||
for (const auto& kv : subgraphs) {
|
||||
auto& subgraph = kv.second;
|
||||
functions.push_back(subgraph->decode());
|
||||
for (const auto& input : subgraph->get_inputs_from_parent()) {
|
||||
const auto& name = input.get_node()->get_friendly_name();
|
||||
if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output<ngraph::Node>& n) -> bool {
|
||||
return name == n.get_node()->get_friendly_name();
|
||||
}) == inputs.end()) {
|
||||
inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
framework_node = std::make_shared<frontend::ONNXSubgraphFrameworkNode>(onnx_node, functions, inputs);
|
||||
} else {
|
||||
framework_node = std::make_shared<frontend::ONNXFrameworkNode>(onnx_node);
|
||||
}
|
||||
return framework_node->outputs();
|
||||
}
|
||||
|
||||
void Graph::decode_to_framework_nodes() {
|
||||
const float total = static_cast<float>(m_model->get_graph().node().size());
|
||||
unsigned int completed = 0u;
|
||||
// Process ONNX graph nodes, convert to nGraph nodes
|
||||
for (const auto& node_proto : m_model->get_graph().node()) {
|
||||
const Node node{node_proto, *this};
|
||||
std::shared_ptr<frontend::ONNXFrameworkNode> framework_node;
|
||||
if (node.has_subgraphs()) {
|
||||
const auto& subgraphs = node.get_subgraphs();
|
||||
auto inputs = node.get_ng_inputs();
|
||||
std::vector<std::shared_ptr<Function>> functions;
|
||||
for (const auto& kv : subgraphs) {
|
||||
auto& subgraph = kv.second;
|
||||
functions.push_back(subgraph->decode());
|
||||
for (const auto& input : subgraph->get_inputs_from_parent()) {
|
||||
const auto& name = input.get_node()->get_friendly_name();
|
||||
if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output<ngraph::Node>& n) -> bool {
|
||||
return name == n.get_node()->get_friendly_name();
|
||||
}) == inputs.end()) {
|
||||
inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
framework_node = std::make_shared<frontend::ONNXSubgraphFrameworkNode>(node, functions, inputs);
|
||||
} else {
|
||||
framework_node = std::make_shared<frontend::ONNXFrameworkNode>(node);
|
||||
}
|
||||
OutputVector ng_nodes{framework_node->outputs()};
|
||||
OutputVector ng_nodes{make_framework_nodes(node)};
|
||||
set_friendly_names(node, ng_nodes);
|
||||
// Iterate over the number of outputs for given node in graph.
|
||||
// Some of them may be optional and trimmed. See:
|
||||
@ -265,7 +269,7 @@ OutputVector Graph::get_ng_outputs() const {
|
||||
return results;
|
||||
}
|
||||
|
||||
OutputVector Graph::make_ng_nodes(const Node& onnx_node) const {
|
||||
OutputVector Graph::make_ng_nodes(const Node& onnx_node) {
|
||||
const auto ng_node_factory = m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
|
||||
// contains outputs of nG subgraph implementing a particular ONNX node (possibly a single output of a single node)
|
||||
OutputVector ng_subgraph_outputs;
|
||||
@ -349,80 +353,16 @@ Output<ngraph::Node> Subgraph::get_ng_node_from_cache(const std::string& name) c
|
||||
return m_parent_graph->get_ng_node_from_cache(name);
|
||||
}
|
||||
|
||||
void Subgraph::replace_input_from_parent_scope_with_parameter(const std::string& in_name,
|
||||
const Output<ngraph::Node>& from_parent_node,
|
||||
Input<ngraph::Node>&& node_to_replace_input) {
|
||||
auto new_param = std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
|
||||
from_parent_node.get_partial_shape());
|
||||
node_to_replace_input.replace_source_output(new_param);
|
||||
m_parameter_to_parent_node_map.insert({new_param, in_name});
|
||||
m_cache->emplace_node(in_name, new_param);
|
||||
m_parameters.push_back(new_param);
|
||||
m_inputs_from_parent.push_back(in_name);
|
||||
}
|
||||
|
||||
void Subgraph::find_inputs_from_parent() {
|
||||
// find all nodes on edge parent graph-subgraph
|
||||
// (it means input of node from parent graph, output from subgraph)
|
||||
for (const auto& node_proto : m_model->get_graph().node()) {
|
||||
int input_index = 0;
|
||||
for (const auto& in_name : node_proto.input()) {
|
||||
if (m_parent_graph->is_ng_node_in_cache(in_name)) {
|
||||
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
|
||||
// constants are skipped
|
||||
if (!ngraph::is_type<ngraph::op::Constant>(from_parent_node.get_node_shared_ptr())) {
|
||||
for (const auto& out_name : node_proto.output()) {
|
||||
if (m_cache->contains(out_name)) {
|
||||
auto node_to_replace_input = m_cache->get_node(out_name);
|
||||
replace_input_from_parent_scope_with_parameter(
|
||||
in_name,
|
||||
from_parent_node,
|
||||
node_to_replace_input.get_node()->input(input_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
++input_index;
|
||||
}
|
||||
// Nodes with subgraphs (like Loop or If) can have implicit inputs (so their subgraphs depend on nodes from
|
||||
// parent) Those implicit inputs are not present in `node_proto.input()` list so to get them, we need to fetch
|
||||
// node's nGraph representation and then we can match those inputs with parent nodes
|
||||
for (const auto& out_name : node_proto.output()) {
|
||||
if (m_cache->contains(out_name)) {
|
||||
auto node_to_replace_input = m_cache->get_node(out_name).get_node();
|
||||
if (!ov::is_type<op::util::MultiSubGraphOp>(node_to_replace_input) &&
|
||||
!ov::is_type<frontend::ONNXSubgraphFrameworkNode>(node_to_replace_input))
|
||||
continue;
|
||||
auto inputs = node_to_replace_input->input_values();
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
const auto& input = inputs.at(i);
|
||||
auto input_node = input.get_node();
|
||||
if (op::is_constant(input_node))
|
||||
continue;
|
||||
const auto& in_name = input_node->get_friendly_name();
|
||||
if (m_parent_graph->is_ng_node_in_cache(in_name)) {
|
||||
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
|
||||
replace_input_from_parent_scope_with_parameter(in_name,
|
||||
from_parent_node,
|
||||
node_to_replace_input->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OutputVector Subgraph::make_ng_nodes(const Node& onnx_node) {
|
||||
replace_input_from_parent_scope_with_parameter(onnx_node);
|
||||
return Graph::make_ng_nodes(onnx_node);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> Subgraph::convert() {
|
||||
convert_to_ngraph_nodes();
|
||||
find_inputs_from_parent();
|
||||
return create_function();
|
||||
}
|
||||
|
||||
void Subgraph::decode_to_framework_nodes() {
|
||||
Graph::decode_to_framework_nodes();
|
||||
find_inputs_from_parent();
|
||||
}
|
||||
|
||||
const std::vector<Output<ngraph::Node>> Subgraph::get_inputs_from_parent() const {
|
||||
OutputVector result;
|
||||
for (const auto& name : m_inputs_from_parent) {
|
||||
@ -440,6 +380,30 @@ void Subgraph::infer_inputs_from_parent() {
|
||||
}
|
||||
}
|
||||
|
||||
OutputVector Subgraph::make_framework_nodes(const Node& onnx_node) {
|
||||
replace_input_from_parent_scope_with_parameter(onnx_node);
|
||||
return Graph::make_framework_nodes(onnx_node);
|
||||
}
|
||||
|
||||
void Subgraph::replace_input_from_parent_scope_with_parameter(const Node& onnx_node) {
|
||||
for (std::size_t i = 0; i < onnx_node.get_inputs_size(); ++i) {
|
||||
const auto& in_name = onnx_node.input(i);
|
||||
if (m_parent_graph->is_ng_node_in_cache(in_name) &&
|
||||
std::find(m_inputs_from_parent.begin(), m_inputs_from_parent.end(), in_name) ==
|
||||
m_inputs_from_parent.end()) {
|
||||
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
|
||||
if (op::is_constant(from_parent_node.get_node()))
|
||||
continue;
|
||||
auto new_param = std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
|
||||
from_parent_node.get_partial_shape());
|
||||
m_parameter_to_parent_node_map.insert({new_param, in_name});
|
||||
m_cache->emplace_node(in_name, new_param);
|
||||
m_parameters.push_back(new_param);
|
||||
m_inputs_from_parent.push_back(in_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
||||
|
@ -41,7 +41,7 @@ public:
|
||||
}
|
||||
virtual bool is_ng_node_in_cache(const std::string& name) const;
|
||||
virtual Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
|
||||
OutputVector make_ng_nodes(const Node& onnx_node) const;
|
||||
virtual OutputVector make_ng_nodes(const Node& onnx_node);
|
||||
const OpsetImports& get_opset_imports() const;
|
||||
virtual ~Graph() = default;
|
||||
|
||||
@ -57,7 +57,8 @@ protected:
|
||||
void set_friendly_names(const Node& onnx_node, const OutputVector& ng_subgraph_outputs) const;
|
||||
|
||||
protected:
|
||||
virtual void decode_to_framework_nodes();
|
||||
virtual OutputVector make_framework_nodes(const Node& onnx_node);
|
||||
void decode_to_framework_nodes();
|
||||
void convert_to_ngraph_nodes();
|
||||
void remove_dangling_parameters();
|
||||
std::shared_ptr<Function> create_function();
|
||||
@ -98,19 +99,13 @@ public:
|
||||
|
||||
bool is_ng_node_in_cache(const std::string& name) const override;
|
||||
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const override;
|
||||
OutputVector make_ng_nodes(const Node& onnx_node) override;
|
||||
void infer_inputs_from_parent();
|
||||
|
||||
private:
|
||||
void decode_to_framework_nodes() override;
|
||||
void find_inputs_from_parent();
|
||||
/// \brief Replaces current node's input with Parameter if that input comes from parent graph scope
|
||||
///
|
||||
/// \param[in] in_name input node name
|
||||
/// \param[in] from_parent_node nGraph node from parent scope
|
||||
/// \param[in] node_to_replace_input nGraph input node to be replaced
|
||||
void replace_input_from_parent_scope_with_parameter(const std::string& in_name,
|
||||
const Output<ngraph::Node>& from_parent_node,
|
||||
Input<ngraph::Node>&& node_to_replace_input);
|
||||
OutputVector make_framework_nodes(const Node& onnx_node) override;
|
||||
/// \brief Checks if onnx_node has inputs from parent graph and replaces those inputs with Parameters
|
||||
void replace_input_from_parent_scope_with_parameter(const Node& onnx_node);
|
||||
|
||||
const Graph* m_parent_graph;
|
||||
std::vector<std::string> m_inputs_from_parent;
|
||||
|
@ -50,6 +50,8 @@ public:
|
||||
|
||||
const std::string& description() const;
|
||||
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
|
||||
const std::string& input(int index) const;
|
||||
std::size_t get_inputs_size() const;
|
||||
const std::string& output(int index) const;
|
||||
std::size_t get_outputs_size() const;
|
||||
|
||||
@ -103,6 +105,14 @@ const std::vector<std::reference_wrapper<const std::string>>& Node::Impl::get_ou
|
||||
return m_output_names;
|
||||
}
|
||||
|
||||
const std::string& Node::Impl::input(int index) const {
|
||||
return m_node_proto->input(index);
|
||||
}
|
||||
|
||||
std::size_t Node::Impl::get_inputs_size() const {
|
||||
return m_node_proto->input_size();
|
||||
}
|
||||
|
||||
const std::string& Node::Impl::output(int index) const {
|
||||
return m_node_proto->output(index);
|
||||
}
|
||||
@ -110,6 +120,7 @@ const std::string& Node::Impl::output(int index) const {
|
||||
std::size_t Node::Impl::get_outputs_size() const {
|
||||
return m_output_names.size();
|
||||
}
|
||||
|
||||
bool Node::Impl::has_attribute(const std::string& name) const {
|
||||
auto it = std::find_if(std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
|
||||
return attribute.get_name() == name;
|
||||
@ -223,12 +234,22 @@ const std::vector<std::reference_wrapper<const std::string>>& Node::get_output_n
|
||||
return m_pimpl->get_output_names();
|
||||
}
|
||||
|
||||
const std::string& Node::input(int index) const {
|
||||
return m_pimpl->input(index);
|
||||
}
|
||||
|
||||
std::size_t Node::get_inputs_size() const {
|
||||
return m_pimpl->get_inputs_size();
|
||||
}
|
||||
|
||||
const std::string& Node::output(int index) const {
|
||||
return m_pimpl->output(index);
|
||||
}
|
||||
|
||||
std::size_t Node::get_outputs_size() const {
|
||||
return m_pimpl->get_outputs_size();
|
||||
}
|
||||
|
||||
bool Node::has_attribute(const std::string& name) const {
|
||||
return m_pimpl->has_attribute(name);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user