Add high level API to ONNX Editor (#4927)

This commit is contained in:
Mateusz Bencer 2021-05-17 10:11:49 +02:00 committed by GitHub
parent d5a8861475
commit 880b45a770
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1017 additions and 5 deletions

View File

@ -0,0 +1,112 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <map>
#include <string>
#include <vector>
#include "onnx_editor/editor_types.hpp"
namespace ONNX_NAMESPACE
{
// Forward declaration to avoid the necessity of including paths in components
// that don't directly depend on the ONNX library
class GraphProto;
} // namespace ONNX_NAMESPACE
namespace ngraph
{
namespace onnx_editor
{
/// \brief A class which allows specifying InputEdge and OutputEdge by user-friendly ONNX
/// names.
class EdgeMapper
{
public:
EdgeMapper() = default;
/// \brief Creates an edge mapper based on a GraphProto object.
///
/// \note If state of graph_proto will be changed, the information from edge mapper
/// is outdated. In such a case the update method should be called.
///
/// \param graph_proto Reference to a GraphProto object.
EdgeMapper(const ONNX_NAMESPACE::GraphProto& graph_proto);
/// \brief Returns the InputEdge based on a node (node name or output name)
/// and an input (input name or input index).
///
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
/// In such a case the algorthim tries to match the given node name
/// with the input name (providing an input index is not enough).
/// If a unique edge is found, it will be returned.
/// If InputEdge cannot be determined based on parameter values an ngraph_error
/// exception will be thrown.
///
/// \param node An EditorNode helper structure created based on a node name
/// or a node output name.
///
/// \param input An EditorInput helper structure created based on a input name
/// or a input index.
InputEdge find_input_edge(const EditorNode& node, const EditorInput& input) const;
/// \brief Returns an OutputEdge based on a node (node name or output name)
/// and an output (output name or output index).
///
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
/// In such a case the algorthim will try to match the given node name
/// with the output name (providing an output index is not enough).
/// If after such operation a found edge is unique, it is returned.
/// If OutputEdge cannot be determined based on given params the ngraph_error
/// exception is thrown.
///
/// \param node An EditorNode helper structure created based on a node name
/// or a node output name.
///
/// \param output An EditorOutput helper structure created based on a output name
/// or a output index.
OutputEdge find_output_edge(const EditorNode& node, const EditorOutput& output) const;
/// \brief Returns an OutputEdge based on a output name.
///
/// \note The output name guarantees the uniqueness of the edge.
///
/// \param output_name A node output name.
///
OutputEdge find_output_edge(const std::string& output_name) const;
/// \brief Returns a vector of InputEdges which consume an output of a node
/// determined by provided output name.
///
/// \note The output name is deterministic in the ONNX standard.
///
/// \param output_name A node output name.
///
std::vector<InputEdge> find_output_consumers(const std::string& output_name) const;
/// \brief Returns true if a provided node is correct (exists in a graph)
/// and is not ambiguous (identification of an ONNX node can be ambiguous
/// if an only tensor name is provided).
///
/// \param node An EditorNode helper structure created based on a node name
/// or a node output name.
///
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
private:
std::vector<int> find_node_indexes(const std::string& node_name,
const std::string& output_name) const;
std::string get_node_input_name(int node_index, int input_index) const;
std::string get_node_output_name(int node_index, int output_index) const;
std::vector<std::vector<std::string>> m_node_inputs;
std::vector<std::vector<std::string>> m_node_outputs;
std::multimap<std::string, int> m_node_name_to_index;
std::map<std::string, int> m_node_output_name_to_index;
std::multimap<std::string, int> m_output_consumers_index;
};
} // namespace onnx_editor
} // namespace ngraph

View File

@ -108,7 +108,69 @@ namespace ngraph
/// \param out_file_path A path to the file where the modified model should be dumped.
void serialize(const std::string& out_file_path) const;
/// \brief Returns the InputEdge based on a node (node name or output name)
/// and an input (input name or input index).
///
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
/// In such a case the algorthim tries to match the given node name
/// with the input name (providing an input index is not enough).
/// If a unique edge is found, it will be returned.
/// If InputEdge cannot be determined based on parameter values an ngraph_error
/// exception will be thrown.
///
/// \param node A node helper structure created based on a node name
/// or a node output name.
///
/// \param input An input helper structure created based on a input name
/// or a input index.
InputEdge find_input_edge(const EditorNode& node, const EditorInput& input) const;
/// \brief Returns an OutputEdge based on a node (node name or output name)
/// and an output (output name or output index).
///
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
/// In such a case the algorthim will try to match the given node name
/// with the output name (providing an output index is not enough).
/// If after such operation a found edge is unique, it is returned.
/// If OutputEdge cannot be determined based on given params the ngraph_error
/// exception is thrown.
///
/// \param node A node helper structure created based on a node name
/// or a node output name.
///
/// \param output A output helper structure created based on a output name
/// or a output index.
OutputEdge find_output_edge(const EditorNode& node, const EditorOutput& output) const;
/// \brief Returns an OutputEdge based on a output name.
///
/// \note The output name guarantees the uniqueness of the edge.
///
/// \param output_name A node output name.
///
OutputEdge find_output_edge(const std::string& output_name) const;
/// \brief Returns a vector of InputEdges which consume an output of a node
/// determined by provided output name.
///
/// \note The output name is deterministic in the ONNX standard.
///
/// \param output_name A node output name.
///
std::vector<InputEdge> find_output_consumers(const std::string& output_name) const;
/// \brief Returns a vector of InputEdges which consume an output of a node
/// determined by provided output name.
///
/// \note The output name is deterministic in the ONNX standard.
///
/// \param output_name A node output name.
///
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
private:
void update_mapper_if_needed() const;
const std::string m_model_path;
struct Impl;

View File

@ -60,5 +60,74 @@ namespace ngraph
/// OutputEdge(5, "out1")
/// OutputEdge(5, "out2")
using OutputEdge = Edge<EdgeType::OUTPUT>;
/// \brief Specifies a single node input by the name or index.
///
/// For a node test_node, with 3 inputs:
///
/// ----(in_A)----> +-----------+
/// ----(in_B)----> | test_node | ----(out)---->
/// ----(in_C)----> +-----------+
/// You can indicate in_B as EditorInput("in_B") or EditorInput(1)
struct EditorInput
{
EditorInput() = delete;
EditorInput(std::string input_name)
: m_input_name{std::move(input_name)}
{
}
EditorInput(const int input_index)
: m_input_index{input_index}
{
}
const std::string m_input_name = "";
const int m_input_index = -1;
};
/// \brief Specifies a single node output by the name or index.
/// For a node test_node, with 2 outputs:
///
/// +-----------+ ---(out1)--->
/// ----(in_A)----> | test_node |
/// +-----------+ ---(out2)--->
/// You can indicate out2 as EditorOutput("out2") or EditorOutput(1)
struct EditorOutput
{
EditorOutput() = delete;
EditorOutput(std::string output_name)
: m_output_name{std::move(output_name)}
{
}
EditorOutput(const int output_index)
: m_output_index{output_index}
{
}
const std::string m_output_name = "";
const int m_output_index = -1;
};
/// \brief Specifies a single node by output name which is determinitic
/// or node name which can be ambiguous.
/// For a node test_node, with 2 outputs:
///
/// +-----------+ ---(out1)--->
/// ----(in_A)----> | test_node |
/// +-----------+ ---(out2)--->
/// You can indicate test_node by name as EditorNode("test_node")
/// or by assigned output as EditorNode(EditorOutput("out1"))
/// or EditorNode(EditorOutput("out2"))
struct EditorNode
{
EditorNode(std::string node_name)
: m_node_name{std::move(node_name)}
{
}
EditorNode(EditorOutput output)
: m_output_name{std::move(output.m_output_name)}
{
}
const std::string m_node_name = "";
const std::string m_output_name = "";
};
} // namespace onnx_editor
} // namespace ngraph

