[ONNX] Split importing model to two phases: decode and convert (#6326)

This commit is contained in:
Mateusz Tabaka 2021-07-12 12:13:25 +02:00 committed by GitHub
parent 8fb1182670
commit 1a5bd8a510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 462 additions and 254 deletions

View File

@ -134,6 +134,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
return Precision(Precision::BIN);
case ::ngraph::element::Type_t::boolean:
return Precision(Precision::BOOL);
case ::ngraph::element::Type_t::dynamic:
return Precision(Precision::UNSPECIFIED);
default:
IE_THROW() << "Incorrect precision " << precision.get_type_name() << "!"; return{};
}

View File

@ -55,7 +55,7 @@ class TRANSFORMATIONS_API FrameworkNode : public Op {
public:
NGRAPH_RTTI_DECLARATION;
explicit FrameworkNode(const OutputVector& inputs);
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1);
void validate_and_infer_types() override;

View File

@ -10,8 +10,9 @@ using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::FrameworkNode, "FrameworkNode", 0);
op::FrameworkNode::FrameworkNode(const OutputVector& inputs)
op::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size)
: Op(inputs) {
set_output_size(output_size);
constructor_validate_and_infer_types();
}

View File

@ -495,6 +495,7 @@ std::string get_opset_name(
std::string get_precision_name(const ngraph::element::Type & elem_type) {
switch (elem_type) {
case ::ngraph::element::Type_t::undefined:
case ::ngraph::element::Type_t::dynamic:
return "UNSPECIFIED";
case ::ngraph::element::Type_t::f16:
return "FP16";

View File

@ -45,7 +45,7 @@ if(COMMAND ie_faster_build)
)
endif()
target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder
target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder inference_engine_transformations
PUBLIC ngraph)
target_include_directories(onnx_importer PUBLIC $<BUILD_INTERFACE:${ONNX_IMPORT_INCLUDE_DIR}>

View File

@ -75,9 +75,8 @@ namespace ngraph
bool has_attribute(const std::string& name) const;
Subgraph get_subgraph_from_attribute(
const std::string& name,
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
bool has_subgraph() const;
std::shared_ptr<Subgraph> get_subgraph() const;
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const;

View File

@ -0,0 +1,100 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <core/graph.hpp>
#include <ngraph/visibility.hpp>
#include <ngraph_ops/framework_node.hpp>
#include <onnx_import/core/node.hpp>
namespace ONNX_NAMESPACE
{
// forward declaration
class ModelProto;
} // namespace ONNX_NAMESPACE
namespace ngraph
{
namespace onnx_import
{
class Model;
}
namespace frontend
{
class ONNXFrameworkNode : public op::FrameworkNode
{
public:
NGRAPH_RTTI_DECLARATION;
ONNXFrameworkNode(const onnx_import::Node& node)
: FrameworkNode(node.get_ng_inputs(), node.get_outputs_size())
, m_node(node)
{
}
ONNXFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
: FrameworkNode(inputs, node.get_outputs_size())
, m_node(node)
{
}
const onnx_import::Node& get_onnx_node() const { return m_node; }
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;
virtual bool visit_attributes(AttributeVisitor& visitor) override
{
// TODO: implement reading as well, now it work for serialization only
std::string domain = m_node.domain();
std::string op_type = m_node.op_type();
visitor.on_attribute("ONNX_META_domain", domain);
visitor.on_attribute("ONNX_META_type", op_type);
return true;
}
private:
onnx_import::Node m_node;
};
class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode
{
public:
NGRAPH_RTTI_DECLARATION;
ONNXSubgraphFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
: ONNXFrameworkNode(node, inputs)
{
}
void infer_inputs_from_parent()
{
get_onnx_node().get_subgraph()->infer_inputs_from_parent();
}
std::shared_ptr<Function> get_subgraph_body() const
{
auto subgraph = get_onnx_node().get_subgraph();
return std::make_shared<Function>(subgraph->get_ng_outputs(),
subgraph->get_ng_parameters(),
subgraph->get_name());
}
};
} // namespace frontend
} // namespace ngraph

View File

@ -11,9 +11,7 @@ namespace ngraph
{
namespace onnx_import
{
Subgraph Attribute::get_subgraph(
const Graph& parent_graph,
const std::map<std::size_t, std::string>& carried_dependencies_map) const
Subgraph Attribute::get_subgraph(const Graph& parent_graph) const
{
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH)
{
@ -25,33 +23,6 @@ namespace ngraph
const auto& graph = m_attribute_proto->g();
model_proto->mutable_graph()->CopyFrom(graph);
const std::size_t subgraph_inputs_count =
static_cast<size_t>(model_proto->mutable_graph()->mutable_input()->size());
// Use the `carried_dependencies_map` to infer the types for the subgraph inputs
for (const auto& carried_dependency : carried_dependencies_map)
{
if (carried_dependency.first >= subgraph_inputs_count)
{
NGRAPH_WARN << "Input with index: '" << carried_dependency.first
<< "' was not found in the subgraph";
}
else
{
const auto& parent_in =
parent_graph.get_ng_node_from_cache(carried_dependency.second);
const auto& carried_type = parent_in.get_element_type();
auto subgraph_in =
model_proto->mutable_graph()->mutable_input(carried_dependency.first);
auto subgraph_in_tensor_type =
subgraph_in->mutable_type()->mutable_tensor_type();
if (!subgraph_in_tensor_type->has_elem_type())
{
subgraph_in_tensor_type->set_elem_type(
onnx_common::ng_to_onnx_data_type(carried_type));
}
}
}
// set opset version and domain from the parent graph
model_proto->mutable_opset_import()->CopyFrom(parent_graph.get_opset_imports());
auto model = common::make_unique<Model>(std::move(model_proto));

View File

@ -316,9 +316,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Subgraph get_subgraph(
const Graph& parent_graph,
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
Subgraph get_subgraph(const Graph& parent_graph) const;
std::vector<Tensor> get_tensor_array() const
{

View File

@ -14,6 +14,7 @@
#include "ngraph/node.hpp"
#include "ngraph/provenance.hpp"
#include "onnx_import/core/node.hpp"
#include "onnx_import/onnx_framework_node.hpp"
#include "utils/common.hpp"
#include "utils/provenance_tag.hpp"
@ -55,25 +56,6 @@ namespace ngraph
Graph::Graph(std::unique_ptr<Model>&& model)
: Graph(std::move(model), common::make_unique<GraphCache>())
{
// Remove dangling Parameters
for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();)
{
if ((*param_it)->get_output_target_inputs(0).size() == 0)
{
const auto& name = (*param_it)->get_friendly_name();
auto out_it = std::find_if(
m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) {
return info.get_name() == name;
});
if (out_it == m_outputs.end())
{
m_cache->remove_node(name);
param_it = m_parameters.erase(param_it);
continue;
}
}
param_it++;
}
}
Graph::Graph(std::unique_ptr<Model>&& model, std::unique_ptr<GraphCache>&& cache)
@ -174,14 +156,82 @@ namespace ngraph
NGRAPH_CHECK(unknown_operators.empty(),
"nGraph does not support the following ONNX operations: ",
detail::to_string(unknown_operators));
}
void Graph::convert_to_ngraph_nodes()
{
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_model->get_graph().node())
{
m_nodes.emplace_back(node_proto, *this);
const Node& node{m_nodes.back()};
if (node.has_subgraph())
{
auto subgraph = node.get_subgraph();
auto body_func = subgraph->convert();
}
OutputVector ng_nodes{node.get_ng_nodes()};
set_friendly_names(node, ng_nodes);
for (std::size_t i{0}; i < node.get_outputs_size(); ++i)
{
m_cache->emplace_node(node.output(i), std::move(ng_nodes.at(i)));
}
}
}
void Graph::remove_dangling_parameters()
{
for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();)
{
if ((*param_it)->get_output_target_inputs(0).size() == 0)
{
const auto& name = (*param_it)->get_friendly_name();
auto out_it = std::find_if(
m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) {
return info.get_name() == name;
});
if (out_it == m_outputs.end())
{
m_cache->remove_node(name);
param_it = m_parameters.erase(param_it);
continue;
}
}
param_it++;
}
}
std::shared_ptr<Function> Graph::convert()
{
convert_to_ngraph_nodes();
remove_dangling_parameters();
return create_function();
}
void Graph::decode_to_framework_nodes()
{
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_model->get_graph().node())
{
m_nodes.emplace_back(node_proto, *this);
const Node& node{m_nodes.back()};
std::shared_ptr<frontend::ONNXFrameworkNode> framework_node;
if (node.has_subgraph())
{
auto subgraph = node.get_subgraph();
auto body_func = subgraph->decode();
auto inputs = node.get_ng_inputs();
for (const auto& input : subgraph->get_inputs_from_parent())
inputs.push_back(input);
framework_node =
std::make_shared<ngraph::frontend::ONNXSubgraphFrameworkNode>(node, inputs);
}
else
{
framework_node = std::make_shared<ngraph::frontend::ONNXFrameworkNode>(node);
}
OutputVector ng_nodes{framework_node->outputs()};
set_friendly_names(node, ng_nodes);
// Iterate over the number of outputs for given node in graph.
// Some of them may be optional and trimmed. See:
// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
@ -192,12 +242,24 @@ namespace ngraph
}
}
const GraphCache& Graph::get_graph_cache() const { return *m_cache.get(); }
bool Graph::is_node_in_cache(const std::string& name) const
std::shared_ptr<Function> Graph::create_function()
{
return m_cache->contains(name);
auto function = std::make_shared<Function>(get_ng_outputs(), m_parameters, get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{
function->get_output_op(i)->set_friendly_name(m_outputs.at(i).get_name());
}
return function;
}
std::shared_ptr<Function> Graph::decode()
{
decode_to_framework_nodes();
return create_function();
}
const GraphCache& Graph::get_graph_cache() const { return *m_cache.get(); }
Output<ngraph::Node> Graph::get_ng_node_from_cache(const std::string& name) const
{
return m_cache->get_node(name);
@ -247,6 +309,12 @@ namespace ngraph
set_friendly_names(onnx_node, ng_node_vector);
add_provenance_tags(onnx_node, ng_node_vector);
for (std::size_t i{0}; i < onnx_node.get_outputs_size(); ++i)
{
auto ng_node = ng_node_vector.at(i);
m_cache->emplace_node(onnx_node.output(i), std::move(ng_node));
}
return ng_node_vector;
}
@ -323,9 +391,21 @@ namespace ngraph
}
Subgraph::Subgraph(std::unique_ptr<Model>&& model, const Graph& parent_graph)
: Graph(
std::move(model),
std::unique_ptr<SubgraphCache>(new SubgraphCache(parent_graph.get_graph_cache())))
: Graph(std::move(model), common::make_unique<GraphCache>())
, m_parent_graph_cache(&parent_graph.get_graph_cache())
{
}
Output<ngraph::Node> Subgraph::get_ng_node_from_cache(const std::string& name) const
{
if (m_cache->contains(name))
{
return m_cache->get_node(name);
}
return m_parent_graph_cache->get_node(name);
}
void Subgraph::find_inputs_from_parent()
{
// find all nodes on edge parent graph-subgraph
// (it means input of node from parent graph, output from subgraph)
@ -334,16 +414,16 @@ namespace ngraph
int input_index = 0;
for (const auto& in_name : node_proto.input())
{
if (m_cache->node_scope(in_name) == NodeScope::ParentGraph)
if (m_parent_graph_cache->contains(in_name))
{
const auto& from_parent_node = m_cache->get_node(in_name);
const auto& from_parent_node = m_parent_graph_cache->get_node(in_name);
// constants are skipped
if (!ngraph::is_type<ngraph::op::Constant>(
from_parent_node.get_node_shared_ptr()))
{
for (const auto& out_name : node_proto.output())
{
if (m_cache->node_scope(out_name) == NodeScope::SubGraph)
if (m_cache->contains(out_name))
{
auto out_node_to_replace_input = m_cache->get_node(out_name);
auto new_param = std::make_shared<ngraph::op::Parameter>(
@ -353,8 +433,10 @@ namespace ngraph
out_node_to_replace_input.get_node()
->input(input_index)
.replace_source_output(new_param);
m_parameter_to_parent_node_map.insert({new_param, in_name});
m_cache->emplace_node(in_name, new_param);
m_parameters.push_back(new_param);
m_outputs_from_parent.push_back(from_parent_node);
m_inputs_from_parent.push_back(in_name);
}
}
}
@ -364,11 +446,39 @@ namespace ngraph
}
}
const std::vector<Output<ngraph::Node>> Subgraph::get_outputs_from_parent() const
std::shared_ptr<Function> Subgraph::convert()
{
return m_outputs_from_parent;
convert_to_ngraph_nodes();
find_inputs_from_parent();
return create_function();
}
void Subgraph::decode_to_framework_nodes()
{
Graph::decode_to_framework_nodes();
find_inputs_from_parent();
}
const std::vector<Output<ngraph::Node>> Subgraph::get_inputs_from_parent() const
{
OutputVector result;
for (const auto& name : m_inputs_from_parent)
{
result.push_back(m_parent_graph_cache->get_node(name));
}
return result;
}
void Subgraph::infer_inputs_from_parent()
{
for (auto& it : m_parameter_to_parent_node_map)
{
const auto& node = m_parent_graph_cache->get_node(it.second);
auto& parameter = it.first;
parameter->set_element_type(node.get_element_type());
parameter->set_partial_shape(node.get_partial_shape());
}
}
} // namespace onnx_import
} // namespace ngraph

View File

@ -31,13 +31,14 @@ namespace ngraph
Graph& operator=(const Graph&) = delete;
Graph& operator=(Graph&&) = default;
virtual std::shared_ptr<Function> convert();
std::shared_ptr<Function> decode();
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
OutputVector get_ng_outputs() const;
const ParameterVector& get_ng_parameters() const { return m_parameters; }
bool is_node_in_cache(const std::string& name) const;
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
virtual Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const;
const std::string& get_name() const { return m_model->get_graph().name(); }
OutputVector make_ng_nodes(const Node& onnx_node) const;
const GraphCache& get_graph_cache() const;
@ -60,6 +61,11 @@ namespace ngraph
const OutputVector& ng_node_vector) const;
protected:
virtual void decode_to_framework_nodes();
void convert_to_ngraph_nodes();
void remove_dangling_parameters();
std::shared_ptr<Function> create_function();
ParameterVector m_parameters;
std::unique_ptr<Model> m_model;
std::unique_ptr<GraphCache> m_cache;
@ -82,9 +88,11 @@ namespace ngraph
/// \param[in] parent_graph The reference to the parent graph.
Subgraph(std::unique_ptr<Model>&& model, const Graph& parent_graph);
/// \brief Return outputs which are on the edge the subgraph and the parent graph.
/// \brief Return nodes which are on the edge the subgraph and the parent graph.
/// \return Vector of edge nodes from parent scope.
const std::vector<Output<ngraph::Node>> get_outputs_from_parent() const;
const std::vector<Output<ngraph::Node>> get_inputs_from_parent() const;
std::shared_ptr<Function> convert() override;
Subgraph() = delete;
@ -94,8 +102,17 @@ namespace ngraph
Subgraph& operator=(const Subgraph&) = delete;
Subgraph& operator=(Subgraph&&) = default;
Output<ngraph::Node> get_ng_node_from_cache(const std::string& name) const override;
void infer_inputs_from_parent();
private:
std::vector<Output<ngraph::Node>> m_outputs_from_parent;
void decode_to_framework_nodes() override;
void find_inputs_from_parent();
const GraphCache* m_parent_graph_cache;
std::vector<std::string> m_inputs_from_parent;
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::string>
m_parameter_to_parent_node_map;
};
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)

