Remove ngraph::Lambda class, replace TensorIterator body with ngraph::Function (#1830)
* remove Lambda class, replace TensorIterator body with ngraph::Function * Fix passing parameters from parent graph to subgraph Co-authored-by: mbencer <mateusz.bencer@intel.com>
This commit is contained in:
@@ -64,13 +64,15 @@ namespace ngraph
|
||||
void add_provenance_tags(const Node& onnx_node,
|
||||
const OutputVector& ng_node_vector) const;
|
||||
|
||||
protected:
|
||||
ParameterVector m_parameters;
|
||||
|
||||
private:
|
||||
const ONNX_NAMESPACE::GraphProto* m_graph_proto;
|
||||
std::unique_ptr<GraphCache> m_cache;
|
||||
std::vector<Node> m_nodes;
|
||||
std::vector<ValueInfo> m_inputs;
|
||||
std::vector<ValueInfo> m_outputs;
|
||||
ParameterVector m_parameters;
|
||||
Model* m_model;
|
||||
};
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
#include "onnx_import/core/graph.hpp"
|
||||
#include "onnx_import/core/node.hpp"
|
||||
@@ -291,6 +292,37 @@ namespace ngraph
|
||||
model,
|
||||
std::unique_ptr<SubgraphCache>(new SubgraphCache(parent_graph.get_graph_cache())))
|
||||
{
|
||||
std::vector<std::shared_ptr<ngraph::Node>> subgraph_root_nodes;
|
||||
const auto& outputs = as_result_vector(get_ng_outputs());
|
||||
for (auto& out : outputs)
|
||||
{
|
||||
subgraph_root_nodes.push_back(out);
|
||||
}
|
||||
const auto& params = get_ng_parameters();
|
||||
for (auto& param : params)
|
||||
{
|
||||
subgraph_root_nodes.push_back(param);
|
||||
}
|
||||
const auto subgraph_nodes = topological_sort(subgraph_root_nodes);
|
||||
|
||||
const auto& parent_graph_parameters = parent_graph.get_ng_parameters();
|
||||
for (const auto& node : subgraph_nodes)
|
||||
{
|
||||
if (op::is_parameter(node))
|
||||
{
|
||||
const auto sub_it = std::find(m_parameters.begin(), m_parameters.end(), node);
|
||||
// not present as subgraph parameter
|
||||
if (sub_it == m_parameters.end())
|
||||
{
|
||||
const auto parent_it = std::find(
|
||||
parent_graph_parameters.begin(), parent_graph_parameters.end(), node);
|
||||
if (parent_it != m_parameters.end())
|
||||
{
|
||||
m_parameters.push_back(*parent_it);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
@@ -116,15 +116,15 @@ namespace ngraph
|
||||
const auto& graph_outputs = body_graph.get_ng_outputs();
|
||||
const auto& graph_inputs = body_graph.get_ng_parameters();
|
||||
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
graph_inputs.size() == loop_carried_dependencies.size() + 2,
|
||||
"The provided loop body graph inputs size (",
|
||||
graph_inputs.size(),
|
||||
"), is not equal to the sum of loop carried dependencies and two mandatory"
|
||||
" inputs (",
|
||||
loop_carried_dependencies.size() + 2,
|
||||
")");
|
||||
CHECK_VALID_NODE(node,
|
||||
graph_inputs.size() >= loop_carried_dependencies.size() + 2,
|
||||
"The provided loop body graph inputs size (",
|
||||
graph_inputs.size(),
|
||||
"), is not greater than the sum of loop carried dependencies "
|
||||
"and two mandatory"
|
||||
" inputs (",
|
||||
loop_carried_dependencies.size() + 2,
|
||||
")");
|
||||
|
||||
CHECK_VALID_NODE(node,
|
||||
graph_outputs.size() >= loop_carried_dependencies.size() + 1,
|
||||
@@ -144,8 +144,8 @@ namespace ngraph
|
||||
default_opset::Constant::create(element::boolean, Shape{}, {true});
|
||||
|
||||
// create the loop body
|
||||
const auto body = std::make_shared<ngraph::op::TensorIterator::BodyLambda>(
|
||||
graph_outputs, graph_inputs);
|
||||
const auto body =
|
||||
std::make_shared<ngraph::Function>(graph_outputs, graph_inputs);
|
||||
auto tensor_iterator = std::make_shared<ngraph::op::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user