View File

@ -0,0 +1,243 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <algorithm>
#include <onnx/onnx_pb.h>
#include "ngraph/except.hpp"
#include "onnx_editor/edge_mapper.hpp"
using namespace ngraph;
using namespace ngraph::onnx_editor;
onnx_editor::EdgeMapper::EdgeMapper(const ONNX_NAMESPACE::GraphProto& graph_proto)
: m_node_inputs(graph_proto.node().size())
, m_node_outputs(graph_proto.node().size())
{
int topological_index = 0;
for (const auto& node_proto : graph_proto.node())
{
for (const auto& out_name : node_proto.output())
{
// node output name is unique
m_node_output_name_to_index.emplace(out_name, topological_index);
m_node_outputs[topological_index].push_back(out_name);
}
for (const auto& in_name : node_proto.input())
{
m_node_inputs[topological_index].push_back(in_name);
m_output_consumers_index.emplace(in_name, topological_index);
}
if (!node_proto.name().empty())
{
// node name can identify node, but it can be ambiguous
m_node_name_to_index.emplace(node_proto.name(), topological_index);
}
++topological_index;
}
}
std::vector<int> onnx_editor::EdgeMapper::find_node_indexes(const std::string& node_name,
const std::string& output_name) const
{
if (!output_name.empty())
{
const auto& index_iter = m_node_output_name_to_index.find(output_name);
if (index_iter != std::end(m_node_output_name_to_index))
{
return std::vector<int>{index_iter->second};
}
}
std::vector<int> result;
if (!node_name.empty())
{
const auto matched_nodes_range = m_node_name_to_index.equal_range(node_name);
std::transform(matched_nodes_range.first,
matched_nodes_range.second,
std::back_inserter(result),
[](const std::pair<std::string, int>& iter) { return iter.second; });
}
return result;
};
std::string onnx_editor::EdgeMapper::get_node_output_name(int node_index, int output_index) const
{
if (node_index >= static_cast<int>(m_node_outputs.size()))
{
throw ngraph_error("Node with index: " + std::to_string(node_index) +
"is out of scope outputs list");
}
if (output_index >= static_cast<int>(m_node_outputs[node_index].size()))
{
throw ngraph_error("Node with index: " + std::to_string(node_index) +
" has not output with index: " + std::to_string(output_index));
}
const auto output_name = m_node_outputs[node_index][output_index];
return output_name;
}
std::string onnx_editor::EdgeMapper::get_node_input_name(int node_index, int input_index) const
{
if (node_index >= static_cast<int>(m_node_inputs.size()))
{
throw ngraph_error("Node with index: " + std::to_string(node_index) +
"is out of scope inputs list");
}
if (input_index >= static_cast<int>(m_node_inputs[node_index].size()))
{
throw ngraph_error("Node with index: " + std::to_string(node_index) +
" has not input with index: " + std::to_string(input_index));
}
const auto input_name = m_node_inputs[node_index][input_index];
return input_name;
}
InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node,
const EditorInput& in) const
{
// identification can be both based on node name and output name
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
int node_index = -1;
if (node_indexes.size() == 1)
{
node_index = node_indexes[0];
}
else if (node_indexes.empty())
{
throw ngraph_error(
"Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" was not found");
}
else if (!in.m_input_name
.empty()) // input indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches input name
int matched_inputs_number = 0;
for (const auto& index : node_indexes)
{
if (std::count(std::begin(m_node_inputs[index]),
std::end(m_node_inputs[index]),
in.m_input_name) > 0)
{
node_index = index;
++matched_inputs_number;
}
}
if (matched_inputs_number == 0)
{
throw ngraph_error("Input edge described by: " + node.m_node_name +
" and input name: " + in.m_input_name + " was not found");
}
if (matched_inputs_number > 1)
{
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " +
in.m_input_name + " are ambiguous to determine input edge");
}
}
else
{
throw ngraph_error("Given node name: " + node.m_node_name +
" and input index: " + std::to_string(in.m_input_index) +
" are ambiguous to determine input edge");
}
if (!in.m_input_name.empty())
{
return InputEdge{node_index, in.m_input_name};
}
else if (in.m_input_index != -1) // input index is set
{
const auto& input_name = get_node_input_name(node_index, in.m_input_index);
return InputEdge{node_index, input_name};
}
else
{
throw ngraph_error("Not enough information to determine input edge");
}
}
OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node,
const EditorOutput& out) const
{
// identification can be both based on node name and output name
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
int node_index = -1;
if (node_indexes.size() == 1)
{
node_index = node_indexes[0];
}
else if (node_indexes.empty())
{
throw ngraph_error(
"Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
" was not found");
}
else if (!out.m_output_name
.empty()) // output indexes are not deterministic if a node name is ambiguous
{
// many nodes with the same name
// check if some of found index matches output name
int matched_outputs_number = 0;
for (const auto& index : node_indexes)
{
if (std::count(std::begin(m_node_outputs[index]),
std::end(m_node_outputs[index]),
out.m_output_name) > 0)
{
node_index = index;
++matched_outputs_number;
}
}
if (matched_outputs_number == 0)
{
throw ngraph_error("Output edge described by: " + node.m_node_name +
" and output name: " + out.m_output_name + " was not found");
}
}
else
{
throw ngraph_error("Given node name: " + node.m_node_name +
" and output index: " + std::to_string(out.m_output_index) +
" are ambiguous to determine output edge");
}
if (!out.m_output_name.empty())
{
return OutputEdge{node_index, out.m_output_name};
}
else if (out.m_output_index != -1) // output index is set
{
const auto& output_name = get_node_output_name(node_index, out.m_output_index);
return OutputEdge{node_index, output_name};
}
else
{
throw ngraph_error("Not enough information to determine output edge");
}
}
OutputEdge onnx_editor::EdgeMapper::find_output_edge(const std::string& output_name) const
{
return find_output_edge(EditorNode{EditorOutput{output_name}}, EditorOutput{output_name});
}
std::vector<InputEdge>
onnx_editor::EdgeMapper::find_output_consumers(const std::string& output_name) const
{
const auto matched_nodes_range = m_output_consumers_index.equal_range(output_name);
std::vector<InputEdge> input_edges;
std::transform(matched_nodes_range.first,
matched_nodes_range.second,
std::back_inserter(input_edges),
[&output_name](const std::pair<std::string, int>& iter) {
return InputEdge{iter.second, output_name};
});
return input_edges;
}
bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const
{
return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1;
}