View File

@ -39,55 +39,5 @@ namespace ngraph
{
return (m_graph_cache_map.count(name) > 0);
}
NodeScope GraphCache::node_scope(const std::string& name) const
{
return contains(name) ? NodeScope::ParentGraph : NodeScope::Lack;
}
SubgraphCache::SubgraphCache(const GraphCache& parent_graph_cache)
: m_parent_graph_cache{&parent_graph_cache}
{
if (m_parent_graph_cache == nullptr)
{
throw ngraph_error("Parent graph cache is not initialized");
}
}
Output<ngraph::Node> SubgraphCache::get_node(const std::string& name) const
{
// present in subgraph scope
if (GraphCache::contains(name))
{
return GraphCache::get_node(name);
}
else // present in parent graph scope
{
return m_parent_graph_cache->get_node(name);
}
}
bool SubgraphCache::contains(const std::string& name) const
{
// the node is in subgraph or in parent graph scope
return GraphCache::contains(name) || m_parent_graph_cache->contains(name);
}
NodeScope SubgraphCache::node_scope(const std::string& name) const
{
if (GraphCache::contains(name))
{
return NodeScope::SubGraph;
}
else if (m_parent_graph_cache->contains(name))
{
return NodeScope::ParentGraph;
}
else
{
return NodeScope::Lack;
}
}
} // namespace onnx_import
} // namespace ngraph

