Remove ngraph::Lambda class, replace TensorIterator body with ngraph::Function (#1830)

* remove Lambda class, replace TensorIterator body with ngraph::Function

* Fix passing parameters from parent graph to subgraph

Co-authored-by: mbencer <mateusz.bencer@intel.com>
This commit is contained in:
Ivan Tikhonov
2020-08-19 07:09:32 +03:00
committed by GitHub
parent 618c6f7f23
commit c5ca8f5b51
24 changed files with 452 additions and 520 deletions

View File

@@ -64,13 +64,15 @@ namespace ngraph
void add_provenance_tags(const Node& onnx_node,
const OutputVector& ng_node_vector) const;
protected:
ParameterVector m_parameters;
private:
const ONNX_NAMESPACE::GraphProto* m_graph_proto;
std::unique_ptr<GraphCache> m_cache;
std::vector<Node> m_nodes;
std::vector<ValueInfo> m_inputs;
std::vector<ValueInfo> m_outputs;
ParameterVector m_parameters;
Model* m_model;
};

View File

@@ -20,6 +20,7 @@
#include <sstream>
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/provenance.hpp"
#include "onnx_import/core/graph.hpp"
#include "onnx_import/core/node.hpp"
@@ -291,6 +292,37 @@ namespace ngraph
model,
std::unique_ptr<SubgraphCache>(new SubgraphCache(parent_graph.get_graph_cache())))
{
std::vector<std::shared_ptr<ngraph::Node>> subgraph_root_nodes;
const auto& outputs = as_result_vector(get_ng_outputs());
for (auto& out : outputs)
{
subgraph_root_nodes.push_back(out);
}
const auto& params = get_ng_parameters();
for (auto& param : params)
{
subgraph_root_nodes.push_back(param);
}
const auto subgraph_nodes = topological_sort(subgraph_root_nodes);
const auto& parent_graph_parameters = parent_graph.get_ng_parameters();
for (const auto& node : subgraph_nodes)
{
if (op::is_parameter(node))
{
const auto sub_it = std::find(m_parameters.begin(), m_parameters.end(), node);
// not present as subgraph parameter
if (sub_it == m_parameters.end())
{
const auto parent_it = std::find(
parent_graph_parameters.begin(), parent_graph_parameters.end(), node);
if (parent_it != m_parameters.end())
{
m_parameters.push_back(*parent_it);
}
}
}
}
}
} // namespace onnx_import

View File

@@ -116,15 +116,15 @@ namespace ngraph
const auto& graph_outputs = body_graph.get_ng_outputs();
const auto& graph_inputs = body_graph.get_ng_parameters();
CHECK_VALID_NODE(
node,
graph_inputs.size() == loop_carried_dependencies.size() + 2,
"The provided loop body graph inputs size (",
graph_inputs.size(),
"), is not equal to the sum of loop carried dependencies and two mandatory"
" inputs (",
loop_carried_dependencies.size() + 2,
")");
CHECK_VALID_NODE(node,
graph_inputs.size() >= loop_carried_dependencies.size() + 2,
"The provided loop body graph inputs size (",
graph_inputs.size(),
"), is not greater than the sum of loop carried dependencies "
"and two mandatory"
" inputs (",
loop_carried_dependencies.size() + 2,
")");
CHECK_VALID_NODE(node,
graph_outputs.size() >= loop_carried_dependencies.size() + 1,
@@ -144,8 +144,8 @@ namespace ngraph
default_opset::Constant::create(element::boolean, Shape{}, {true});
// create the loop body
const auto body = std::make_shared<ngraph::op::TensorIterator::BodyLambda>(
graph_outputs, graph_inputs);
const auto body =
std::make_shared<ngraph::Function>(graph_outputs, graph_inputs);
auto tensor_iterator = std::make_shared<ngraph::op::TensorIterator>();
tensor_iterator->set_body(body);