Add high level API to ONNX Editor (#4927)
This commit is contained in:
parent
d5a8861475
commit
880b45a770
112
ngraph/frontend/onnx_editor/include/onnx_editor/edge_mapper.hpp
Normal file
112
ngraph/frontend/onnx_editor/include/onnx_editor/edge_mapper.hpp
Normal file
@ -0,0 +1,112 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnx_editor/editor_types.hpp"
|
||||
|
||||
namespace ONNX_NAMESPACE
|
||||
{
|
||||
// Forward declaration to avoid the necessity of including paths in components
|
||||
// that don't directly depend on the ONNX library
|
||||
class GraphProto;
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_editor
|
||||
{
|
||||
/// \brief A class which allows specifying InputEdge and OutputEdge by user-friendly ONNX
|
||||
/// names.
|
||||
class EdgeMapper
|
||||
{
|
||||
public:
|
||||
EdgeMapper() = default;
|
||||
|
||||
/// \brief Creates an edge mapper based on a GraphProto object.
|
||||
///
|
||||
/// \note If state of graph_proto will be changed, the information from edge mapper
|
||||
/// is outdated. In such a case the update method should be called.
|
||||
///
|
||||
/// \param graph_proto Reference to a GraphProto object.
|
||||
EdgeMapper(const ONNX_NAMESPACE::GraphProto& graph_proto);
|
||||
|
||||
/// \brief Returns the InputEdge based on a node (node name or output name)
|
||||
/// and an input (input name or input index).
|
||||
///
|
||||
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
|
||||
/// In such a case the algorthim tries to match the given node name
|
||||
/// with the input name (providing an input index is not enough).
|
||||
/// If a unique edge is found, it will be returned.
|
||||
/// If InputEdge cannot be determined based on parameter values an ngraph_error
|
||||
/// exception will be thrown.
|
||||
///
|
||||
/// \param node An EditorNode helper structure created based on a node name
|
||||
/// or a node output name.
|
||||
///
|
||||
/// \param input An EditorInput helper structure created based on a input name
|
||||
/// or a input index.
|
||||
InputEdge find_input_edge(const EditorNode& node, const EditorInput& input) const;
|
||||
|
||||
/// \brief Returns an OutputEdge based on a node (node name or output name)
|
||||
/// and an output (output name or output index).
|
||||
///
|
||||
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
|
||||
/// In such a case the algorthim will try to match the given node name
|
||||
/// with the output name (providing an output index is not enough).
|
||||
/// If after such operation a found edge is unique, it is returned.
|
||||
/// If OutputEdge cannot be determined based on given params the ngraph_error
|
||||
/// exception is thrown.
|
||||
///
|
||||
/// \param node An EditorNode helper structure created based on a node name
|
||||
/// or a node output name.
|
||||
///
|
||||
/// \param output An EditorOutput helper structure created based on a output name
|
||||
/// or a output index.
|
||||
OutputEdge find_output_edge(const EditorNode& node, const EditorOutput& output) const;
|
||||
|
||||
/// \brief Returns an OutputEdge based on a output name.
|
||||
///
|
||||
/// \note The output name guarantees the uniqueness of the edge.
|
||||
///
|
||||
/// \param output_name A node output name.
|
||||
///
|
||||
OutputEdge find_output_edge(const std::string& output_name) const;
|
||||
|
||||
/// \brief Returns a vector of InputEdges which consume an output of a node
|
||||
/// determined by provided output name.
|
||||
///
|
||||
/// \note The output name is deterministic in the ONNX standard.
|
||||
///
|
||||
/// \param output_name A node output name.
|
||||
///
|
||||
std::vector<InputEdge> find_output_consumers(const std::string& output_name) const;
|
||||
|
||||
/// \brief Returns true if a provided node is correct (exists in a graph)
|
||||
/// and is not ambiguous (identification of an ONNX node can be ambiguous
|
||||
/// if an only tensor name is provided).
|
||||
///
|
||||
/// \param node An EditorNode helper structure created based on a node name
|
||||
/// or a node output name.
|
||||
///
|
||||
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
|
||||
|
||||
private:
|
||||
std::vector<int> find_node_indexes(const std::string& node_name,
|
||||
const std::string& output_name) const;
|
||||
std::string get_node_input_name(int node_index, int input_index) const;
|
||||
std::string get_node_output_name(int node_index, int output_index) const;
|
||||
|
||||
std::vector<std::vector<std::string>> m_node_inputs;
|
||||
std::vector<std::vector<std::string>> m_node_outputs;
|
||||
std::multimap<std::string, int> m_node_name_to_index;
|
||||
std::map<std::string, int> m_node_output_name_to_index;
|
||||
std::multimap<std::string, int> m_output_consumers_index;
|
||||
};
|
||||
} // namespace onnx_editor
|
||||
} // namespace ngraph
|
@ -108,7 +108,69 @@ namespace ngraph
|
||||
/// \param out_file_path A path to the file where the modified model should be dumped.
|
||||
void serialize(const std::string& out_file_path) const;
|
||||
|
||||
/// \brief Returns the InputEdge based on a node (node name or output name)
|
||||
/// and an input (input name or input index).
|
||||
///
|
||||
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
|
||||
/// In such a case the algorthim tries to match the given node name
|
||||
/// with the input name (providing an input index is not enough).
|
||||
/// If a unique edge is found, it will be returned.
|
||||
/// If InputEdge cannot be determined based on parameter values an ngraph_error
|
||||
/// exception will be thrown.
|
||||
///
|
||||
/// \param node A node helper structure created based on a node name
|
||||
/// or a node output name.
|
||||
///
|
||||
/// \param input An input helper structure created based on a input name
|
||||
/// or a input index.
|
||||
InputEdge find_input_edge(const EditorNode& node, const EditorInput& input) const;
|
||||
|
||||
/// \brief Returns an OutputEdge based on a node (node name or output name)
|
||||
/// and an output (output name or output index).
|
||||
///
|
||||
/// \note The node name can be ambiguous (many ONNX nodes can have the same name).
|
||||
/// In such a case the algorthim will try to match the given node name
|
||||
/// with the output name (providing an output index is not enough).
|
||||
/// If after such operation a found edge is unique, it is returned.
|
||||
/// If OutputEdge cannot be determined based on given params the ngraph_error
|
||||
/// exception is thrown.
|
||||
///
|
||||
/// \param node A node helper structure created based on a node name
|
||||
/// or a node output name.
|
||||
///
|
||||
/// \param output A output helper structure created based on a output name
|
||||
/// or a output index.
|
||||
OutputEdge find_output_edge(const EditorNode& node, const EditorOutput& output) const;
|
||||
|
||||
/// \brief Returns an OutputEdge based on a output name.
|
||||
///
|
||||
/// \note The output name guarantees the uniqueness of the edge.
|
||||
///
|
||||
/// \param output_name A node output name.
|
||||
///
|
||||
OutputEdge find_output_edge(const std::string& output_name) const;
|
||||
|
||||
/// \brief Returns a vector of InputEdges which consume an output of a node
|
||||
/// determined by provided output name.
|
||||
///
|
||||
/// \note The output name is deterministic in the ONNX standard.
|
||||
///
|
||||
/// \param output_name A node output name.
|
||||
///
|
||||
std::vector<InputEdge> find_output_consumers(const std::string& output_name) const;
|
||||
|
||||
/// \brief Returns a vector of InputEdges which consume an output of a node
|
||||
/// determined by provided output name.
|
||||
///
|
||||
/// \note The output name is deterministic in the ONNX standard.
|
||||
///
|
||||
/// \param output_name A node output name.
|
||||
///
|
||||
bool is_correct_and_unambiguous_node(const EditorNode& node) const;
|
||||
|
||||
private:
|
||||
void update_mapper_if_needed() const;
|
||||
|
||||
const std::string m_model_path;
|
||||
|
||||
struct Impl;
|
||||
|
@ -60,5 +60,74 @@ namespace ngraph
|
||||
/// OutputEdge(5, "out1")
|
||||
/// OutputEdge(5, "out2")
|
||||
using OutputEdge = Edge<EdgeType::OUTPUT>;
|
||||
|
||||
/// \brief Specifies a single node input by the name or index.
|
||||
///
|
||||
/// For a node test_node, with 3 inputs:
|
||||
///
|
||||
/// ----(in_A)----> +-----------+
|
||||
/// ----(in_B)----> | test_node | ----(out)---->
|
||||
/// ----(in_C)----> +-----------+
|
||||
/// You can indicate in_B as EditorInput("in_B") or EditorInput(1)
|
||||
struct EditorInput
|
||||
{
|
||||
EditorInput() = delete;
|
||||
EditorInput(std::string input_name)
|
||||
: m_input_name{std::move(input_name)}
|
||||
{
|
||||
}
|
||||
EditorInput(const int input_index)
|
||||
: m_input_index{input_index}
|
||||
{
|
||||
}
|
||||
const std::string m_input_name = "";
|
||||
const int m_input_index = -1;
|
||||
};
|
||||
|
||||
/// \brief Specifies a single node output by the name or index.
|
||||
/// For a node test_node, with 2 outputs:
|
||||
///
|
||||
/// +-----------+ ---(out1)--->
|
||||
/// ----(in_A)----> | test_node |
|
||||
/// +-----------+ ---(out2)--->
|
||||
/// You can indicate out2 as EditorOutput("out2") or EditorOutput(1)
|
||||
struct EditorOutput
|
||||
{
|
||||
EditorOutput() = delete;
|
||||
EditorOutput(std::string output_name)
|
||||
: m_output_name{std::move(output_name)}
|
||||
{
|
||||
}
|
||||
EditorOutput(const int output_index)
|
||||
: m_output_index{output_index}
|
||||
{
|
||||
}
|
||||
const std::string m_output_name = "";
|
||||
const int m_output_index = -1;
|
||||
};
|
||||
|
||||
/// \brief Specifies a single node by output name which is determinitic
|
||||
/// or node name which can be ambiguous.
|
||||
/// For a node test_node, with 2 outputs:
|
||||
///
|
||||
/// +-----------+ ---(out1)--->
|
||||
/// ----(in_A)----> | test_node |
|
||||
/// +-----------+ ---(out2)--->
|
||||
/// You can indicate test_node by name as EditorNode("test_node")
|
||||
/// or by assigned output as EditorNode(EditorOutput("out1"))
|
||||
/// or EditorNode(EditorOutput("out2"))
|
||||
struct EditorNode
|
||||
{
|
||||
EditorNode(std::string node_name)
|
||||
: m_node_name{std::move(node_name)}
|
||||
{
|
||||
}
|
||||
EditorNode(EditorOutput output)
|
||||
: m_output_name{std::move(output.m_output_name)}
|
||||
{
|
||||
}
|
||||
const std::string m_node_name = "";
|
||||
const std::string m_output_name = "";
|
||||
};
|
||||
} // namespace onnx_editor
|
||||
} // namespace ngraph
|
||||
|
243
ngraph/frontend/onnx_editor/src/edge_mapper.cpp
Normal file
243
ngraph/frontend/onnx_editor/src/edge_mapper.cpp
Normal file
@ -0,0 +1,243 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <algorithm>
|
||||
#include <onnx/onnx_pb.h>
|
||||
|
||||
#include "ngraph/except.hpp"
|
||||
#include "onnx_editor/edge_mapper.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::onnx_editor;
|
||||
|
||||
onnx_editor::EdgeMapper::EdgeMapper(const ONNX_NAMESPACE::GraphProto& graph_proto)
|
||||
: m_node_inputs(graph_proto.node().size())
|
||||
, m_node_outputs(graph_proto.node().size())
|
||||
{
|
||||
int topological_index = 0;
|
||||
for (const auto& node_proto : graph_proto.node())
|
||||
{
|
||||
for (const auto& out_name : node_proto.output())
|
||||
{
|
||||
// node output name is unique
|
||||
m_node_output_name_to_index.emplace(out_name, topological_index);
|
||||
m_node_outputs[topological_index].push_back(out_name);
|
||||
}
|
||||
for (const auto& in_name : node_proto.input())
|
||||
{
|
||||
m_node_inputs[topological_index].push_back(in_name);
|
||||
m_output_consumers_index.emplace(in_name, topological_index);
|
||||
}
|
||||
if (!node_proto.name().empty())
|
||||
{
|
||||
// node name can identify node, but it can be ambiguous
|
||||
m_node_name_to_index.emplace(node_proto.name(), topological_index);
|
||||
}
|
||||
++topological_index;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> onnx_editor::EdgeMapper::find_node_indexes(const std::string& node_name,
|
||||
const std::string& output_name) const
|
||||
{
|
||||
if (!output_name.empty())
|
||||
{
|
||||
const auto& index_iter = m_node_output_name_to_index.find(output_name);
|
||||
if (index_iter != std::end(m_node_output_name_to_index))
|
||||
{
|
||||
return std::vector<int>{index_iter->second};
|
||||
}
|
||||
}
|
||||
std::vector<int> result;
|
||||
if (!node_name.empty())
|
||||
{
|
||||
const auto matched_nodes_range = m_node_name_to_index.equal_range(node_name);
|
||||
std::transform(matched_nodes_range.first,
|
||||
matched_nodes_range.second,
|
||||
std::back_inserter(result),
|
||||
[](const std::pair<std::string, int>& iter) { return iter.second; });
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
std::string onnx_editor::EdgeMapper::get_node_output_name(int node_index, int output_index) const
|
||||
{
|
||||
if (node_index >= static_cast<int>(m_node_outputs.size()))
|
||||
{
|
||||
throw ngraph_error("Node with index: " + std::to_string(node_index) +
|
||||
"is out of scope outputs list");
|
||||
}
|
||||
if (output_index >= static_cast<int>(m_node_outputs[node_index].size()))
|
||||
{
|
||||
throw ngraph_error("Node with index: " + std::to_string(node_index) +
|
||||
" has not output with index: " + std::to_string(output_index));
|
||||
}
|
||||
const auto output_name = m_node_outputs[node_index][output_index];
|
||||
return output_name;
|
||||
}
|
||||
|
||||
std::string onnx_editor::EdgeMapper::get_node_input_name(int node_index, int input_index) const
|
||||
{
|
||||
if (node_index >= static_cast<int>(m_node_inputs.size()))
|
||||
{
|
||||
throw ngraph_error("Node with index: " + std::to_string(node_index) +
|
||||
"is out of scope inputs list");
|
||||
}
|
||||
if (input_index >= static_cast<int>(m_node_inputs[node_index].size()))
|
||||
{
|
||||
throw ngraph_error("Node with index: " + std::to_string(node_index) +
|
||||
" has not input with index: " + std::to_string(input_index));
|
||||
}
|
||||
const auto input_name = m_node_inputs[node_index][input_index];
|
||||
return input_name;
|
||||
}
|
||||
|
||||
InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node,
|
||||
const EditorInput& in) const
|
||||
{
|
||||
// identification can be both based on node name and output name
|
||||
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
|
||||
int node_index = -1;
|
||||
if (node_indexes.size() == 1)
|
||||
{
|
||||
node_index = node_indexes[0];
|
||||
}
|
||||
else if (node_indexes.empty())
|
||||
{
|
||||
throw ngraph_error(
|
||||
"Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" was not found");
|
||||
}
|
||||
else if (!in.m_input_name
|
||||
.empty()) // input indexes are not deterministic if a node name is ambiguous
|
||||
{
|
||||
// many nodes with the same name
|
||||
// check if some of found index matches input name
|
||||
int matched_inputs_number = 0;
|
||||
for (const auto& index : node_indexes)
|
||||
{
|
||||
if (std::count(std::begin(m_node_inputs[index]),
|
||||
std::end(m_node_inputs[index]),
|
||||
in.m_input_name) > 0)
|
||||
{
|
||||
node_index = index;
|
||||
++matched_inputs_number;
|
||||
}
|
||||
}
|
||||
if (matched_inputs_number == 0)
|
||||
{
|
||||
throw ngraph_error("Input edge described by: " + node.m_node_name +
|
||||
" and input name: " + in.m_input_name + " was not found");
|
||||
}
|
||||
if (matched_inputs_number > 1)
|
||||
{
|
||||
throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " +
|
||||
in.m_input_name + " are ambiguous to determine input edge");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Given node name: " + node.m_node_name +
|
||||
" and input index: " + std::to_string(in.m_input_index) +
|
||||
" are ambiguous to determine input edge");
|
||||
}
|
||||
if (!in.m_input_name.empty())
|
||||
{
|
||||
return InputEdge{node_index, in.m_input_name};
|
||||
}
|
||||
else if (in.m_input_index != -1) // input index is set
|
||||
{
|
||||
const auto& input_name = get_node_input_name(node_index, in.m_input_index);
|
||||
return InputEdge{node_index, input_name};
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Not enough information to determine input edge");
|
||||
}
|
||||
}
|
||||
|
||||
OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node,
|
||||
const EditorOutput& out) const
|
||||
{
|
||||
// identification can be both based on node name and output name
|
||||
const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name);
|
||||
int node_index = -1;
|
||||
if (node_indexes.size() == 1)
|
||||
{
|
||||
node_index = node_indexes[0];
|
||||
}
|
||||
else if (node_indexes.empty())
|
||||
{
|
||||
throw ngraph_error(
|
||||
"Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) +
|
||||
" and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) +
|
||||
" was not found");
|
||||
}
|
||||
else if (!out.m_output_name
|
||||
.empty()) // output indexes are not deterministic if a node name is ambiguous
|
||||
{
|
||||
// many nodes with the same name
|
||||
// check if some of found index matches output name
|
||||
int matched_outputs_number = 0;
|
||||
for (const auto& index : node_indexes)
|
||||
{
|
||||
if (std::count(std::begin(m_node_outputs[index]),
|
||||
std::end(m_node_outputs[index]),
|
||||
out.m_output_name) > 0)
|
||||
{
|
||||
node_index = index;
|
||||
++matched_outputs_number;
|
||||
}
|
||||
}
|
||||
if (matched_outputs_number == 0)
|
||||
{
|
||||
throw ngraph_error("Output edge described by: " + node.m_node_name +
|
||||
" and output name: " + out.m_output_name + " was not found");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Given node name: " + node.m_node_name +
|
||||
" and output index: " + std::to_string(out.m_output_index) +
|
||||
" are ambiguous to determine output edge");
|
||||
}
|
||||
if (!out.m_output_name.empty())
|
||||
{
|
||||
return OutputEdge{node_index, out.m_output_name};
|
||||
}
|
||||
else if (out.m_output_index != -1) // output index is set
|
||||
{
|
||||
const auto& output_name = get_node_output_name(node_index, out.m_output_index);
|
||||
return OutputEdge{node_index, output_name};
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Not enough information to determine output edge");
|
||||
}
|
||||
}
|
||||
|
||||
OutputEdge onnx_editor::EdgeMapper::find_output_edge(const std::string& output_name) const
|
||||
{
|
||||
return find_output_edge(EditorNode{EditorOutput{output_name}}, EditorOutput{output_name});
|
||||
}
|
||||
|
||||
std::vector<InputEdge>
|
||||
onnx_editor::EdgeMapper::find_output_consumers(const std::string& output_name) const
|
||||
{
|
||||
const auto matched_nodes_range = m_output_consumers_index.equal_range(output_name);
|
||||
std::vector<InputEdge> input_edges;
|
||||
std::transform(matched_nodes_range.first,
|
||||
matched_nodes_range.second,
|
||||
std::back_inserter(input_edges),
|
||||
[&output_name](const std::pair<std::string, int>& iter) {
|
||||
return InputEdge{iter.second, output_name};
|
||||
});
|
||||
return input_edges;
|
||||
}
|
||||
|
||||
bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const
|
||||
{
|
||||
return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1;
|
||||
}
|
@ -10,10 +10,12 @@
|
||||
#include "ngraph/log.hpp"
|
||||
#include "onnx_common/parser.hpp"
|
||||
#include "onnx_common/utils.hpp"
|
||||
#include "onnx_editor/edge_mapper.hpp"
|
||||
#include "onnx_editor/editor.hpp"
|
||||
#include "onnx_import/utils/onnx_internal.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::onnx_editor;
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -186,6 +188,8 @@ namespace
|
||||
struct onnx_editor::ONNXModelEditor::Impl
|
||||
{
|
||||
ONNX_NAMESPACE::ModelProto m_model_proto;
|
||||
EdgeMapper m_edge_mapper;
|
||||
bool m_is_mapper_updated = false;
|
||||
|
||||
Impl() = delete;
|
||||
|
||||
@ -285,6 +289,7 @@ void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vector<InputEdg
|
||||
editor.extract_subgraph(outputs);
|
||||
|
||||
m_pimpl->remove_shape_inference_info();
|
||||
m_pimpl->m_is_mapper_updated = false;
|
||||
}
|
||||
|
||||
std::vector<std::string> onnx_editor::ONNXModelEditor::model_inputs() const
|
||||
@ -344,3 +349,45 @@ void onnx_editor::ONNXModelEditor::set_input_values(
|
||||
modify_initializer(*onnx_initializer, name, values, onnx_input);
|
||||
}
|
||||
}
|
||||
|
||||
void onnx_editor::ONNXModelEditor::update_mapper_if_needed() const
|
||||
{
|
||||
if (!m_pimpl->m_is_mapper_updated)
|
||||
{
|
||||
m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto.graph());
|
||||
}
|
||||
m_pimpl->m_is_mapper_updated = true;
|
||||
}
|
||||
|
||||
InputEdge onnx_editor::ONNXModelEditor::find_input_edge(const EditorNode& node,
|
||||
const EditorInput& input) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.find_input_edge(node, input);
|
||||
}
|
||||
|
||||
OutputEdge onnx_editor::ONNXModelEditor::find_output_edge(const EditorNode& node,
|
||||
const EditorOutput& input) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.find_output_edge(node, input);
|
||||
}
|
||||
|
||||
OutputEdge onnx_editor::ONNXModelEditor::find_output_edge(const std::string& output_name) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.find_output_edge(output_name);
|
||||
}
|
||||
|
||||
std::vector<InputEdge>
|
||||
onnx_editor::ONNXModelEditor::find_output_consumers(const std::string& output_name) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.find_output_consumers(output_name);
|
||||
}
|
||||
|
||||
bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorNode& node) const
|
||||
{
|
||||
update_mapper_if_needed();
|
||||
return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node);
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ graph {
|
||||
input: "conv1/7x7_s2_w_0"
|
||||
input: "conv1/7x7_s2_b_0"
|
||||
output: "conv1/7x7_s2_1"
|
||||
name: ""
|
||||
name: "conv1"
|
||||
op_type: "Conv"
|
||||
attribute {
|
||||
name: "strides"
|
||||
@ -34,13 +34,13 @@ graph {
|
||||
node {
|
||||
input: "conv1/7x7_s2_1"
|
||||
output: "conv1/7x7_s2_2"
|
||||
name: ""
|
||||
name: "relu1"
|
||||
op_type: "Relu"
|
||||
}
|
||||
node {
|
||||
input: "conv1/7x7_s2_2"
|
||||
output: "pool1/3x3_s2_1"
|
||||
name: ""
|
||||
name: "maxpool1"
|
||||
op_type: "MaxPool"
|
||||
attribute {
|
||||
name: "strides"
|
||||
|
@ -5,12 +5,14 @@ graph {
|
||||
input: "in1"
|
||||
output: "relu1"
|
||||
op_type: "Relu"
|
||||
name: "relu1_name"
|
||||
}
|
||||
node {
|
||||
input: "relu1"
|
||||
input: "in2"
|
||||
output: "add1"
|
||||
op_type: "Add"
|
||||
name: "add_ambiguous_name"
|
||||
}
|
||||
node {
|
||||
input: "in3"
|
||||
@ -23,12 +25,14 @@ graph {
|
||||
input: "add1"
|
||||
output: "add2"
|
||||
op_type: "Add"
|
||||
name: "add_ambiguous_name"
|
||||
}
|
||||
node {
|
||||
input: "add1"
|
||||
input: "conv1"
|
||||
output: "mul2"
|
||||
op_type: "Mul"
|
||||
name: ""
|
||||
}
|
||||
node {
|
||||
input: "add2"
|
||||
@ -40,6 +44,7 @@ graph {
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
name: "split_name"
|
||||
}
|
||||
node {
|
||||
input: "relu1"
|
||||
|
@ -5,6 +5,7 @@ graph {
|
||||
input: "in1"
|
||||
output: "relu1"
|
||||
op_type: "Relu"
|
||||
name: "add1"
|
||||
}
|
||||
node {
|
||||
input: "in1"
|
||||
@ -20,12 +21,14 @@ graph {
|
||||
input: "in2"
|
||||
output: "relu4"
|
||||
op_type: "Relu"
|
||||
name: "relu4_name"
|
||||
}
|
||||
node {
|
||||
input: "relu1"
|
||||
input: "relu2"
|
||||
output: "add1"
|
||||
op_type: "Add"
|
||||
name: "add1_name"
|
||||
}
|
||||
node {
|
||||
input: "relu2"
|
||||
|
@ -21,8 +21,7 @@
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::onnx_import;
|
||||
using namespace ngraph::onnx_editor;
|
||||
using namespace onnx_editor;
|
||||
using namespace ngraph::test;
|
||||
|
||||
static std::string s_manifest = "${MANIFEST}";
|
||||
@ -658,6 +657,478 @@ NGRAPH_TEST(onnx_editor, subgraph__inputs_getter)
|
||||
EXPECT_EQ(editor.model_inputs(), (std::vector<std::string>{"conv1/7x7_s2_1"}));
|
||||
}
|
||||
|
||||
// HIGHT LEVEL API TESTS
|
||||
// INPUT EDGES TEST
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_output_name_and_input_name)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
|
||||
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_2"}},
|
||||
EditorInput{"conv1/7x7_s2_1"});
|
||||
EXPECT_EQ(edge.m_node_idx, 1);
|
||||
EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1");
|
||||
|
||||
const InputEdge edge2 = editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}},
|
||||
EditorInput{"data_0"});
|
||||
EXPECT_EQ(edge2.m_node_idx, 0);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "data_0");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_output_name_and_input_index)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
|
||||
|
||||
const InputEdge edge =
|
||||
editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_2"}}, EditorInput{0});
|
||||
EXPECT_EQ(edge.m_node_idx, 1);
|
||||
EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1");
|
||||
|
||||
const InputEdge edge2 =
|
||||
editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}}, EditorInput{1});
|
||||
EXPECT_EQ(edge2.m_node_idx, 0);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "conv1/7x7_s2_w_0");
|
||||
|
||||
const InputEdge edge3 =
|
||||
editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}}, EditorInput{2});
|
||||
EXPECT_EQ(edge3.m_node_idx, 0);
|
||||
EXPECT_EQ(edge3.m_tensor_name, "conv1/7x7_s2_b_0");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_name)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
|
||||
|
||||
const InputEdge edge =
|
||||
editor.find_input_edge(EditorNode{"relu1"}, EditorInput{"conv1/7x7_s2_1"});
|
||||
EXPECT_EQ(edge.m_node_idx, 1);
|
||||
EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1");
|
||||
|
||||
const InputEdge edge2 =
|
||||
editor.find_input_edge(EditorNode{"conv1"}, EditorInput{"conv1/7x7_s2_w_0"});
|
||||
EXPECT_EQ(edge2.m_node_idx, 0);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "conv1/7x7_s2_w_0");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_index)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{"relu1_name"}, EditorInput{0});
|
||||
EXPECT_EQ(edge.m_node_idx, 0);
|
||||
EXPECT_EQ(edge.m_tensor_name, "in1");
|
||||
|
||||
const InputEdge edge2 = editor.find_input_edge(EditorNode{"split_name"}, EditorInput{0});
|
||||
EXPECT_EQ(edge2.m_node_idx, 5);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "add2");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
try
|
||||
{
|
||||
const InputEdge edge =
|
||||
editor.find_input_edge(EditorNode{""}, EditorInput{"conv1/7x7_s2_1"});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(
|
||||
msg.find("Node with name: not_given and output_name: not_given was not found") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
// OUTPUT EDGES TEST
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_output_name)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const OutputEdge edge =
|
||||
editor.find_output_edge(EditorNode{EditorOutput{"mul2"}}, EditorOutput{"mul2"});
|
||||
EXPECT_EQ(edge.m_node_idx, 4);
|
||||
EXPECT_EQ(edge.m_tensor_name, "mul2");
|
||||
|
||||
const OutputEdge edge2 =
|
||||
editor.find_output_edge(EditorNode{EditorOutput{"split1"}}, EditorOutput{"split2"});
|
||||
EXPECT_EQ(edge2.m_node_idx, 5);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "split2");
|
||||
|
||||
// simplified overload
|
||||
const OutputEdge edge3 =
|
||||
editor.find_output_edge("mul2");
|
||||
EXPECT_EQ(edge3.m_node_idx, 4);
|
||||
EXPECT_EQ(edge3.m_tensor_name, "mul2");
|
||||
|
||||
const OutputEdge edge4 =
|
||||
editor.find_output_edge("split2");
|
||||
EXPECT_EQ(edge4.m_node_idx, 5);
|
||||
EXPECT_EQ(edge4.m_tensor_name, "split2");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_output_name_and_output_index)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const OutputEdge edge =
|
||||
editor.find_output_edge(EditorNode{EditorOutput{"add2"}}, EditorOutput{0});
|
||||
EXPECT_EQ(edge.m_node_idx, 3);
|
||||
EXPECT_EQ(edge.m_tensor_name, "add2");
|
||||
|
||||
const OutputEdge edge2 =
|
||||
editor.find_output_edge(EditorNode{EditorOutput{"split1"}}, EditorOutput{1});
|
||||
EXPECT_EQ(edge2.m_node_idx, 5);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "split2");
|
||||
|
||||
const OutputEdge edge3 =
|
||||
editor.find_output_edge(EditorNode{EditorOutput{"split2"}}, EditorOutput{0});
|
||||
EXPECT_EQ(edge3.m_node_idx, 5);
|
||||
EXPECT_EQ(edge3.m_tensor_name, "split1");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_name)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const OutputEdge edge =
|
||||
editor.find_output_edge(EditorNode{"relu1_name"}, EditorOutput{"relu1"});
|
||||
EXPECT_EQ(edge.m_node_idx, 0);
|
||||
EXPECT_EQ(edge.m_tensor_name, "relu1");
|
||||
|
||||
const OutputEdge edge2 =
|
||||
editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{"split2"});
|
||||
EXPECT_EQ(edge2.m_node_idx, 5);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "split2");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_index)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const OutputEdge edge = editor.find_output_edge(EditorNode{"relu1_name"}, EditorOutput{0});
|
||||
EXPECT_EQ(edge.m_node_idx, 0);
|
||||
EXPECT_EQ(edge.m_tensor_name, "relu1");
|
||||
|
||||
const OutputEdge edge2 = editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{1});
|
||||
EXPECT_EQ(edge2.m_node_idx, 5);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "split2");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
|
||||
|
||||
const InputEdge edge =
|
||||
editor.find_input_edge(EditorNode{EditorOutput{"relu4"}}, EditorInput{0});
|
||||
EXPECT_EQ(edge.m_node_idx, 3);
|
||||
EXPECT_EQ(edge.m_tensor_name, "in2");
|
||||
|
||||
const OutputEdge edge2 = editor.find_output_edge(EditorNode{"relu4_name"}, EditorOutput{0});
|
||||
EXPECT_EQ(edge2.m_node_idx, 3);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "relu4");
|
||||
|
||||
const OutputEdge edge3 = editor.find_output_edge(EditorNode{"add1_name"}, EditorOutput{0});
|
||||
EXPECT_EQ(edge3.m_node_idx, 4);
|
||||
EXPECT_EQ(edge3.m_tensor_name, "add1");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_edge_error_handling)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
|
||||
|
||||
// node with given output name not found
|
||||
try
|
||||
{
|
||||
const InputEdge edge =
|
||||
editor.find_input_edge(EditorNode{EditorOutput{"not_existed"}}, EditorInput{0});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(
|
||||
msg.find("Node with name: not_given and output_name: not_existed was not found") !=
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
// node with given name not found
|
||||
try
|
||||
{
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{"not_existed"}, EditorInput{0});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(
|
||||
msg.find("Node with name: not_existed and output_name: not_given was not found") !=
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
// input index out of scope
|
||||
try
|
||||
{
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{"relu4_name"}, EditorInput{1});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Node with index: 3 has not input with index: 1") !=
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
// output index out of scope
|
||||
try
|
||||
{
|
||||
const OutputEdge edge =
|
||||
editor.find_output_edge(EditorNode{"relu4_name"}, EditorOutput{1});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Node with index: 3 has not output with index: 1") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
// Nodes with ambiguous node names tests
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_but_matched_input)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"in2"});
|
||||
EXPECT_EQ(edge.m_node_idx, 1);
|
||||
EXPECT_EQ(edge.m_tensor_name, "in2");
|
||||
|
||||
const InputEdge edge2 = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"add1"});
|
||||
EXPECT_EQ(edge2.m_node_idx, 3);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "add1");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_and_not_matched_input)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
try
|
||||
{
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"in3"});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Input edge described by: add_ambiguous_name and input name: in3 was not found") !=
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"relu1"});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and input name: relu1 are ambiguous to determine input edge") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_and_input_index)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
try
|
||||
{
|
||||
const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{0});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and input index: 0 are ambiguous to determine input edge") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_but_matched_output)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"add1"});
|
||||
EXPECT_EQ(edge.m_node_idx, 1);
|
||||
EXPECT_EQ(edge.m_tensor_name, "add1");
|
||||
|
||||
const OutputEdge edge2 = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"add2"});
|
||||
EXPECT_EQ(edge2.m_node_idx, 3);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "add2");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_the_same_node_name_and_output_name)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
|
||||
|
||||
const OutputEdge edge = editor.find_output_edge(EditorNode{"add1"}, EditorOutput{0});
|
||||
EXPECT_EQ(edge.m_node_idx, 0);
|
||||
EXPECT_EQ(edge.m_tensor_name, "relu1");
|
||||
|
||||
const OutputEdge edge2 = editor.find_output_edge(EditorNode{EditorOutput{"add1"}}, EditorOutput{0});
|
||||
EXPECT_EQ(edge2.m_node_idx, 4);
|
||||
EXPECT_EQ(edge2.m_tensor_name, "add1");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_and_not_matched_output)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
try
|
||||
{
|
||||
const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"split2"});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Output edge described by: add_ambiguous_name and output name: split2 was not found") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_and_output_index)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
try
|
||||
{
|
||||
const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{0});
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::string msg{e.what()};
|
||||
EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and output index: 0 are ambiguous to determine output edge") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_use_edge_mapper_with_graph_cutter)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
// InputEdge{1, "in2"}
|
||||
const auto input_edge_1 = editor.find_input_edge(
|
||||
EditorNode(EditorOutput("add1")), EditorInput(1));
|
||||
// InputEdge{2, "in3"}
|
||||
const auto input_edge_2 = editor.find_input_edge(
|
||||
EditorNode(EditorOutput("conv1")), EditorInput(0));
|
||||
|
||||
|
||||
const auto output_edge = editor.find_output_edge(
|
||||
EditorNode(EditorOutput("mul2")), EditorOutput(0));
|
||||
// OutputEdge{4, "mul2"}
|
||||
editor.cut_graph_fragment({input_edge_1, input_edge_2}, {output_edge});
|
||||
|
||||
const auto ref_model =
|
||||
file_util::path_join(SERIALIZED_ZOO,
|
||||
"onnx/model_editor/reference/"
|
||||
"subgraph__existing_inputs_and_outputs_based_extraction.prototxt");
|
||||
|
||||
const auto result = compare_onnx_models(editor.model_string(), ref_model);
|
||||
|
||||
EXPECT_TRUE(result.is_ok) << result.error_message;
|
||||
|
||||
// check if mapper was updated after the model changed
|
||||
const auto input_edge_4 = editor.find_input_edge(
|
||||
EditorNode(EditorOutput("relu1")), EditorInput(0));
|
||||
EXPECT_EQ(input_edge_4.m_node_idx, 0);
|
||||
EXPECT_EQ(input_edge_4.m_tensor_name, "in1");
|
||||
|
||||
const auto input_edge_5 = editor.find_input_edge(
|
||||
EditorNode(EditorOutput("add1")), EditorInput(1));
|
||||
EXPECT_EQ(input_edge_5.m_node_idx, 1);
|
||||
EXPECT_EQ(input_edge_5.m_tensor_name, "in2");
|
||||
|
||||
const auto output_edge_3 = editor.find_output_edge("mul2");
|
||||
EXPECT_EQ(output_edge_3.m_node_idx, 3);
|
||||
EXPECT_EQ(output_edge_3.m_tensor_name, "mul2");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
std::vector<InputEdge> output_consumers = editor.find_output_consumers("relu1");
|
||||
EXPECT_EQ(output_consumers.size(), 3);
|
||||
EXPECT_EQ(output_consumers[0].m_node_idx, 1);
|
||||
EXPECT_EQ(output_consumers[0].m_tensor_name, "relu1");
|
||||
EXPECT_EQ(output_consumers[1].m_node_idx, 3);
|
||||
EXPECT_EQ(output_consumers[1].m_tensor_name, "relu1");
|
||||
EXPECT_EQ(output_consumers[2].m_node_idx, 6);
|
||||
EXPECT_EQ(output_consumers[2].m_tensor_name, "relu1");
|
||||
|
||||
output_consumers = editor.find_output_consumers("add1");
|
||||
EXPECT_EQ(output_consumers.size(), 2);
|
||||
EXPECT_EQ(output_consumers[0].m_node_idx, 3);
|
||||
EXPECT_EQ(output_consumers[0].m_tensor_name, "add1");
|
||||
EXPECT_EQ(output_consumers[1].m_node_idx, 4);
|
||||
EXPECT_EQ(output_consumers[1].m_tensor_name, "add1");
|
||||
|
||||
output_consumers = editor.find_output_consumers("in3");
|
||||
EXPECT_EQ(output_consumers.size(), 1);
|
||||
EXPECT_EQ(output_consumers[0].m_node_idx, 2);
|
||||
EXPECT_EQ(output_consumers[0].m_tensor_name, "in3");
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers_empty_result)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
const std::vector<InputEdge> output_consumers = editor.find_output_consumers("not_existed");
|
||||
EXPECT_EQ(output_consumers.size(), 0);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node)
|
||||
{
|
||||
ONNXModelEditor editor{file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
|
||||
|
||||
bool is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"relu1"}});
|
||||
EXPECT_EQ(is_correct_node, true);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"mul2"}});
|
||||
EXPECT_EQ(is_correct_node, true);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"split2"}});
|
||||
EXPECT_EQ(is_correct_node, true);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"});
|
||||
EXPECT_EQ(is_correct_node, true);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}});
|
||||
EXPECT_EQ(is_correct_node, false);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"add_ambiguous_name"});
|
||||
EXPECT_EQ(is_correct_node, false);
|
||||
|
||||
is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"not_exist"});
|
||||
EXPECT_EQ(is_correct_node, false);
|
||||
}
|
||||
|
||||
using TestEngine = test::INTERPRETER_Engine;
|
||||
|
||||
NGRAPH_TEST(onnx_editor, values__append_one_initializer)
|
||||
|
Loading…
Reference in New Issue
Block a user