Subgraph extraction in ONNX model editor (#4107)

This commit is contained in:
Tomasz Dołbniak 2021-03-04 10:12:37 +01:00 committed by GitHub
parent 819bdbe4eb
commit a085b68d37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 3984 additions and 9 deletions

View File

@ -0,0 +1,164 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
namespace ONNX_NAMESPACE
{
class GraphProto;
class NodeProto;
class ValueInfoProto;
} // namespace ONNX_NAMESPACE
namespace ngraph
{
enum class EdgeType
{
INPUT,
OUTPUT
};
template <EdgeType>
struct Edge
{
Edge() = delete;
Edge(const int node_idx, std::string tensor_name)
: m_node_idx{node_idx}
, m_tensor_name{std::move(tensor_name)}
{
}
const int m_node_idx;
const std::string m_tensor_name;
};
namespace onnx_import
{
/// \brief Defines an edge connected to an input of any node in the graph.
/// It consists of a node index in the processed ONNX model and the input name.
/// The index should point to a node in the topological sort of the underlying graph
/// which means it has to be in range: 0 <= node_idx < graph.node_size()
///
/// For a node number 5, with 3 inputs:
///
/// ----(in_A)----> +--------+
/// ----(in_B)----> | node 5 | ----(out)---->
/// ----(in_C)----> +--------+
///
/// there are 3 possible valid instances of this struct:
/// InputEdge(5, "in_A")
/// InputEdge(5, "in_B")
/// InputEdge(5, "in_C")
using InputEdge = Edge<EdgeType::INPUT>;
/// \brief Defines an edge connected to an output of any node in the graph.
/// It consists of a node index in the processed ONNX model and the output name.
///
/// For a node number 5, with 2 outputs:
///
/// +--------+ ----(out1)---->
/// ----(in_A)----> | node 5 |
/// +--------+ ----(out2)---->
///
/// there are 2 possible valid instances of this struct:
/// OutputEdge(5, "out1")
/// OutputEdge(5, "out2")
using OutputEdge = Edge<EdgeType::OUTPUT>;
/// \brief Subgraph extraction helper structure
struct SubgraphExtractor
{
SubgraphExtractor(ONNX_NAMESPACE::GraphProto& graph);
/// \brief Adds new inputs to the graph and connects them to the nodes indicated by
/// the provided input edges.
void add_new_inputs(const std::vector<InputEdge>& new_inputs);
/// \brief Adds new outputs to the graph with the same name as the nodes pointed to
/// by the input edges "new_outputs".
void add_new_outputs(const std::vector<OutputEdge>& new_outputs);
/// \brief Extracts the final subgraph by traversing the original model bottom-up
/// starting at each of the provided output edges. The extracted subgraph
/// contains all previously added inputs and potentially a subset of original
/// model's inputs that contribute to the value calculated in the output tensors.
/// In the end the underlying GraphProto is modified and obsolete elements
/// are discarded after this method call has finished.
///
/// \param subgraph_outputs A list of expected outputs of the extracted subgraph.
void extract_subgraph(std::vector<OutputEdge> subgraph_outputs);
/// \brief Represents a subgraph of an ONNX model by holding a subset of nodes, inputs,
/// outputs and initializers of the original graph. Objects of this struct can be
/// merged into other instances using the += operator to build a subgraph from
/// smaller clusters.
struct SubgraphComponents
{
SubgraphComponents() = default;
SubgraphComponents(const SubgraphComponents&) = delete;
SubgraphComponents(SubgraphComponents&&) = default;
SubgraphComponents& operator=(const SubgraphComponents&) = delete;
SubgraphComponents& operator=(SubgraphComponents&&) = default;
std::set<int> nodes;
std::set<std::string> inputs;
std::set<std::string> initializers;
std::set<std::string> outputs;
SubgraphComponents& operator+=(SubgraphComponents&& other)
{
nodes.insert(other.nodes.begin(), other.nodes.end());
inputs.insert(other.inputs.begin(), other.inputs.end());
initializers.insert(other.initializers.begin(), other.initializers.end());
outputs.insert(other.outputs.begin(), other.outputs.end());
return *this;
}
};
private:
ONNX_NAMESPACE::GraphProto& m_onnx_graph;
// Graph traversal helper: node index -> node inputs (one-to-many)
std::unordered_multimap<int, std::string> m_node_inputs;
// Number of consumers of all tensors in the graph
std::map<std::string, int> m_tensor_consumers;
/// \brief Replaces the old input edge with a new one in the helper struct.
/// This is used by the output contributors discovery.
void replace_input_edge(const InputEdge& old_edge, const InputEdge& new_edge);
/// \brief Returns a list of edges of each outputs of the graph "m_onnx_graph"
std::vector<OutputEdge> all_output_edges() const;
/// \brief Traverses the graph bottom-up and collects all nodes, inputs and initializers
/// that contribute to an output designated by the provided output edge.
/// A sum of such SubgraphComponents objects forms a target extracted subgraph.
SubgraphComponents
discover_output_contributors(const OutputEdge& output_edge,
const SubgraphComponents& already_collected) const;
/// \brief Modifies the underlying GraphProto object and discards all obsolete elements.
///
/// \param subgraph An object describing the subgraph to be extracted (elems to be kept)
void extract_subgraph_from_onnx_model(const SubgraphComponents& subgraph);
};
} // namespace onnx_import
} // namespace ngraph

View File

@ -23,6 +23,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "onnx_import/editor/detail/subgraph_extraction.hpp"
#include "onnx_import/utils/onnx_importer_visibility.hpp"
namespace ONNX_NAMESPACE
@ -53,7 +54,7 @@ namespace ngraph
/// \param model_path Path to the file containing the model.
ONNXModelEditor(const std::string& model_path);
/// \brief Modifies the in-memory representation of the model (m_model_proto) by setting
/// \brief Modifies the in-memory representation of the model by setting
/// custom input types for all inputs specified in the provided map.
///
/// \param input_types A collection of pairs {input_name: new_input_type} that should be
@ -62,7 +63,7 @@ namespace ngraph
/// the inputs specified in its parameter.
void set_input_types(const std::map<std::string, element::Type_t>& input_types);
/// \brief Modifies the in-memory representation of the model (m_model_proto) by setting
/// \brief Modifies the in-memory representation of the model by setting
/// custom input shapes for all inputs specified in the provided map.
///
/// \param input_shapes A collection of pairs {input_name: new_input_shape} that should
@ -71,6 +72,18 @@ namespace ngraph
/// the inputs specified in its parameter.
void set_input_shapes(const std::map<std::string, ngraph::PartialShape>& input_shapes);
/// \brief Extracts a subgraph constrained by input edges and output edges. In the end
/// the underlying ModelProto is modified - obsolete inputs, initializers, nodes
/// and outputs are removed from the in-memory model.
///
/// \node Please look at the declaration of InputEdge and OutputEdge for explanation
/// how those objects can be created. If the outputs parameter is empty
/// this method keeps all of the original outputs of the model.
///
/// \param inputs A collection of input edges which become new inputs to the graph
/// \param outputs A collection of output edges which become new outputs of the graph
void cut_graph_fragment(const std::vector<InputEdge>& inputs,
const std::vector<OutputEdge>& outputs);
/// \brief Modifies the in-memory representation of the model by setting custom input
/// values for inputs specified in the provided map.
///
@ -91,11 +104,20 @@ namespace ngraph
/// \return A reference to ONNX ModelProto object containing the in-memory model
ONNX_NAMESPACE::ModelProto& model() const;
/// \brief Returns a serialized ONNX model, possibly modified by the editor.
std::string model_string() const;
/// \brief Returns a list of all inputs of the in-memory model, including initializers.
/// The returned value might depend on the previous operations executed on an
/// instance of the model editor, in particular the subgraph extraction which
/// can discard some inputs and initializers from the original graph.
std::vector<std::string> model_inputs() const;
/// \brief Returns the path to the original model file
const std::string& model_path() const;
/// \brief Saves the possibly model held by this class to a file. Serializes in binary
/// mode.
/// \brief Saves the possibly modified model held by this class to a file.
/// Serializes in binary mode.
///
/// \param out_file_path A path to the file where the modified model should be dumped.
void serialize(const std::string& out_file_path) const;
@ -103,7 +125,7 @@ namespace ngraph
private:
const std::string m_model_path;
class Impl;
struct Impl;
std::unique_ptr<Impl, void (*)(Impl*)> m_pimpl;
};
} // namespace onnx_import