View File

@ -14,17 +14,6 @@ namespace ngraph
{
namespace onnx_import
{
/// \brief Enum which determines scope (visibility) of nodes in GraphCache.
enum class NodeScope
{
// in parent graph scope
ParentGraph = 1,
// in subgraph scope
SubGraph,
// not available at all
Lack
};
/// \brief GraphCache stores and provides access to ONNX graph initializers.
class GraphCache
{
@ -58,58 +47,10 @@ namespace ngraph
/// \return true if the node named `name` exist in the cache, false otherwise.
virtual bool contains(const std::string& name) const;
/// \brief Return NodeScope enum which determines scope of the node.
/// \note If the method is called on GraphCache the ParentGraph enum
/// value is retunred always.
///
/// \param[in] name The name of the node.
///
/// \return SubGraph if node belongs to SubgraphCache, ParentGraph if
/// is avalible in parent_graph_cache, otherwise Lack
virtual NodeScope node_scope(const std::string& name) const;
virtual ~GraphCache() = default;
private:
std::map<std::string, Output<ngraph::Node>> m_graph_cache_map;
};
class SubgraphCache : public GraphCache
{
public:
/// \brief Constructs a SubgraphCache class object.
///
/// \param[in] parent_graph_cache The reference to the parent graph.
SubgraphCache(const GraphCache& parent_graph_cache);
/// \brief Get the node from the cache (subgraph or parent graph)
///
/// \note If the node is not found the ngraph_error exception is thrown.
///
/// \param[in] name The name of the node.
///
/// \return The node named `name` from subgraph (as present) or from parent graph.
Output<ngraph::Node> get_node(const std::string& name) const override;
/// \brief Return true if the node named `name` exist in the cache.
///
/// \param[in] name The name of the node.
///
/// \return true if the node named `name` exist in the cache
/// (subgraph or parent graph), false otherwise.
bool contains(const std::string& name) const override;
/// \brief Return NodeScope enum which determines scope of the node.
///
/// \param[in] name The name of the node.
///
/// \return SubGraph if the node belongs to SubgraphCache, ParentGraph if
/// is avalible in parent_graph_cache, otherwise Lack
NodeScope node_scope(const std::string& name) const override;
private:
const GraphCache* m_parent_graph_cache;
};
} // namespace onnx_import
} // namespace ngraph

