[ONNX] Replace subgraph's inputs from parent with Parameter before node is created (#10113)

This patch fixes case when If operator has subgraph with just Identity op,
which input comes from parent graph. Since Identity is eliminated,
its input is incorrectly pulled to this subgraph's body.
For example:
this ONNX subgraph:
```
               +-----------+
               |AveragePool|
               +-+---+-----+
                 |   |
            +----+   v
            |      .....
            |        |
            |        v
    +-------|--------------------------+
    |       |       If                 |
    |   then|branch      else branch   |
    +-------|--------+-----------------+
    |       |        |                 |
    |       v        |                 |
    |  +-----------+ |                 |
    |  | Identity  | |    .........    |
    |  +-----------+ |                 |
    |                |                 |
    |                |                 |
    +----------------+-----------------+
```
was converted to following (incorrect) nGraph representation:
```
              +-------------+
              | AveragePool |
              +--+---+------+
                 |   |
            +----+   v
            |      .....
            |        |
            |        v
    +-------|---------------------------+
    |       |        If                 |
    |   then|branch       else branch   |
    +-------|---------+-----------------+
    |       v         |                 |
    |  +-----------+  |                 |
    |  | Parameter |  |                 |
    |  +-----------+  |                 |
    |       |         |                 |
    |       v         |                 |
    | +-------------+ |                 |
    | | AveragePool | |    .........    |
    | +-------------+ |                 |
    |       |         |                 |
    |       v         |                 |
    |   +--------+    |                 |
    |   | Result |    |                 |
    |   +--------+    |                 |
    |                 |                 |
    +-----------------+-----------------+
```

With this change, subgraph's inputs from parent scope are replaced with
Parameter before nGraph node is created. In that case Identity's input
is a Parameter (and not AveragePool) and therefore 'then branch' looks like:
```
     +-----------+
     | Parameter |
     +-----------+
           |
           v
     +-----------+
     |  Result   |
     +-----------+

```

Ticket: 73895.
This commit is contained in:
Mateusz Tabaka 2022-02-04 12:23:27 +01:00 committed by GitHub
parent b7c62fcfbc
commit 72216a9b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 350 additions and 103 deletions

View File