View File

@ -0,0 +1,498 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <onnx/onnx_pb.h>
#include <stack>
#include "ngraph/check.hpp"
#include "onnx_import/editor/detail/subgraph_extraction.hpp"
using namespace ngraph::onnx_import;
namespace
{
void validate_node_index(const ONNX_NAMESPACE::GraphProto& graph, const int node_idx)
{
NGRAPH_CHECK(
node_idx >= 0 && node_idx < graph.node_size(),
"The specified node index is out of range of nodes in the original model(idx: ",
std::to_string(node_idx),
"; nodes count in the model: ",
std::to_string(graph.node_size()),
")");
}
template <typename T>
std::function<bool(const T&)> name_equals(const std::string& name)
{
return [&name](const T& onnx_object) -> bool { return onnx_object.name() == name; };
}
const auto is_equal_to =
+[](const std::string& other) { return [&](const std::string& s) { return s == other; }; };
/// \brief Checks if an item with name equal to "name" already exists in the specified
/// container. A container item is expected to have a name() method.
template <typename Container>
bool already_exists(const Container& items, const std::string& name)
{
using std::begin;
using std::end;
return std::any_of(
begin(items), end(items), name_equals<typename Container::value_type>(name));
}
/// \brief Checks if a tensor with name "name" is produced by an input of the graph
bool is_graph_input(const ONNX_NAMESPACE::GraphProto& graph, const std::string& name)
{
return already_exists(graph.input(), name);
}
/// \brief Checks if a tensor with name "name" is produced by an initializer of the graph
bool is_graph_initializer(const ONNX_NAMESPACE::GraphProto& graph, const std::string& name)
{
return already_exists(graph.initializer(), name);
}
/// \brief Looks up the index of a node that produces a tensor "input_name". Used to traverse
/// the graph bottom-up. Starts from a node index "current_node_idx" because it operates
/// on a topologically sorted graph.
int find_source_node_idx(const ONNX_NAMESPACE::GraphProto& graph,
const int current_node_idx,
const std::string& input_name)
{
for (int i = current_node_idx - 1; i >= 0; --i)
{
const auto& outputs = graph.node(i).output();
const auto output_found =
std::any_of(std::begin(outputs), std::end(outputs), is_equal_to(input_name));
if (output_found)
{
return i;
}
}
throw ngraph::ngraph_error{"Source node not found in the graph for node: " +
std::to_string(current_node_idx) + " and input name: " +
input_name};
}
/// \brief Looks up a descriptor for a given tensor name. This descriptor contains inferred
/// shape information which is required to create new inputs and outputs in the graph.
const ONNX_NAMESPACE::ValueInfoProto&
find_tensor_descriptor(const ONNX_NAMESPACE::GraphProto& graph,
const std::string& tensor_name)
{
const auto it = std::find_if(std::begin(graph.value_info()),
std::end(graph.value_info()),
name_equals<ONNX_NAMESPACE::ValueInfoProto>(tensor_name));
NGRAPH_CHECK(it != std::end(graph.value_info()),
"Could not find a tensor descriptor for tensor '",
tensor_name,
"'. It's not possible to add a new input to the graph without the type and "
"shape information of the intermediate tensor.");
return *it;
}
/// \brief Inserts a new input to the graph and removes an initializer that produced a tensor
/// specified by an input edge passed to this function.
void replace_initializer_with_new_input(ONNX_NAMESPACE::GraphProto& graph,
const InputEdge& edge)
{
const auto it = std::find_if(std::begin(graph.initializer()),
std::end(graph.initializer()),
name_equals<ONNX_NAMESPACE::TensorProto>(edge.m_tensor_name));
NGRAPH_CHECK(it != std::end(graph.initializer()),
"Could not find an initializer in the graph: '",
edge.m_tensor_name);
if (!already_exists(graph.input(), edge.m_tensor_name))
{
const auto& initializer = *it;
auto& new_input = *(graph.add_input());
auto& new_input_tensor_type = *(new_input.mutable_type()->mutable_tensor_type());
new_input_tensor_type.set_elem_type(initializer.data_type());
auto& new_input_shape = *(new_input_tensor_type.mutable_shape());
for (const auto initializer_dim : initializer.dims())
{
auto& new_dim = *(new_input_shape.add_dim());
new_dim.set_dim_value(initializer_dim);
}
*(new_input.mutable_name()) = edge.m_tensor_name;
}
graph.mutable_initializer()->erase(it);
}
/// \brief Inserts a new input to the graph and connects it to the node designated by an input
/// edge passed to this function.
/// \return A new input edge (along with "true") if a new input was added to the graph,
/// false + the original edge otherwise.
std::pair<bool, InputEdge> append_new_graph_input(ONNX_NAMESPACE::GraphProto& graph,
const InputEdge& edge)
{
if (already_exists(graph.input(), edge.m_tensor_name) &&
!is_graph_initializer(graph, edge.m_tensor_name))
{
// no need to append a new input if an edge points to an existing one in the model
return {false, edge};
}
auto& target_node = *(graph.mutable_node(edge.m_node_idx));
auto& node_inputs = *(target_node.mutable_input());
auto target_input =
std::find(std::begin(node_inputs), std::end(node_inputs), edge.m_tensor_name);
NGRAPH_CHECK(target_input != std::end(node_inputs),
"Input '",
edge.m_tensor_name,
"' not found in the inputs of node ",
edge.m_node_idx,
". Cannot append a new graph input to this node.");
const std::string new_input_name = target_node.output(0) + ":" + edge.m_tensor_name;
// if an edge is connected to an initializer, the initializer is removed and substituted
// with an input
if (is_graph_initializer(graph, edge.m_tensor_name))
{
replace_initializer_with_new_input(graph, edge);
return {false, edge};
}
else
{
auto& new_input = *(graph.add_input());
// copy the intermediate tensor properties to the newly created input
new_input.MergeFrom(find_tensor_descriptor(graph, edge.m_tensor_name));
*(new_input.mutable_name()) = new_input_name;
// attach the new graph input to the target node's input
*target_input = new_input_name;
return {true, InputEdge{edge.m_node_idx, new_input_name}};
}
}
/// \brief Replaces a node or initializer (consumed by multiple nodes) with a new input
/// \return Returns an index of a removed node or -1 if an initializer was removed
int replace_source_with_new_input(ONNX_NAMESPACE::GraphProto& graph, const InputEdge& edge)
{
if (already_exists(graph.input(), edge.m_tensor_name) &&
!is_graph_initializer(graph, edge.m_tensor_name))
{
// happens when a user specifies multiple input edges pointing to the same tensor name
return -1;
}
if (is_graph_initializer(graph, edge.m_tensor_name))
{
replace_initializer_with_new_input(graph, edge);
}
else
{
auto& new_input = *(graph.add_input());
// copy the intermediate tensor properties to the newly created input
new_input.MergeFrom(find_tensor_descriptor(graph, edge.m_tensor_name));
const auto source_node_idx =
find_source_node_idx(graph, edge.m_node_idx, edge.m_tensor_name);
auto& source_node = *(graph.mutable_node(source_node_idx));
auto& node_outputs = *source_node.mutable_output();
auto target_output =
std::find(std::begin(node_outputs), std::end(node_outputs), edge.m_tensor_name);
NGRAPH_CHECK(target_output != std::end(node_outputs),
"Output '",
edge.m_tensor_name,
"' not found in the outputs of node ",
source_node_idx,
". Cannot remove the output from this node.");
// stop produsing tensor "edge.m_tensor_name" by the source node of the processed edge
*target_output = "";
return source_node_idx;
}
return -1;
}
/// \brief Adds new outputs to the ONNX graph for an edge specified by a user
/// The shape for this output is taken from a previously executed shape inference of the
/// original model.
void append_new_graph_output(ONNX_NAMESPACE::GraphProto& graph, const OutputEdge& edge)
{
if (already_exists(graph.output(), edge.m_tensor_name))
{
return;
}
auto& target_node = *(graph.mutable_node(edge.m_node_idx));
const auto& node_outputs = target_node.output();
const auto target_output =
std::find(std::begin(node_outputs), std::end(node_outputs), edge.m_tensor_name);
NGRAPH_CHECK(target_output != std::end(node_outputs),
"Output '",
edge.m_tensor_name,
"' not found in the outputs of node ",
edge.m_node_idx,
". Cannot append a new graph output to this node.");
auto& new_output = *(graph.add_output());
// copy the intermediate tensor's properties to the newly created
new_output.MergeFrom(find_tensor_descriptor(graph, edge.m_tensor_name));
*(new_output.mutable_name()) = edge.m_tensor_name;
}
/// \brief Removes all items from a container except the ones whose names are in items_to_keep
/// It's intended to work with ONNX graph inputs, outputs and initializers only.
template <typename Container>
void discard_by_name(Container& all_items, const std::set<std::string>& items_to_keep)
{
static_assert(
std::is_same<typename Container::value_type, ONNX_NAMESPACE::ValueInfoProto>::value ||
std::is_same<typename Container::value_type, ONNX_NAMESPACE::TensorProto>::value,
"Unsupported value type of the container");
// The tested item can be discarded if its name is not found in the items_to_keep set
const auto can_be_discarded = [&items_to_keep](const typename Container::value_type& item) {
return items_to_keep.count(item.name()) == 0;
};
using std::begin;
using std::end;
// move the elements-to-discard to the end of the container
const auto new_end = std::remove_if(begin(all_items), end(all_items), can_be_discarded);
// erase all of the discarded elements past the new end of the container
all_items.erase(new_end, end(all_items));
}
/// \brief Removes all nodes from a container keeping the ones whose index is in nodes_to_keep
template <typename Container>
void discard_nodes(Container& all_nodes, const std::set<int>& nodes_to_keep)
{
static_assert(
std::is_same<typename Container::value_type, ONNX_NAMESPACE::NodeProto>::value,
"Unsupported value type of the container");
int idx = 0;
const auto discard_node = [&idx, &nodes_to_keep](const typename Container::value_type&) {
return nodes_to_keep.count(idx++) == 0;
};
using std::begin;
using std::end;
const auto new_end = std::remove_if(begin(all_nodes), end(all_nodes), discard_node);
all_nodes.erase(new_end, end(all_nodes));
}
} // namespace
/* -----------------------------------------------------------------------------------------------*/
SubgraphExtractor::SubgraphExtractor(ONNX_NAMESPACE::GraphProto& graph)
: m_onnx_graph(graph)
{
// gathers information about the graph - input edges of every node and number of "consumers"
// of all tensors in the graph
for (int i = 0; i < graph.node_size(); ++i)
{
for (const auto& node_input : graph.node(i).input())
{
m_node_inputs.insert({i, node_input});
m_tensor_consumers[node_input] += 1;
}
}
}
void SubgraphExtractor::add_new_inputs(const std::vector<InputEdge>& new_inputs)
{
for (const auto& input_edge : new_inputs)
{
validate_node_index(m_onnx_graph, input_edge.m_node_idx);
// if a tensor has multiple consumers, its producer(source) should be replaced with a new
// input - this way all consumers of this tensor will now be connected to a new graph input
if (m_tensor_consumers[input_edge.m_tensor_name] > 1)
{
// remove a node or initializer from a model and insert a new input instead
int idx = replace_source_with_new_input(m_onnx_graph, input_edge);
if (idx != -1)
{
// if a node was replaced with an input, remove input edges from a helper multimap
// for this node because it won't end up in the target subgraph
// m_node_inputs stores information about existing edges in the graph,
// when a node is removed/replaced, information about its edges should also
// be removed (this way this node will be discarded from the original graph)
m_node_inputs.erase(idx);
}
}
else
{
// in case an edge is connected to a single node, a new graph input should be added
// and connected to that node; the new edge is an edge between the node and new input
const auto& new_edge = append_new_graph_input(m_onnx_graph, input_edge);
if (new_edge.first)
{
// the original edge should be replaced with a new one in the helper multimap
// this information will later be used during the subgraph extraction stage
replace_input_edge(input_edge, new_edge.second);
}
}
}
}
void SubgraphExtractor::add_new_outputs(const std::vector<OutputEdge>& new_outputs)
{
for (const auto& output_edge : new_outputs)
{
validate_node_index(m_onnx_graph, output_edge.m_node_idx);
append_new_graph_output(m_onnx_graph, output_edge);
}
}
void SubgraphExtractor::replace_input_edge(const InputEdge& old_edge, const InputEdge& new_edge)
{
// old_edge = {5, "x"}; new_edge = {5, "y"}
// for a given node index "N", find all of its inputs in the helper multimap (pair of iterators)
// using those iterators find the name of an input tensor that needs to be replaced
const auto node_inputs = m_node_inputs.equal_range(old_edge.m_node_idx);
auto old_input_name = node_inputs.first;
// find an iterator pointing to an input name that should
while (old_input_name->second != old_edge.m_tensor_name && old_input_name != node_inputs.second)
{
++old_input_name;
}
// finally remove the old edge from the helper map and insert a new edge
m_node_inputs.erase(old_input_name);
m_node_inputs.insert({new_edge.m_node_idx, new_edge.m_tensor_name});
}
void SubgraphExtractor::extract_subgraph(std::vector<OutputEdge> subgraph_outputs)
{
// when the user doesn't specify any outputs, all outputs of the original graph should be kept
if (subgraph_outputs.empty())
{
subgraph_outputs = all_output_edges();
}
SubgraphComponents subgraph;
for (const auto& output_edge : subgraph_outputs)
{
// for each output edge find the nodes, inputs and initializers that contribute to the value
// produced by this output - "output contributors"
// a sum of all contributors of all outputs is the target subgraph
subgraph += discover_output_contributors(output_edge, subgraph);
}
// using the subgraph components collected above, modify the underlying GraphProto
extract_subgraph_from_onnx_model(subgraph);
}
SubgraphExtractor::SubgraphComponents SubgraphExtractor::discover_output_contributors(
const OutputEdge& output_edge, const SubgraphComponents& already_collected) const
{
const auto already_visited = [&already_collected](const int node_index) {
return already_collected.nodes.count(node_index) > 0;
};
SubgraphComponents output_contributors;
output_contributors.outputs.insert(output_edge.m_tensor_name);
// reverse DFS graph traversal
std::stack<int> nodes_to_visit;
nodes_to_visit.push(output_edge.m_node_idx);
while (!nodes_to_visit.empty())
{
const auto n = nodes_to_visit.top();
nodes_to_visit.pop();
// if a node has already been visited, return early because it's already marked as
// a node to keep in the final extracted subgraph
if (already_visited(n))
{
continue;
}
output_contributors.nodes.insert(n);
// check if the visitor reached any of the graph inputs
// and/or keep looking for more contributors further up in the graph
// when an input or initializer is reached, the visitor stops the lookup
const auto n_inputs = m_node_inputs.equal_range(n);
for (auto input_name = n_inputs.first; input_name != n_inputs.second; ++input_name)
{
if (is_graph_input(m_onnx_graph, input_name->second))
{
output_contributors.inputs.insert(input_name->second);
// when an initializer has a matching graph input
if (is_graph_initializer(m_onnx_graph, input_name->second))
{
output_contributors.initializers.insert(input_name->second);
}
}
else if (is_graph_initializer(m_onnx_graph, input_name->second))
{
// when an initializer doesn't have a corresponding input
output_contributors.initializers.insert(input_name->second);
}
else
{
// if an edge points to another node (source node) it should be visited
// in one of the future iterations
nodes_to_visit.push(find_source_node_idx(m_onnx_graph, n, input_name->second));
}
}
}
return output_contributors;
}
void SubgraphExtractor::extract_subgraph_from_onnx_model(const SubgraphComponents& subgraph)
{
discard_by_name(*(m_onnx_graph.mutable_input()), subgraph.inputs);
discard_by_name(*(m_onnx_graph.mutable_initializer()), subgraph.initializers);
discard_by_name(*(m_onnx_graph.mutable_output()), subgraph.outputs);
discard_nodes(*(m_onnx_graph.mutable_node()), subgraph.nodes);
}
std::vector<OutputEdge> SubgraphExtractor::all_output_edges() const
{
std::vector<OutputEdge> all_outputs;
for (const auto& graph_output : m_onnx_graph.output())
{
all_outputs.emplace_back(
find_source_node_idx(m_onnx_graph, m_onnx_graph.node_size(), graph_output.name()),
graph_output.name());
}
return all_outputs;
}