View File

@ -6,6 +6,7 @@
#include "core/model.hpp"
#include "ngraph/log.hpp"
#include "onnx_import/onnx_framework_node.hpp"
#include "ops_bridge.hpp"
namespace ngraph

View File

@ -26,6 +26,29 @@ namespace ngraph
, m_graph{&graph}
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
{
const auto it =
std::find_if(std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.is_graph(); });
m_has_subgraph = it != std::end(m_attributes);
if (m_has_subgraph)
{
m_subgraph = std::make_shared<Subgraph>(it->get_subgraph(*m_graph));
}
}
Impl(const ONNX_NAMESPACE::NodeProto& node_proto,
const Graph& graph,
std::shared_ptr<Subgraph> subgraph)
: m_node_proto{&node_proto}
, m_name{node_proto.has_name() ? node_proto.name() : ""}
, m_domain{get_node_domain(node_proto)}
, m_graph{&graph}
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
, m_has_subgraph(subgraph != nullptr)
, m_subgraph(subgraph)
{
}
@ -44,9 +67,8 @@ namespace ngraph
bool has_attribute(const std::string& name) const;
Subgraph get_subgraph_from_attribute(
const std::string& name,
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
bool has_subgraph() const;
std::shared_ptr<Subgraph> get_subgraph() const;
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const;
@ -58,6 +80,8 @@ namespace ngraph
const Graph& graph() const;
private:
Subgraph get_subgraph_from_attribute(const std::string& name) const;
const ONNX_NAMESPACE::NodeProto* m_node_proto;
std::string m_name;
std::string m_domain;
@ -65,6 +89,9 @@ namespace ngraph
std::vector<Attribute> m_attributes;
std::vector<std::reference_wrapper<const std::string>> m_output_names;
mutable std::string m_description;
bool m_has_subgraph;
std::shared_ptr<Subgraph> m_subgraph;
};
const ONNX_NAMESPACE::NodeProto& Node::Impl::node_proto() const { return *m_node_proto; }
@ -94,9 +121,7 @@ namespace ngraph
return it != std::end(m_attributes);
}
Subgraph Node::Impl::get_subgraph_from_attribute(
const std::string& name,
const std::map<std::size_t, std::string>& carried_dependencies_map) const
Subgraph Node::Impl::get_subgraph_from_attribute(const std::string& name) const
{
auto it = std::find_if(
std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
@ -106,9 +131,13 @@ namespace ngraph
{
throw error::node::UnknownAttribute{this->name(), name};
}
return it->get_subgraph(graph(), carried_dependencies_map);
return it->get_subgraph(*m_graph);
}
bool Node::Impl::has_subgraph() const { return m_has_subgraph; }
std::shared_ptr<Subgraph> Node::Impl::get_subgraph() const { return m_subgraph; }
template <typename T>
T Node::Impl::get_attribute_value(const std::string& name, T default_value) const
{
@ -140,8 +169,7 @@ namespace ngraph
template <>
Subgraph Node::Impl::get_attribute_value(const std::string& name) const
{
const std::map<std::size_t, std::string> empty_map;
return get_subgraph_from_attribute(name, empty_map);
return get_subgraph_from_attribute(name);
}
OutputVector Node::Impl::get_ng_nodes(const Node& node) const
@ -196,7 +224,9 @@ namespace ngraph
}
Node::Node(const Node& other)
: m_pimpl{new Impl{other.m_pimpl->node_proto(), other.m_pimpl->graph()},
: m_pimpl{new Impl{other.m_pimpl->node_proto(),
other.m_pimpl->graph(),
other.get_subgraph()},
[](Impl* impl) { delete impl; }}
{
}
@ -219,12 +249,9 @@ namespace ngraph
return m_pimpl->has_attribute(name);
}
Subgraph Node::get_subgraph_from_attribute(
const std::string& name,
const std::map<std::size_t, std::string>& carried_dependencies_map) const
{
return m_pimpl->get_subgraph_from_attribute(name, carried_dependencies_map);
}
bool Node::has_subgraph() const { return m_pimpl->has_subgraph(); }
std::shared_ptr<Subgraph> Node::get_subgraph() const { return m_pimpl->get_subgraph(); }
std::vector<std::string> Node::get_attribute_names() const
{
@ -462,7 +489,6 @@ namespace ngraph
{
return m_pimpl->template get_attribute_value<std::vector<Graph>>(name);
}
} // namespace onnx_import
} // namespace ngraph