@ -0,0 +1,238 @@
ir_version: 6
graph {
node {
output: "zero"
name: "Constant_6"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
int64_data: 0
}
type: TENSOR
}
}
node {
input: "input"
input: "zero"
output: "unsqueeze"
op_type: "Unsqueeze"
}
node {
output: "pads"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 10
data_type: 7
int64_data: 0
int64_data: 0
int64_data: 1
int64_data: 0
int64_data: 0
int64_data: 0
int64_data: 0
int64_data: 1
int64_data: 0
int64_data: 0
}
type: TENSOR
}
}
node {
input: "unsqueeze"
input: "pads"
output: "pad"
name: "Pad_1"
op_type: "Pad"
attribute {
name: "mode"
type: STRING
s: "constant"
}
}
node {
input: "pad"
output: "avgpool"
name: "AveragePool_2"
op_type: "AveragePool"
attribute {
name: "ceil_mode"
i: 0
type: INT
}
attribute {
name: "kernel_shape"
ints: 3
ints: 1
ints: 1
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "strides"
ints: 1
ints: 1
ints: 1
type: INTS
}
}
node {
output: "index"
name: "Constant_3"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
int64_data: 1
}
type: TENSOR
}
}
node {
input: "avgpool"
output: "avgpool_shape"
name: "Shape_4"
op_type: "Shape"
}
node {
input: "avgpool_shape"
input: "index"
output: "gather"
name: "Gather_5"
op_type: "Gather"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "one"
name: "Constant_6"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
int64_data: 1
}
type: TENSOR
}
}
node {
input: "gather"
input: "one"
output: "equal"
name: "Equal_7"
op_type: "Equal"
}
node {
input: "equal"
output: "if"
name: "If_8"
op_type: "If"
attribute {
name: "then_branch"
g {
node {
input: "avgpool"
input: "one"
output: "then_output"
name: "Squeeze_9"
op_type: "Squeeze"
}
name: "then"
output {
name: "then_output"
}
}
type: GRAPH
}
attribute {
name: "else_branch"
g {
node {
input: "avgpool"
output: "else_output"
name: "Identity_10"
op_type: "Identity"
}
name: "else"
output {
name: "else_output"
}
}
type: GRAPH
}
}
node {
input: "input"
input: "if"
output: "output"
name: "Add_11"
op_type: "Add"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 5
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 5
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -692,6 +692,32 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_if_inside_loop) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_with_only_indentity_in_else_branch) {
/*
unsq = unsqueeze(input)
padded = pad(unsq)
avgpool = avgpool(padded, kernel=[3, 1, 1])
if_output = if (avgpool.shape[1] == 1) {
squeeze(avgpool)
} else {
identity(avgpool)
}
output = add(input, if_output)
*/
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_with_only_indentity_in_else_branch.onnx"));
auto test_case = test::TestCase(function, s_device);
std::vector<float> x(shape_size(Shape{1, 5, 2, 2}));
std::iota(x.begin(), x.end(), 0);
std::vector<float> expected{1.333333, 3, 4.666666, 6.333333, 8, 10, 12, 14, 16, 18,
20, 22, 24, 26, 28, 30, 25.33333, 27, 28.666667, 30.33333};
test_case.add_input<float>(x);
test_case.add_expected_output<float>(expected);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_if_dynamic_inputs) {
/*
if (condition) {

View File

@ -60,6 +60,9 @@ public:
/// \return Description of Node
const std::string& get_description() const;
const std::string& input(int index) const;
std::size_t get_inputs_size() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
const std::string& output(int index) const;
std::size_t get_outputs_size() const;

View File

@ -186,16 +186,11 @@ std::shared_ptr<Function> Graph::convert() {
return create_function();
}
void Graph::decode_to_framework_nodes() {
const float total = static_cast<float>(m_model->get_graph().node().size());
unsigned int completed = 0u;
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_model->get_graph().node()) {
const Node node{node_proto, *this};
OutputVector Graph::make_framework_nodes(const Node& onnx_node) {
std::shared_ptr<frontend::ONNXFrameworkNode> framework_node;
if (node.has_subgraphs()) {
const auto& subgraphs = node.get_subgraphs();
auto inputs = node.get_ng_inputs();
if (onnx_node.has_subgraphs()) {
const auto& subgraphs = onnx_node.get_subgraphs();
auto inputs = onnx_node.get_ng_inputs();
std::vector<std::shared_ptr<Function>> functions;
for (const auto& kv : subgraphs) {
auto& subgraph = kv.second;
@ -209,11 +204,20 @@ void Graph::decode_to_framework_nodes() {
}
}
}
framework_node = std::make_shared<frontend::ONNXSubgraphFrameworkNode>(node, functions, inputs);
framework_node = std::make_shared<frontend::ONNXSubgraphFrameworkNode>(onnx_node, functions, inputs);
} else {
framework_node = std::make_shared<frontend::ONNXFrameworkNode>(node);
framework_node = std::make_shared<frontend::ONNXFrameworkNode>(onnx_node);
}
OutputVector ng_nodes{framework_node->outputs()};
return framework_node->outputs();
}
void Graph::decode_to_framework_nodes() {
const float total = static_cast<float>(m_model->get_graph().node().size());
unsigned int completed = 0u;
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_model->get_graph().node()) {
const Node node{node_proto, *this};
OutputVector ng_nodes{make_framework_nodes(node)};
set_friendly_names(node, ng_nodes);
// Iterate over the number of outputs for given node in graph.
// Some of them may be optional and trimmed. See:
@ -265,7 +269,7 @@ OutputVector Graph::get_ng_outputs() const {
return results;
}
OutputVector Graph::make_ng_nodes(const Node& onnx_node) const {
OutputVector Graph::make_ng_nodes(const Node& onnx_node) {
const auto ng_node_factory = m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
// contains outputs of nG subgraph implementing a particular ONNX node (possibly a single output of a single node)
OutputVector ng_subgraph_outputs;
@ -349,80 +353,16 @@ Output<ngraph::Node> Subgraph::get_ng_node_from_cache(const std::string& name) c
return m_parent_graph->get_ng_node_from_cache(name);
}
void Subgraph::replace_input_from_parent_scope_with_parameter(const std::string& in_name,
const Output<ngraph::Node>& from_parent_node,
Input<ngraph::Node>&& node_to_replace_input) {
auto new_param = std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
from_parent_node.get_partial_shape());
node_to_replace_input.replace_source_output(new_param);
m_parameter_to_parent_node_map.insert({new_param, in_name});
m_cache->emplace_node(in_name, new_param);
m_parameters.push_back(new_param);
m_inputs_from_parent.push_back(in_name);
}
void Subgraph::find_inputs_from_parent() {
// find all nodes on edge parent graph-subgraph
// (it means input of node from parent graph, output from subgraph)
for (const auto& node_proto : m_model->get_graph().node()) {
int input_index = 0;
for (const auto& in_name : node_proto.input()) {
if (m_parent_graph->is_ng_node_in_cache(in_name)) {
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
// constants are skipped
if (!ngraph::is_type<ngraph::op::Constant>(from_parent_node.get_node_shared_ptr())) {
for (const auto& out_name : node_proto.output()) {
if (m_cache->contains(out_name)) {
auto node_to_replace_input = m_cache->get_node(out_name);
replace_input_from_parent_scope_with_parameter(
in_name,
from_parent_node,
node_to_replace_input.get_node()->input(input_index));
}
}
}
}
++input_index;
}
// Nodes with subgraphs (like Loop or If) can have implicit inputs (so their subgraphs depend on nodes from
// parent) Those implicit inputs are not present in `node_proto.input()` list so to get them, we need to fetch
// node's nGraph representation and then we can match those inputs with parent nodes
for (const auto& out_name : node_proto.output()) {
if (m_cache->contains(out_name)) {
auto node_to_replace_input = m_cache->get_node(out_name).get_node();
if (!ov::is_type<op::util::MultiSubGraphOp>(node_to_replace_input) &&
!ov::is_type<frontend::ONNXSubgraphFrameworkNode>(node_to_replace_input))
continue;
auto inputs = node_to_replace_input->input_values();
for (size_t i = 0; i < inputs.size(); i++) {
const auto& input = inputs.at(i);
auto input_node = input.get_node();
if (op::is_constant(input_node))
continue;
const auto& in_name = input_node->get_friendly_name();
if (m_parent_graph->is_ng_node_in_cache(in_name)) {
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
replace_input_from_parent_scope_with_parameter(in_name,
from_parent_node,
node_to_replace_input->input(i));
}
}
}
}
}
OutputVector Subgraph::make_ng_nodes(const Node& onnx_node) {
replace_input_from_parent_scope_with_parameter(onnx_node);
return Graph::make_ng_nodes(onnx_node);
}
std::shared_ptr<Function> Subgraph::convert() {
convert_to_ngraph_nodes();
find_inputs_from_parent();
return create_function();
}
void Subgraph::decode_to_framework_nodes() {
Graph::decode_to_framework_nodes();
find_inputs_from_parent();
}
const std::vector<Output<ngraph::Node>> Subgraph::get_inputs_from_parent() const {
OutputVector result;
for (const auto& name : m_inputs_from_parent) {
@ -440,6 +380,30 @@ void Subgraph::infer_inputs_from_parent() {
}
}
OutputVector Subgraph::make_framework_nodes(const Node& onnx_node) {
replace_input_from_parent_scope_with_parameter(onnx_node);
return Graph::make_framework_nodes(onnx_node);
}
void Subgraph::replace_input_from_parent_scope_with_parameter(const Node& onnx_node) {
for (std::size_t i = 0; i < onnx_node.get_inputs_size(); ++i) {
const auto& in_name = onnx_node.input(i);
if (m_parent_graph->is_ng_node_in_cache(in_name) &&
std::find(m_inputs_from_parent.begin(), m_inputs_from_parent.end(), in_name) ==
m_inputs_from_parent.end()) {
const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name);
if (op::is_constant(from_parent_node.get_node()))
continue;
auto new_param = std::make_shared<ngraph::op::Parameter>(from_parent_node.get_element_type(),
from_parent_node.get_partial_shape());
m_parameter_to_parent_node_map.insert({new_param, in_name});
m_cache->emplace_node(in_name, new_param);
m_parameters.push_back(new_param);
m_inputs_from_parent.push_back(in_name);
}
}
}
} // namespace onnx_import
} // namespace ngraph

View File

@ -41,7 +41,7 @@ public:
}
virtual bool is_ng_node_in_cache(const std::string& name) const;
virtual Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
OutputVector make_ng_nodes(const Node& onnx_node) const;
virtual OutputVector make_ng_nodes(const Node& onnx_node);
const OpsetImports& get_opset_imports() const;
virtual ~Graph() = default;
@ -57,7 +57,8 @@ protected:
void set_friendly_names(const Node& onnx_node, const OutputVector& ng_subgraph_outputs) const;
protected:
virtual void decode_to_framework_nodes();
virtual OutputVector make_framework_nodes(const Node& onnx_node);
void decode_to_framework_nodes();
void convert_to_ngraph_nodes();
void remove_dangling_parameters();
std::shared_ptr<Function> create_function();
@ -98,19 +99,13 @@ public:
bool is_ng_node_in_cache(const std::string& name) const override;
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const override;
OutputVector make_ng_nodes(const Node& onnx_node) override;
void infer_inputs_from_parent();
private:
void decode_to_framework_nodes() override;
void find_inputs_from_parent();
/// \brief Replaces current node's input with Parameter if that input comes from parent graph scope
///
/// \param[in] in_name input node name
/// \param[in] from_parent_node nGraph node from parent scope
/// \param[in] node_to_replace_input nGraph input node to be replaced
void replace_input_from_parent_scope_with_parameter(const std::string& in_name,
const Output<ngraph::Node>& from_parent_node,
Input<ngraph::Node>&& node_to_replace_input);
OutputVector make_framework_nodes(const Node& onnx_node) override;
/// \brief Checks if onnx_node has inputs from parent graph and replaces those inputs with Parameters
void replace_input_from_parent_scope_with_parameter(const Node& onnx_node);
const Graph* m_parent_graph;
std::vector<std::string> m_inputs_from_parent;

View File

@ -50,6 +50,8 @@ public:
const std::string& description() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
const std::string& input(int index) const;
std::size_t get_inputs_size() const;
const std::string& output(int index) const;
std::size_t get_outputs_size() const;
@ -103,6 +105,14 @@ const std::vector<std::reference_wrapper<const std::string>>& Node::Impl::get_ou
return m_output_names;
}
const std::string& Node::Impl::input(int index) const {
return m_node_proto->input(index);
}
std::size_t Node::Impl::get_inputs_size() const {
return m_node_proto->input_size();
}
const std::string& Node::Impl::output(int index) const {
return m_node_proto->output(index);
}
@ -110,6 +120,7 @@ const std::string& Node::Impl::output(int index) const {
std::size_t Node::Impl::get_outputs_size() const {
return m_output_names.size();
}
bool Node::Impl::has_attribute(const std::string& name) const {
auto it = std::find_if(std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
return attribute.get_name() == name;
@ -223,12 +234,22 @@ const std::vector<std::reference_wrapper<const std::string>>& Node::get_output_n
return m_pimpl->get_output_names();
}
const std::string& Node::input(int index) const {
return m_pimpl->input(index);
}
std::size_t Node::get_inputs_size() const {
return m_pimpl->get_inputs_size();
}
const std::string& Node::output(int index) const {
return m_pimpl->output(index);
}
std::size_t Node::get_outputs_size() const {
return m_pimpl->get_outputs_size();
}
bool Node::has_attribute(const std::string& name) const {
return m_pimpl->has_attribute(name);
}