View File

@ -16,6 +16,7 @@
#include <fstream>
#include <onnx/onnx_pb.h>
#include <onnx/shape_inference/implementation.h>
#include "ngraph/log.hpp"
#include "onnx_import/editor/editor.hpp"
@ -157,6 +158,12 @@ namespace
}
}
template <typename T>
std::string extract_name(const T& input_or_initializer)
{
return input_or_initializer.name();
};
void modify_initializer(TensorProto& initializer,
const std::string& name,
const std::shared_ptr<ngraph::op::Constant> values,
@ -211,6 +218,9 @@ struct onnx_import::ONNXModelEditor::Impl
: m_model_proto{std::move(parse_from_file(model_path))}
{
}
void infer_shapes() { ONNX_NAMESPACE::shape_inference::InferShapes(m_model_proto); }
void remove_shape_inference_info() { m_model_proto.mutable_graph()->clear_value_info(); }
};
onnx_import::ONNXModelEditor::ONNXModelEditor(const std::string& model_path)
@ -289,6 +299,49 @@ void onnx_import::ONNXModelEditor::set_input_shapes(
}
}
void onnx_import::ONNXModelEditor::cut_graph_fragment(const std::vector<InputEdge>& inputs,
const std::vector<OutputEdge>& outputs)
{
if (inputs.empty() && outputs.empty())
{
return;
}
m_pimpl->infer_shapes();
SubgraphExtractor editor{*(m_pimpl->m_model_proto.mutable_graph())};
editor.add_new_inputs(inputs);
editor.add_new_outputs(outputs);
editor.extract_subgraph(outputs);
m_pimpl->remove_shape_inference_info();
}
std::vector<std::string> onnx_import::ONNXModelEditor::model_inputs() const
{
const auto& graph = m_pimpl->m_model_proto.graph();
std::vector<std::string> inputs_and_initializers;
inputs_and_initializers.reserve(graph.input_size() + graph.initializer_size());
std::transform(graph.input().begin(),
graph.input().end(),
std::back_inserter(inputs_and_initializers),
extract_name<ONNX_NAMESPACE::ValueInfoProto>);
std::transform(graph.initializer().begin(),
graph.initializer().end(),
std::back_inserter(inputs_and_initializers),
extract_name<ONNX_NAMESPACE::TensorProto>);
return inputs_and_initializers;
}
std::string onnx_import::ONNXModelEditor::model_string() const
{
return m_pimpl->m_model_proto.SerializeAsString();
}
void onnx_import::ONNXModelEditor::set_input_values(
const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values)
{

View File

@ -15,8 +15,6 @@
//*****************************************************************************
#include <fstream>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <memory>
#include "core/graph.hpp"
@ -82,8 +80,6 @@ namespace ngraph
std::shared_ptr<Function> import_onnx_model(const ONNXModelEditor& model_editor)
{
// this overload of the import_onnx_model is friended with the ONNXModelEditor
// and thus can access its private members
return detail::import_onnx_model(model_editor.model(), model_editor.model_path());
}

View File

@ -0,0 +1,132 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in1"
output: "relu1"
op_type: "Relu"
}
node {
input: "relu1"
input: "in2"
output: "add1"
op_type: "Add"
}
node {
input: "in3"
input: "in4"
output: "conv1"
op_type: "Conv"
}
node {
input: "add1"
input: "conv1"
output: "mul2"
op_type: "Mul"
}
name: "subgraph_extraction_testing"
initializer {
dims: 1
dims: 1
dims: 1
dims: 1
data_type: 1
float_data: 1
name: "in4"
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "in3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "in4"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "mul2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,121 @@
ir_version: 3
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
input {
name: "conv1/7x7_s2_b_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "conv1/7x7_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 112
}
dim {
dim_value: 112
}
}
}
}
}
}
opset_import {
version: 8
}