View File

@ -36,7 +36,10 @@ namespace ngraph
public:
static constexpr NodeTypeInfo type_info{"NullNode", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NullNode() = default;
NullNode()
: Node(1)
{
}
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;

View File

@ -19,20 +19,6 @@ namespace ngraph
{
namespace onnx_import
{
namespace error
{
namespace value_info
{
struct unspecified_element_type : ngraph_error
{
unspecified_element_type()
: ngraph_error{"value info has no element type specified"}
{
}
};
} // namespace value_info
} // namespace error
class ValueInfo
{
public:
@ -65,12 +51,12 @@ namespace ngraph
const PartialShape& get_shape() const { return m_partial_shape; }
const element::Type& get_element_type() const
{
if (!m_value_info_proto->type().tensor_type().has_elem_type())
if (m_value_info_proto->type().tensor_type().has_elem_type())
{
throw error::value_info::unspecified_element_type{};
return common::get_ngraph_element_type(
m_value_info_proto->type().tensor_type().elem_type());
}
return common::get_ngraph_element_type(
m_value_info_proto->type().tensor_type().elem_type());
return ngraph::element::dynamic;
}
std::shared_ptr<ngraph::Node>

View File

@ -0,0 +1,34 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx_import/onnx_framework_node.hpp>
namespace ngraph
{
namespace frontend
{
NGRAPH_RTTI_DEFINITION(ONNXFrameworkNode, "ONNXFrameworkNode", 1);
std::shared_ptr<Node>
ONNXFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const
{
return std::make_shared<ONNXFrameworkNode>(m_node, inputs);
}
NGRAPH_RTTI_DEFINITION(ONNXSubgraphFrameworkNode, "ONNXSubgraphFrameworkNode", 1);
} // namespace frontend
} // namespace ngraph

View File

@ -77,10 +77,18 @@ namespace ngraph
loop_carried_dependencies[i].get_node()->get_friendly_name();
}
const Subgraph& body_graph{
node.get_subgraph_from_attribute("body", loop_carried_dependencies_map)};
auto body_outputs = body_graph.get_ng_outputs();
const auto& body_inputs = body_graph.get_ng_parameters();
auto body_graph = node.get_subgraph();
auto body_outputs = body_graph->get_ng_outputs();
const auto& body_inputs = body_graph->get_ng_parameters();
// Infer loop body inputs' element type based on carried dependencies
for (size_t i = 0; i < loop_carried_dependencies.size(); i++)
{
body_inputs[i + 2]->set_element_type(
loop_carried_dependencies[i].get_element_type());
body_inputs[i + 2]->set_partial_shape(
loop_carried_dependencies[i].get_partial_shape());
}
// optional inputs
Output<ngraph::Node> trip_count;
@ -190,22 +198,22 @@ namespace ngraph
final_values.push_back(loop->get_iter_value(*body_outputs_it++, -1));
}
const auto& outputs_from_parent = body_graph.get_outputs_from_parent();
const auto& inputs_from_parent = body_graph->get_inputs_from_parent();
CHECK_VALID_NODE(
node,
static_cast<size_t>(std::distance(body_inputs_it, body_inputs.end())) ==
outputs_from_parent.size(),
inputs_from_parent.size(),
"Expected number of invariant parameters is"
" not equal number of provided outputs from parent scope");
" not equal number of provided inputs from parent scope");
// Set-up parameters from parent graph which are not changed during Loop's
// iterations
for (auto out_from_parent_it = outputs_from_parent.begin();
for (auto in_from_parent_it = inputs_from_parent.begin();
body_inputs_it != body_inputs.end() &&
out_from_parent_it != outputs_from_parent.end();
++body_inputs_it, ++out_from_parent_it)
in_from_parent_it != inputs_from_parent.end();
++body_inputs_it, ++in_from_parent_it)
{
loop->set_invariant_input(*body_inputs_it, *out_from_parent_it);
loop->set_invariant_input(*body_inputs_it, *in_from_parent_it);
}
// Set-up scan outputs

View File

@ -6,7 +6,9 @@
#include "core/graph.hpp"
#include "core/model.hpp"
#include "core/null_node.hpp"
#include "core/transform.hpp"
#include "onnx_import/onnx_framework_node.hpp"
#include "onnx_import/utils/onnx_internal.hpp"
namespace ngraph
@ -15,21 +17,81 @@ namespace ngraph
{
namespace detail
{
std::shared_ptr<Function>
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
void remove_dangling_parameters(std::shared_ptr<Function>& function)
{
auto p_model_proto = common::make_unique<ONNX_NAMESPACE::ModelProto>(model_proto);
auto model = common::make_unique<Model>(std::move(p_model_proto));
Graph graph{std::move(model)};
auto function = std::make_shared<Function>(
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
const auto parameters = function->get_parameters();
for (auto parameter : parameters)
{
function->get_output_op(i)->set_friendly_name(
graph.get_outputs().at(i).get_name());
const auto parameter_users = parameter->get_users();
// if a Parameter is connected to a ONNXFrameworkNode that was not converted
// during convert_function it means, this Parameter is dangling and we can
// remove it from function
const bool is_dangling_parameter = std::all_of(
parameter_users.begin(),
parameter_users.end(),
[](const std::shared_ptr<ngraph::Node>& node) -> bool {
return std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node) !=
nullptr;
});
if (is_dangling_parameter)
{
function->remove_parameter(parameter);
}
}
return function;
}
void remove_dangling_results(std::shared_ptr<Function>& function)
{
const auto results = function->get_results();
for (auto result : results)
{
// we can remove Result from function if after function conversion,
// Result is connected to NullNode only
const auto result_inputs = result->input_values();
const bool is_dangling_result =
std::all_of(result_inputs.begin(),
result_inputs.end(),
[](const Output<ngraph::Node>& node) -> bool {
return ngraph::op::is_null(node);
});
if (is_dangling_result)
{
function->remove_result(result);
}
}
}
void convert_decoded_function(std::shared_ptr<Function> function)
{
for (const auto& node : function->get_ordered_ops())
{
if (auto raw_node =
std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node))
{
if (auto subgraph_node =
std::dynamic_pointer_cast<frontend::ONNXSubgraphFrameworkNode>(
node))
{
subgraph_node->infer_inputs_from_parent();
convert_decoded_function(subgraph_node->get_subgraph_body());
}
const auto& onnx_node = raw_node->get_onnx_node();
OutputVector ng_nodes{onnx_node.get_ng_nodes()};
if (ng_nodes.size() > raw_node->get_output_size())
{
ng_nodes.resize(raw_node->get_output_size());
}
replace_node(raw_node, ng_nodes);
}
else
{
// Have to revalidate node because new intpus can affect shape/type
// propagation for already translated nodes
node->revalidate_and_infer_types();
}
}
remove_dangling_parameters(function);
remove_dangling_results(function);
}
std::shared_ptr<Function> import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto,
@ -39,7 +101,10 @@ namespace ngraph
transform::fixup_legacy_operators(model_proto);
transform::update_external_data_paths(model_proto, model_path);
return detail::convert_to_ng_function(model_proto);
auto p_model_proto = common::make_unique<ONNX_NAMESPACE::ModelProto>(model_proto);
auto model = common::make_unique<Model>(std::move(p_model_proto));
Graph graph{std::move(model)};
return graph.convert();
}
} // namespace detail
} // namespace onnx_import