View File

@ -10,10 +10,12 @@
#include "ngraph/log.hpp"
#include "onnx_common/parser.hpp"
#include "onnx_common/utils.hpp"
#include "onnx_editor/edge_mapper.hpp"
#include "onnx_editor/editor.hpp"
#include "onnx_import/utils/onnx_internal.hpp"
using namespace ngraph;
using namespace ngraph::onnx_editor;
namespace
{
@ -186,6 +188,8 @@ namespace
struct onnx_editor::ONNXModelEditor::Impl
{
ONNX_NAMESPACE::ModelProto m_model_proto;
EdgeMapper m_edge_mapper;
bool m_is_mapper_updated = false;
Impl() = delete;
@ -285,6 +289,7 @@ void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vector<InputEdg
editor.extract_subgraph(outputs);
m_pimpl->remove_shape_inference_info();
m_pimpl->m_is_mapper_updated = false;
}
std::vector<std::string> onnx_editor::ONNXModelEditor::model_inputs() const
@ -344,3 +349,45 @@ void onnx_editor::ONNXModelEditor::set_input_values(
modify_initializer(*onnx_initializer, name, values, onnx_input);
}
}
void onnx_editor::ONNXModelEditor::update_mapper_if_needed() const
{
if (!m_pimpl->m_is_mapper_updated)
{
m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto.graph());
}
m_pimpl->m_is_mapper_updated = true;
}
InputEdge onnx_editor::ONNXModelEditor::find_input_edge(const EditorNode& node,
const EditorInput& input) const
{
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.find_input_edge(node, input);
}
OutputEdge onnx_editor::ONNXModelEditor::find_output_edge(const EditorNode& node,
const EditorOutput& input) const
{
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.find_output_edge(node, input);
}
OutputEdge onnx_editor::ONNXModelEditor::find_output_edge(const std::string& output_name) const
{
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.find_output_edge(output_name);
}
std::vector<InputEdge>
onnx_editor::ONNXModelEditor::find_output_consumers(const std::string& output_name) const
{
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.find_output_consumers(output_name);
}
bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorNode& node) const
{
update_mapper_if_needed();
return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node);
}