View File

@ -0,0 +1,120 @@
ir_version: 3
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
node {
input: "conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
op_type: "Relu"
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
initializer {
dims: 1
data_type: 1
name: "conv1/7x7_s2_b_0"
float_data: 3.141592
}
output {
name: "conv1/7x7_s2_2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 112
}
dim {
dim_value: 112
}
}
}
}
}
}
opset_import {
version: 8
}

View File

@ -0,0 +1,166 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "relu1"
input: "in2"
output: "add1"
op_type: "Add"
}
node {
input: "in3"
input: "in4"
output: "conv1"
op_type: "Conv"
}
node {
input: "relu1"
input: "add1"
output: "add2"
op_type: "Add"
}
node {
input: "add1"
input: "conv1"
output: "mul2"
op_type: "Mul"
}
node {
input: "add2"
output: "split1"
output: "split2"
op_type: "Split"
attribute {
name: "axis"
i: 1
type: INT
}
}
node {
input: "relu1"
input: "split1"
output: "mul1"
op_type: "Mul"
}
name: "subgraph_extraction_testing"
initializer {
dims: 1
dims: 1
dims: 1
dims: 1
data_type: 1
float_data: 1
name: "in4"
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "in3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "in4"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "relu1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "mul1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "mul2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,149 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in3"
input: "in4"
output: "conv1"
op_type: "Conv"
}
node {
input: "relu1"
input: "add1"
output: "add2"
op_type: "Add"
}
node {
input: "add1"
input: "conv1"
output: "mul2"
op_type: "Mul"
}
name: "subgraph_extraction_testing"
initializer {
dims: 1
dims: 1
dims: 1
dims: 1
data_type: 1
float_data: 1
name: "in4"
}
input {
name: "in3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "in4"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "relu1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "add1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "mul2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "add2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,95 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "relu1"
input: "in2"
output: "add1"
op_type: "Add"
}
node {
input: "relu1"
input: "add1"
output: "add2"
op_type: "Add"
}
node {
input: "add2"
output: "split1"
output: "split2"
op_type: "Split"
attribute {
name: "axis"
i: 1
type: INT
}
}
node {
input: "relu1"
input: "split1"
output: "mul1"
op_type: "Mul"
}
name: "subgraph_extraction_testing"
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "relu1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "mul1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "split2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,82 @@
ir_version: 7
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "pool1/3x3_s2_1:conv1/7x7_s2_2"
output: "pool1/3x3_s2_1"
name: ""
op_type: "MaxPool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
}
input {
name: "pool1/3x3_s2_1:conv1/7x7_s2_2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 112
}
dim {
dim_value: 112
}
}
}
}
}
output {
name: "pool1/3x3_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 55
}
dim {
dim_value: 55
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,121 @@
ir_version: 7
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
input {
name: "conv1/7x7_s2_b_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
}
}
}
}
output {
name: "conv1/7x7_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 112
}
dim {
dim_value: 112
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,88 @@
ir_version: 7
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "conv1/7x7_s2_2:conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
op_type: "Relu"
}
node {
input: "conv1/7x7_s2_2"
output: "pool1/3x3_s2_1"
name: ""
op_type: "MaxPool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
}
input {
name: "conv1/7x7_s2_2:conv1/7x7_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 112
}
dim {
dim_value: 112
}
}
}
}
}
output {
name: "pool1/3x3_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 55
}
dim {
dim_value: 55
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,127 @@
ir_version: 7
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
node {
input: "conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
op_type: "Relu"
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
input {
name: "conv1/7x7_s2_b_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
}
}
}
}
output {
name: "conv1/7x7_s2_2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 112
}
dim {
dim_value: 112
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,134 @@
ir_version: 3
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
node {
input: "conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
op_type: "Relu"
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
input {
name: "conv1/7x7_s2_b_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
initializer {
dims: 1
data_type: 1
name: "conv1/7x7_s2_b_0"
float_data: 3.141592
}
output {
name: "conv1/7x7_s2_2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 112
}
dim {
dim_value: 112
}
}
}
}
}
}
opset_import {
version: 8
}