View File

@ -390,8 +390,7 @@ def test_cast_errors():
for name, value in zip(node.input, [input_data])
]
output_tensors = [
make_tensor_value_info(name, onnx.TensorProto.FLOAT16, value.shape)
for name, value in zip(node.output, ())
make_tensor_value_info(node.output[0], onnx.TensorProto.FLOAT16, input_data.shape)
] # type: ignore
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
@ -406,8 +405,7 @@ def test_cast_errors():
for name, value in zip(node.input, [input_data])
]
output_tensors = [
make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape)
for name, value in zip(node.output, ())
make_tensor_value_info(node.output[0], onnx.TensorProto.INT32, input_data.shape)
] # type: ignore
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
@ -422,8 +420,7 @@ def test_cast_errors():
for name, value in zip(node.input, [input_data])
]
output_tensors = [
make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape)
for name, value in zip(node.output, ())
make_tensor_value_info(node.output[0], onnx.TensorProto.INT32, input_data.shape)
] # type: ignore
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
@ -438,8 +435,7 @@ def test_cast_errors():
for name, value in zip(node.input, [input_data])
]
output_tensors = [
make_tensor_value_info(name, onnx.TensorProto.COMPLEX128, value.shape)
for name, value in zip(node.output, ())
make_tensor_value_info(node.output[0], onnx.TensorProto.COMPLEX128, input_data.shape)
] # type: ignore
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)

View File

@ -2,7 +2,6 @@ ir_version: 7
producer_name: "backend-test"
graph {
node {
input: "target_shape"
output: "output"
op_type: "ConstantFill"
attribute {