View File

@ -8,7 +8,7 @@ graph {
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
name: "conv1"
op_type: "Conv"
attribute {
name: "strides"
@ -34,13 +34,13 @@ graph {
node {
input: "conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
name: "relu1"
op_type: "Relu"
}
node {
input: "conv1/7x7_s2_2"
output: "pool1/3x3_s2_1"
name: ""
name: "maxpool1"
op_type: "MaxPool"
attribute {
name: "strides"

View File

@ -5,12 +5,14 @@ graph {
input: "in1"
output: "relu1"
op_type: "Relu"
name: "relu1_name"
}
node {
input: "relu1"
input: "in2"
output: "add1"
op_type: "Add"
name: "add_ambiguous_name"
}
node {
input: "in3"
@ -23,12 +25,14 @@ graph {
input: "add1"
output: "add2"
op_type: "Add"
name: "add_ambiguous_name"
}
node {
input: "add1"
input: "conv1"
output: "mul2"
op_type: "Mul"
name: ""
}
node {
input: "add2"
@ -40,6 +44,7 @@ graph {
i: 1
type: INT
}
name: "split_name"
}
node {
input: "relu1"

View File

@ -5,6 +5,7 @@ graph {
input: "in1"
output: "relu1"
op_type: "Relu"
name: "add1"
}
node {
input: "in1"
@ -20,12 +21,14 @@ graph {
input: "in2"
output: "relu4"
op_type: "Relu"
name: "relu4_name"
}
node {
input: "relu1"
input: "relu2"
output: "add1"
op_type: "Add"
name: "add1_name"
}
node {
input: "relu2"

View File

@ -21,8 +21,7 @@
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace ngraph;
using namespace ngraph::onnx_import;
using namespace ngraph::onnx_editor;
using namespace onnx_editor;
using namespace ngraph::test;
static std::string s_manifest = "${MANIFEST}";
@ -658,6 +657,478 @@ NGRAPH_TEST(onnx_editor, subgraph__inputs_getter)
EXPECT_EQ(editor.model_inputs(), (std::vector<std::string>{"conv1/7x7_s2_1"}));
}
// HIGHT LEVEL API TESTS
// INPUT EDGES TEST
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_output_name_and_input_name)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
const InputEdge edge = editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_2"}},
EditorInput{"conv1/7x7_s2_1"});
EXPECT_EQ(edge.m_node_idx, 1);
EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1");
const InputEdge edge2 = editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}},
EditorInput{"data_0"});
EXPECT_EQ(edge2.m_node_idx, 0);
EXPECT_EQ(edge2.m_tensor_name, "data_0");
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_output_name_and_input_index)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
const InputEdge edge =
editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_2"}}, EditorInput{0});
EXPECT_EQ(edge.m_node_idx, 1);
EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1");
const InputEdge edge2 =
editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}}, EditorInput{1});
EXPECT_EQ(edge2.m_node_idx, 0);
EXPECT_EQ(edge2.m_tensor_name, "conv1/7x7_s2_w_0");
const InputEdge edge3 =
editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}}, EditorInput{2});
EXPECT_EQ(edge3.m_node_idx, 0);
EXPECT_EQ(edge3.m_tensor_name, "conv1/7x7_s2_b_0");
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_name)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
const InputEdge edge =
editor.find_input_edge(EditorNode{"relu1"}, EditorInput{"conv1/7x7_s2_1"});
EXPECT_EQ(edge.m_node_idx, 1);
EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1");
const InputEdge edge2 =
editor.find_input_edge(EditorNode{"conv1"}, EditorInput{"conv1/7x7_s2_w_0"});
EXPECT_EQ(edge2.m_node_idx, 0);
EXPECT_EQ(edge2.m_tensor_name, "conv1/7x7_s2_w_0");
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_index)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const InputEdge edge = editor.find_input_edge(EditorNode{"relu1_name"}, EditorInput{0});
EXPECT_EQ(edge.m_node_idx, 0);
EXPECT_EQ(edge.m_tensor_name, "in1");
const InputEdge edge2 = editor.find_input_edge(EditorNode{"split_name"}, EditorInput{0});
EXPECT_EQ(edge2.m_node_idx, 5);
EXPECT_EQ(edge2.m_tensor_name, "add2");
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
try
{
const InputEdge edge =
editor.find_input_edge(EditorNode{""}, EditorInput{"conv1/7x7_s2_1"});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(
msg.find("Node with name: not_given and output_name: not_given was not found") !=
std::string::npos);
}
}
// OUTPUT EDGES TEST
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_output_name)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const OutputEdge edge =
editor.find_output_edge(EditorNode{EditorOutput{"mul2"}}, EditorOutput{"mul2"});
EXPECT_EQ(edge.m_node_idx, 4);
EXPECT_EQ(edge.m_tensor_name, "mul2");
const OutputEdge edge2 =
editor.find_output_edge(EditorNode{EditorOutput{"split1"}}, EditorOutput{"split2"});
EXPECT_EQ(edge2.m_node_idx, 5);
EXPECT_EQ(edge2.m_tensor_name, "split2");
// simplified overload
const OutputEdge edge3 =
editor.find_output_edge("mul2");
EXPECT_EQ(edge3.m_node_idx, 4);
EXPECT_EQ(edge3.m_tensor_name, "mul2");
const OutputEdge edge4 =
editor.find_output_edge("split2");
EXPECT_EQ(edge4.m_node_idx, 5);
EXPECT_EQ(edge4.m_tensor_name, "split2");
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_output_name_and_output_index)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const OutputEdge edge =
editor.find_output_edge(EditorNode{EditorOutput{"add2"}}, EditorOutput{0});
EXPECT_EQ(edge.m_node_idx, 3);
EXPECT_EQ(edge.m_tensor_name, "add2");
const OutputEdge edge2 =
editor.find_output_edge(EditorNode{EditorOutput{"split1"}}, EditorOutput{1});
EXPECT_EQ(edge2.m_node_idx, 5);
EXPECT_EQ(edge2.m_tensor_name, "split2");
const OutputEdge edge3 =
editor.find_output_edge(EditorNode{EditorOutput{"split2"}}, EditorOutput{0});
EXPECT_EQ(edge3.m_node_idx, 5);
EXPECT_EQ(edge3.m_tensor_name, "split1");
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_name)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const OutputEdge edge =
editor.find_output_edge(EditorNode{"relu1_name"}, EditorOutput{"relu1"});
EXPECT_EQ(edge.m_node_idx, 0);
EXPECT_EQ(edge.m_tensor_name, "relu1");
const OutputEdge edge2 =
editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{"split2"});
EXPECT_EQ(edge2.m_node_idx, 5);
EXPECT_EQ(edge2.m_tensor_name, "split2");
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_index)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const OutputEdge edge = editor.find_output_edge(EditorNode{"relu1_name"}, EditorOutput{0});
EXPECT_EQ(edge.m_node_idx, 0);
EXPECT_EQ(edge.m_tensor_name, "relu1");
const OutputEdge edge2 = editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{1});
EXPECT_EQ(edge2.m_node_idx, 5);
EXPECT_EQ(edge2.m_tensor_name, "split2");
}
NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
const InputEdge edge =
editor.find_input_edge(EditorNode{EditorOutput{"relu4"}}, EditorInput{0});
EXPECT_EQ(edge.m_node_idx, 3);
EXPECT_EQ(edge.m_tensor_name, "in2");
const OutputEdge edge2 = editor.find_output_edge(EditorNode{"relu4_name"}, EditorOutput{0});
EXPECT_EQ(edge2.m_node_idx, 3);
EXPECT_EQ(edge2.m_tensor_name, "relu4");
const OutputEdge edge3 = editor.find_output_edge(EditorNode{"add1_name"}, EditorOutput{0});
EXPECT_EQ(edge3.m_node_idx, 4);
EXPECT_EQ(edge3.m_tensor_name, "add1");
}
NGRAPH_TEST(onnx_editor, editor_api_select_edge_error_handling)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
// node with given output name not found
try
{
const InputEdge edge =
editor.find_input_edge(EditorNode{EditorOutput{"not_existed"}}, EditorInput{0});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(
msg.find("Node with name: not_given and output_name: not_existed was not found") !=
std::string::npos);
}
// node with given name not found
try
{
const InputEdge edge = editor.find_input_edge(EditorNode{"not_existed"}, EditorInput{0});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(
msg.find("Node with name: not_existed and output_name: not_given was not found") !=
std::string::npos);
}
// input index out of scope
try
{
const InputEdge edge = editor.find_input_edge(EditorNode{"relu4_name"}, EditorInput{1});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Node with index: 3 has not input with index: 1") !=
std::string::npos);
}
// output index out of scope
try
{
const OutputEdge edge =
editor.find_output_edge(EditorNode{"relu4_name"}, EditorOutput{1});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Node with index: 3 has not output with index: 1") !=
std::string::npos);
}
}
// Nodes with ambiguous node names tests
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_but_matched_input)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"in2"});
EXPECT_EQ(edge.m_node_idx, 1);
EXPECT_EQ(edge.m_tensor_name, "in2");
const InputEdge edge2 = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"add1"});
EXPECT_EQ(edge2.m_node_idx, 3);
EXPECT_EQ(edge2.m_tensor_name, "add1");
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_and_not_matched_input)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
try
{
const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"in3"});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Input edge described by: add_ambiguous_name and input name: in3 was not found") !=
std::string::npos);
}
try
{
const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"relu1"});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and input name: relu1 are ambiguous to determine input edge") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_and_input_index)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
try
{
const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{0});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and input index: 0 are ambiguous to determine input edge") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_but_matched_output)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"add1"});
EXPECT_EQ(edge.m_node_idx, 1);
EXPECT_EQ(edge.m_tensor_name, "add1");
const OutputEdge edge2 = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"add2"});
EXPECT_EQ(edge2.m_node_idx, 3);
EXPECT_EQ(edge2.m_tensor_name, "add2");
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_the_same_node_name_and_output_name)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
const OutputEdge edge = editor.find_output_edge(EditorNode{"add1"}, EditorOutput{0});
EXPECT_EQ(edge.m_node_idx, 0);
EXPECT_EQ(edge.m_tensor_name, "relu1");
const OutputEdge edge2 = editor.find_output_edge(EditorNode{EditorOutput{"add1"}}, EditorOutput{0});
EXPECT_EQ(edge2.m_node_idx, 4);
EXPECT_EQ(edge2.m_tensor_name, "add1");
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_and_not_matched_output)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
try
{
const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"split2"});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Output edge described by: add_ambiguous_name and output name: split2 was not found") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_and_output_index)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
try
{
const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{0});
}
catch (const std::exception& e)
{
std::string msg{e.what()};
EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and output index: 0 are ambiguous to determine output edge") !=
std::string::npos);
}
}
NGRAPH_TEST(onnx_editor, editor_api_use_edge_mapper_with_graph_cutter)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
// InputEdge{1, "in2"}
const auto input_edge_1 = editor.find_input_edge(
EditorNode(EditorOutput("add1")), EditorInput(1));
// InputEdge{2, "in3"}
const auto input_edge_2 = editor.find_input_edge(
EditorNode(EditorOutput("conv1")), EditorInput(0));
const auto output_edge = editor.find_output_edge(
EditorNode(EditorOutput("mul2")), EditorOutput(0));
// OutputEdge{4, "mul2"}
editor.cut_graph_fragment({input_edge_1, input_edge_2}, {output_edge});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__existing_inputs_and_outputs_based_extraction.prototxt");
const auto result = compare_onnx_models(editor.model_string(), ref_model);
EXPECT_TRUE(result.is_ok) << result.error_message;
// check if mapper was updated after the model changed
const auto input_edge_4 = editor.find_input_edge(
EditorNode(EditorOutput("relu1")), EditorInput(0));
EXPECT_EQ(input_edge_4.m_node_idx, 0);
EXPECT_EQ(input_edge_4.m_tensor_name, "in1");
const auto input_edge_5 = editor.find_input_edge(
EditorNode(EditorOutput("add1")), EditorInput(1));
EXPECT_EQ(input_edge_5.m_node_idx, 1);
EXPECT_EQ(input_edge_5.m_tensor_name, "in2");
const auto output_edge_3 = editor.find_output_edge("mul2");
EXPECT_EQ(output_edge_3.m_node_idx, 3);
EXPECT_EQ(output_edge_3.m_tensor_name, "mul2");
}
NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
std::vector<InputEdge> output_consumers = editor.find_output_consumers("relu1");
EXPECT_EQ(output_consumers.size(), 3);
EXPECT_EQ(output_consumers[0].m_node_idx, 1);
EXPECT_EQ(output_consumers[0].m_tensor_name, "relu1");
EXPECT_EQ(output_consumers[1].m_node_idx, 3);
EXPECT_EQ(output_consumers[1].m_tensor_name, "relu1");
EXPECT_EQ(output_consumers[2].m_node_idx, 6);
EXPECT_EQ(output_consumers[2].m_tensor_name, "relu1");
output_consumers = editor.find_output_consumers("add1");
EXPECT_EQ(output_consumers.size(), 2);
EXPECT_EQ(output_consumers[0].m_node_idx, 3);
EXPECT_EQ(output_consumers[0].m_tensor_name, "add1");
EXPECT_EQ(output_consumers[1].m_node_idx, 4);
EXPECT_EQ(output_consumers[1].m_tensor_name, "add1");
output_consumers = editor.find_output_consumers("in3");
EXPECT_EQ(output_consumers.size(), 1);
EXPECT_EQ(output_consumers[0].m_node_idx, 2);
EXPECT_EQ(output_consumers[0].m_tensor_name, "in3");
}
NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers_empty_result)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
const std::vector<InputEdge> output_consumers = editor.find_output_consumers("not_existed");
EXPECT_EQ(output_consumers.size(), 0);
}
NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node)
{
ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
bool is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"relu1"}});
EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"mul2"}});
EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"split2"}});
EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"});
EXPECT_EQ(is_correct_node, true);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}});
EXPECT_EQ(is_correct_node, false);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"add_ambiguous_name"});
EXPECT_EQ(is_correct_node, false);
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"not_exist"});
EXPECT_EQ(is_correct_node, false);
}
using TestEngine = test::INTERPRETER_Engine;
NGRAPH_TEST(onnx_editor, values__append_one_initializer)