View File

@ -0,0 +1,78 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in1"
output: "relu1"
op_type: "Relu"
}
node {
input: "relu1"
input: "in2"
output: "add1"
op_type: "Add"
}
node {
input: "relu1"
input: "add1"
output: "add2"
op_type: "Add"
}
node {
input: "add2"
output: "split1"
output: "split2"
op_type: "Split"
attribute {
name: "axis"
i: 1
type: INT
}
}
name: "subgraph_extraction_testing"
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "split2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,90 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in1"
output: "relu1"
op_type: "Relu"
}
node {
input: "in1"
output: "relu2"
op_type: "Relu"
}
node {
input: "in2"
output: "relu3"
op_type: "Relu"
}
node {
input: "in2"
output: "relu4"
op_type: "Relu"
}
node {
input: "relu1"
input: "relu2"
output: "add1"
op_type: "Add"
}
node {
input: "relu2"
input: "relu3"
output: "add2"
op_type: "Add"
}
name: "subgraph_extraction_testing"
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "relu4"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,95 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in1"
output: "relu1"
op_type: "Relu"
}
node {
input: "in2"
output: "relu3"
op_type: "Relu"
}
node {
input: "in2"
output: "relu4"
op_type: "Relu"
}
node {
input: "relu1"
input: "relu2"
output: "add1"
op_type: "Add"
}
node {
input: "relu2"
input: "relu3"
output: "add2"
op_type: "Add"
}
name: "subgraph_extraction_testing"
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "relu2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "relu4"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,100 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in1"
output: "relu1"
op_type: "Relu"
}
node {
input: "in2"
output: "relu3"
op_type: "Relu"
}
node {
input: "in2"
output: "relu4"
op_type: "Relu"
}
node {
input: "relu1"
input: "relu2"
output: "add1"
op_type: "Add"
}
node {
input: "relu2"
input: "relu3"
output: "add2"
op_type: "Add"
}
name: "subgraph_extraction_testing"
initializer {
data_type: 1
float_data: 1
name: "in2"
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "relu2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "relu4"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,153 @@
ir_version: 7
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
node {
input: "conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
op_type: "Relu"
}
node {
input: "conv1/7x7_s2_2"
output: "pool1/3x3_s2_1"
name: ""
op_type: "MaxPool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
input {
name: "conv1/7x7_s2_b_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
}
}
}
}
output {
name: "pool1/3x3_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 55
}
dim {
dim_value: 55
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,160 @@
ir_version: 3
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
node {
input: "conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
op_type: "Relu"
}
node {
input: "conv1/7x7_s2_2"
output: "pool1/3x3_s2_1"
name: ""
op_type: "MaxPool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
input {
name: "conv1/7x7_s2_b_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
initializer {
dims: 1
data_type: 1
name: "conv1/7x7_s2_b_0"
float_data: 3.141592
}
output {
name: "pool1/3x3_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 55
}
dim {
dim_value: 55
}
}
}
}
}
}
opset_import {
version: 8
}

View File

@ -0,0 +1,146 @@
ir_version: 3
producer_name: "test_data_generator"
doc_string: "This model contains the first few nodes of the ONNX Inception V1 model"
graph {
name: "Inception V1 fragment"
node {
input: "data_0"
input: "conv1/7x7_s2_w_0"
input: "conv1/7x7_s2_b_0"
output: "conv1/7x7_s2_1"
name: ""
op_type: "Conv"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 3
ints: 3
ints: 3
ints: 3
type: INTS
}
attribute {
name: "kernel_shape"
ints: 7
ints: 7
type: INTS
}
}
node {
input: "conv1/7x7_s2_1"
output: "conv1/7x7_s2_2"
name: ""
op_type: "Relu"
}
node {
input: "conv1/7x7_s2_2"
output: "pool1/3x3_s2_1"
name: ""
op_type: "MaxPool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
}
input {
name: "data_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 224
}
dim {
dim_value: 224
}
}
}
}
}
input {
name: "conv1/7x7_s2_w_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 64
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
initializer {
dims: 1
data_type: 1
name: "conv1/7x7_s2_b_0"
float_data: 3.141592
}
output {
name: "pool1/3x3_s2_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 64
}
dim {
dim_value: 55
}
dim {
dim_value: 55
}
}
}
}
}
}
opset_import {
version: 8
}

View File

@ -0,0 +1,187 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in1"
output: "relu1"
op_type: "Relu"
}
node {
input: "relu1"
input: "in2"
output: "add1"
op_type: "Add"
}
node {
input: "in3"
input: "in4"
output: "conv1"
op_type: "Conv"
}
node {
input: "relu1"
input: "add1"
output: "add2"
op_type: "Add"
}
node {
input: "add1"
input: "conv1"
output: "mul2"
op_type: "Mul"
}
node {
input: "add2"
output: "split1"
output: "split2"
op_type: "Split"
attribute {
name: "axis"
i: 1
type: INT
}
}
node {
input: "relu1"
input: "split1"
output: "mul1"
op_type: "Mul"
}
name: "subgraph_extraction_testing"
initializer {
dims: 1
dims: 1
dims: 1
dims: 1
data_type: 1
float_data: 1
name: "in4"
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "in3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "in4"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "mul1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "split2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "mul2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,95 @@
ir_version: 7
producer_name: "test_data_generator"
graph {
node {
input: "in1"
output: "relu1"
op_type: "Relu"
}
node {
input: "in1"
output: "relu2"
op_type: "Relu"
}
node {
input: "in2"
output: "relu3"
op_type: "Relu"
}
node {
input: "in2"
output: "relu4"
op_type: "Relu"
}
node {
input: "relu1"
input: "relu2"
output: "add1"
op_type: "Add"
}
node {
input: "relu2"
input: "relu3"
output: "add2"
op_type: "Add"
}
name: "subgraph_extraction_testing"
initializer {
data_type: 1
float_data: 1
name: "in2"
}
input {
name: "in1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "in2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "add2"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
output {
name: "relu4"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -15,20 +15,24 @@
//*****************************************************************************
#include <algorithm>
#include "gtest/gtest.h"
#include "default_opset.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "onnx_import/editor/editor.hpp"
#include "onnx_import/onnx.hpp"
#include "util/engine/interpreter_engine.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
// #include "utils/onnx_test_util.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace ngraph;
using namespace ngraph::onnx_import;
static std::string s_manifest = "${MANIFEST}";
@ -54,6 +58,23 @@ namespace
return *input_pos;
}
std::string read_binary_file(const std::string& path)
{
std::ifstream inputs_fs{path, std::ios::in | std::ios::binary};
if (!inputs_fs)
{
throw std::runtime_error("Failed to open the file: " + path);
}
std::vector<char> file_content;
inputs_fs.seekg(0, std::ios::end);
const auto size = inputs_fs.tellg();
inputs_fs.seekg(0, std::ios::beg);
file_content.resize(size);
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
return std::string(file_content.begin(), file_content.end());
}
} // namespace
NGRAPH_TEST(onnx_editor, types__single_input_type_substitution)
@ -287,6 +308,411 @@ NGRAPH_TEST(onnx_editor, shapes__static_to_dynamic_rank_substitution)
}
}
NGRAPH_TEST(onnx_editor, subgraph__linear_model_head_cut)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
editor.cut_graph_fragment({{InputEdge(1, "conv1/7x7_s2_1")}}, {});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__linear_model_head_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__linear_model_head_cut_ins_and_outs)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
editor.cut_graph_fragment({{InputEdge(1, "conv1/7x7_s2_1")}},
{{OutputEdge(2, "pool1/3x3_s2_1")}});
// expected to behave the same way as subgraph__linear_model_head_cut
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__linear_model_head_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__linear_model_deeper_head_cut)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
editor.cut_graph_fragment({{InputEdge(2, "conv1/7x7_s2_2")}}, {});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__linear_model_deeper_head_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__linear_model_tail_cut)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
editor.cut_graph_fragment({}, {{OutputEdge{1, "conv1/7x7_s2_2"}}});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__linear_model_tail_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__linear_model_tail_cut_ins_and_outs)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
editor.cut_graph_fragment({{InputEdge{0, "data_0"}}}, {{OutputEdge{1, "conv1/7x7_s2_2"}}});
// expected to behave the same way as subgraph__linear_model_tail_cut
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__linear_model_tail_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__linear_model_with_initializer_tail_cut)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head_with_initializer.prototxt")};
editor.cut_graph_fragment({}, {{OutputEdge{1, "conv1/7x7_s2_2"}}});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO,
"onnx/model_editor/reference/subgraph__linear_model_with_initializer_tail_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__initializer_without_matching_input_tail_cut)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__initializer_without_matching_input.prototxt")};
editor.cut_graph_fragment({}, {{OutputEdge{1, "conv1/7x7_s2_2"}}});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__initializer_without_matching_input_tail_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__linear_model_deeper_tail_cut)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
editor.cut_graph_fragment({}, {{OutputEdge{0, "conv1/7x7_s2_1"}}});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__linear_model_deeper_tail_cut.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__no_input_params)
{
const auto model_path =
file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt");
onnx_import::ONNXModelEditor editor{model_path};
editor.cut_graph_fragment({}, {});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__no_input_params.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), model_path);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__initializer_to_input_replacement)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head_with_initializer.prototxt")};
editor.cut_graph_fragment({{InputEdge{0, "conv1/7x7_s2_b_0"}}},
{{OutputEdge{0, "conv1/7x7_s2_1"}}});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO,
"onnx/model_editor/reference/subgraph__initializer_to_input_replacement.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__initializer_to_input_replacement_2)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__initializer_without_matching_input.prototxt")};
editor.cut_graph_fragment({{InputEdge{0, "conv1/7x7_s2_b_0"}}},
{{OutputEdge{0, "conv1/7x7_s2_1"}}});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO,
"onnx/model_editor/reference/subgraph__initializer_to_input_replacement.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__multiout_op_output_edge)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
editor.cut_graph_fragment({}, {{OutputEdge{5, "split2"}}});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/reference/subgraph__multiout_op_output_edge.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__existing_inputs_and_outputs_based_extraction)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
editor.cut_graph_fragment({{InputEdge{1, "in2"}, InputEdge{2, "in3"}}},
{{OutputEdge{4, "mul2"}}});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__existing_inputs_and_outputs_based_extraction.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__input_edge_from_tensor_with_multiple_consumers)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
editor.cut_graph_fragment({{InputEdge{1, "relu1"}, InputEdge{6, "relu1"}}},
{{OutputEdge{6, "mul1"}, OutputEdge{4, "mul2"}}});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__input_edge_from_tensor_with_multiple_consumers.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__input_edge_from_tensor_with_multiple_consumers_2)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
editor.cut_graph_fragment({{InputEdge{3, "relu1"}, InputEdge{3, "add1"}}},
{{OutputEdge{3, "add2"}, OutputEdge{4, "mul2"}}});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__input_edge_from_tensor_with_multiple_consumers_2.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__input_edge_from_tensor_with_multiple_consumers_3)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
editor.cut_graph_fragment({{InputEdge{3, "relu1"}, InputEdge{6, "relu1"}}},
{{OutputEdge{6, "mul1"}, OutputEdge{5, "split2"}}});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__input_edge_from_tensor_with_multiple_consumers_3.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__input_edge_from_tensor_with_multiple_consumers_4)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")};
editor.cut_graph_fragment({{InputEdge{3, "relu1"}}},
{{OutputEdge{6, "mul1"}, OutputEdge{5, "split2"}}});
// expected to behave the same way as the test above
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__input_edge_from_tensor_with_multiple_consumers_3.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__multiple_consumers_of_graph_input_relu2)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
editor.cut_graph_fragment({{InputEdge{4, "relu2"}}}, {});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__multiple_consumers_of_graph_input_relu2.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__multiple_consumers_of_graph_initializer)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
editor.cut_graph_fragment({{InputEdge{2, "in2"}}}, {});
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__multiple_consumers_of_graph_initializer.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__multiple_consumers_of_graph_initializer_2)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
editor.cut_graph_fragment({{InputEdge{2, "in2"}, InputEdge{3, "in2"}}}, {});
// same as above
const auto ref_model =
file_util::path_join(SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__multiple_consumers_of_graph_initializer.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__multiple_consumers_of_graph_initializer_relu2_and_init)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")};
editor.cut_graph_fragment({{InputEdge{5, "relu2"}, InputEdge{3, "in2"}}}, {});
const auto ref_model = file_util::path_join(
SERIALIZED_ZOO,
"onnx/model_editor/reference/"
"subgraph__multiple_consumers_of_graph_initializer_relu2_and_init.onnx");
EXPECT_EQ(editor.model_string(), read_binary_file(ref_model));
// const auto result = compare_onnx_models(editor.model_string(), ref_model);
// EXPECT_TRUE(result.is_ok) << result.error_message;
}
NGRAPH_TEST(onnx_editor, subgraph__invalid_edge_idx)
{
const auto model_path =
file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt");
onnx_import::ONNXModelEditor editor{model_path};
EXPECT_THROW(editor.cut_graph_fragment({{InputEdge{15, "x"}}}, {}), ngraph::ngraph_error);
}
NGRAPH_TEST(onnx_editor, subgraph__invalid_edge_name)
{
const auto model_path =
file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt");
onnx_import::ONNXModelEditor editor{model_path};
EXPECT_THROW(editor.cut_graph_fragment({{InputEdge{0, "x"}}}, {}), ngraph::ngraph_error);
}
NGRAPH_TEST(onnx_editor, subgraph__inputs_getter)
{
onnx_import::ONNXModelEditor editor{file_util::path_join(
SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")};
EXPECT_EQ(editor.model_inputs(),
(std::vector<std::string>{"data_0", "conv1/7x7_s2_w_0", "conv1/7x7_s2_b_0"}));
editor.cut_graph_fragment({{InputEdge(1, "conv1/7x7_s2_1")}}, {});
EXPECT_EQ(editor.model_inputs(), (std::vector<std::string>{"conv1/7x7_s2_2:conv1/7x7_s2_1"}));
}
using TestEngine = test::INTERPRETER_Engine;
NGRAPH_TEST(onnx_editor, values__append_one_initializer)

View File

@ -0,0 +1,325 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <exception>
#include <fstream>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <onnx/onnx_pb.h>
#include <sstream>
#include "ngraph/except.hpp"
#include "onnx_test_util.hpp"
using namespace ngraph;
using namespace ngraph::test;
namespace
{
void parse_from_istream(std::istream& model_stream, ONNX_NAMESPACE::ModelProto& model_proto)
{
if (!model_stream.good())
{
model_stream.clear();
model_stream.seekg(0);
if (!model_stream.good())
{
throw ngraph_error("Provided input stream has incorrect state.");
}
}
if (!model_proto.ParseFromIstream(&model_stream))
{
#ifdef NGRAPH_USE_PROTOBUF_LITE
throw ngraph_error(
"Error during import of ONNX model provided as input stream "
" with binary protobuf message.");
#else
// Rewind to the beginning and clear stream state.
model_stream.clear();
model_stream.seekg(0);
google::protobuf::io::IstreamInputStream iistream(&model_stream);
// Try parsing input as a prototxt message
if (!google::protobuf::TextFormat::Parse(&iistream, &model_proto))
{
throw ngraph_error(
"Error during import of ONNX model provided as input stream with prototxt "
"protobuf message.");
}
#endif
}
}
void parse_from_file(const std::string& file_path, ONNX_NAMESPACE::ModelProto& model_proto)
{
std::ifstream file_stream{file_path, std::ios::in | std::ios::binary};
if (!file_stream.is_open())
{
throw ngraph_error("Could not open the file: " + file_path);
};
parse_from_istream(file_stream, model_proto);
}
ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.node_size() != ref_graph.node_size())
{
return ComparisonResult::fail("The number of nodes in compared models doesn't match");
}
else
{
for (int i = 0; i < graph.node_size(); ++i)
{
const auto& lhs = graph.node(i);
const auto& rhs = ref_graph.node(i);
if (lhs.op_type() != rhs.op_type())
{
return ComparisonResult::fail("Operation types are different at index " +
std::to_string(i) + ": " + lhs.op_type() +
" vs " + rhs.op_type());
}
for (int j = 0; j < lhs.input_size(); ++j)
{
if (lhs.input(j) != rhs.input(j))
{
return ComparisonResult::fail(
"Input names don't match for nodes at index " + std::to_string(i) +
": " + lhs.input(j) + " vs " + rhs.input(j));
}
}
for (int j = 0; j < lhs.output_size(); ++j)
{
if (lhs.output(j) != rhs.output(j))
{
return ComparisonResult::fail(
"Output names don't match for nodes at index " + std::to_string(i) +
": " + lhs.output(j) + " vs " + rhs.output(j));
}
}
}
}
return ComparisonResult::pass();
}
ComparisonResult compare_value_info(const ONNX_NAMESPACE::ValueInfoProto& lhs,
const ONNX_NAMESPACE::ValueInfoProto& rhs,
const std::string& item_type)
{
if (lhs.name() != rhs.name())
{
return ComparisonResult::fail(item_type + " names in the graph don't match: " +
lhs.name() + " vs " + rhs.name());
}
const auto& lhs_tensor = lhs.type().tensor_type();
const auto& rhs_tensor = rhs.type().tensor_type();
if (lhs_tensor.elem_type() != rhs_tensor.elem_type())
{
return ComparisonResult::fail("Element types don't match for " + item_type + " " +
lhs.name() + ": " +
std::to_string(lhs_tensor.elem_type()) + " vs " +
std::to_string(rhs_tensor.elem_type()));
}
const auto& lhs_shape = lhs_tensor.shape();
const auto& rhs_shape = rhs_tensor.shape();
if (lhs_shape.dim_size() != rhs_shape.dim_size())
{
return ComparisonResult::fail("Tensor ranks don't match for " + item_type + " " +
lhs.name() + ": " + std::to_string(lhs_shape.dim_size()) +
" vs " + std::to_string(rhs_shape.dim_size()));
}
else
{
for (int j = 0; j < lhs_shape.dim_size(); ++j)
{
const auto& lhs_dim = lhs_shape.dim(j);
const auto& rhs_dim = rhs_shape.dim(j);
if ((lhs_dim.has_dim_value() && rhs_dim.has_dim_param()) ||
(rhs_dim.has_dim_value() && lhs_dim.has_dim_param()))
{
return ComparisonResult::fail("Dynamic vs static dimension mismatch for " +
item_type + " " + lhs.name() + " at index: " +
std::to_string(j));
}
else if (lhs_dim.has_dim_value() && lhs_dim.dim_value() != rhs_dim.dim_value())
{
return ComparisonResult::fail("Shape dimensions don't match for " + item_type +
" " + lhs.name() + " at index: " +
std::to_string(j) + ". " +
std::to_string(lhs_dim.dim_value()) + " vs " +
std::to_string(rhs_dim.dim_value()));
}
}
}
return ComparisonResult::pass();
}
ComparisonResult compare_inputs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.input_size() != ref_graph.input_size())
{
return ComparisonResult::fail(
"The number of inputs in compared models doesn't match: " +
std::to_string(graph.input_size()) + " vs " +
std::to_string(ref_graph.input_size()));
}
else
{
for (int i = 0; i < graph.input_size(); ++i)
{
const auto& lhs = graph.input(i);
const auto& rhs = ref_graph.input(i);
const auto res = compare_value_info(lhs, rhs, "input");
if (!res.is_ok)
{
return res;
}
}
return ComparisonResult::pass();
}
}
ComparisonResult compare_outputs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.output_size() != ref_graph.output_size())
{
return ComparisonResult::fail("The number of outputs in compared models doesn't match" +
std::to_string(graph.output_size()) + " vs " +
std::to_string(ref_graph.output_size()));
}
else
{
for (int i = 0; i < graph.output_size(); ++i)
{
const auto& lhs = graph.output(i);
const auto& rhs = ref_graph.output(i);
const auto res = compare_value_info(lhs, rhs, "output");
if (!res.is_ok)
{
return res;
}
}
return ComparisonResult::pass();
}
}
ComparisonResult compare_initializers(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.initializer_size() != ref_graph.initializer_size())
{
return ComparisonResult::fail(
"The number of initializers in compared models doesn't match" +
std::to_string(graph.initializer_size()) + " vs " +
std::to_string(ref_graph.initializer_size()));
}
else
{
for (int i = 0; i < graph.initializer_size(); ++i)
{
const auto& lhs = graph.initializer(i);
const auto& rhs = ref_graph.initializer(i);
if (lhs.name() != rhs.name())
{
return ComparisonResult::fail("Initializer names in the graph don't match: " +
lhs.name() + " vs " + rhs.name());
}
else if (lhs.data_type() != rhs.data_type())
{
return ComparisonResult::fail(
"Initializer data types in the graph don't match: " +
std::to_string(lhs.data_type()) + " vs " + std::to_string(rhs.data_type()));
}
else if (lhs.dims_size() != rhs.dims_size())
{
return ComparisonResult::fail("Initializer ranks in the graph don't match: " +
std::to_string(lhs.dims_size()) + " vs " +
std::to_string(rhs.dims_size()));
}
else
{
for (int j = 0; j < lhs.dims_size(); ++j)
{
if (lhs.dims(j) != rhs.dims(j))
{
return ComparisonResult::fail(
"Shape dimensions don't match for initializer " + lhs.name() +
" at index: " + std::to_string(j) + ". " +
std::to_string(lhs.dims(j)) + " vs " + std::to_string(rhs.dims(j)));
}
}
}
}
return ComparisonResult::pass();
}
}
ComparisonResult compare_onnx_graphs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
ComparisonResult comparison = compare_inputs(graph, ref_graph);
if (!comparison.is_ok)
{
return comparison;
}
comparison = compare_outputs(graph, ref_graph);
if (!comparison.is_ok)
{
return comparison;
}
comparison = compare_initializers(graph, ref_graph);
if (!comparison.is_ok)
{
return comparison;
}
return compare_nodes(graph, ref_graph);
}
} // namespace
namespace ngraph
{
namespace test
{
ComparisonResult compare_onnx_models(const std::string& model,
const std::string& reference_model_path)
{
std::stringstream model_stream{model};
ONNX_NAMESPACE::ModelProto model_proto, ref_model;
parse_from_istream(model_stream, model_proto);
parse_from_file(reference_model_path, ref_model);
return compare_onnx_graphs(model_proto.graph(), ref_model.graph());
}
} // namespace test
} // namespace ngraph

View File

@ -0,0 +1,52 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
namespace ngraph
{
namespace test
{
struct ComparisonResult
{
ComparisonResult() = default;
ComparisonResult(std::string error)
: is_ok{false}
, error_message{std::move(error)}
{
}
ComparisonResult(ComparisonResult&&) = default;
ComparisonResult(const ComparisonResult&) = default;
ComparisonResult& operator=(ComparisonResult&&) = default;
ComparisonResult& operator=(const ComparisonResult&) = default;
bool is_ok = true;
std::string error_message;
static ComparisonResult pass() { return {}; }
static ComparisonResult fail(std::string error)
{
return ComparisonResult{std::move(error)};
}
};
ComparisonResult compare_onnx_models(const std::string& model,
const std::string& reference_model_path);
} // namespace onnx_import
} // namespace ngraph