Adding sinks to ngraph::Function (#2741)

* initial draft of adding sinks to ngraph::Function

* style fixes

* code style fixes

* code style fixes

* code style fix

* review fix+build fix

* code style fix

* fix build

* API changed according to latest discussion

* review fixes

* review fixes + tests

* added 1 more ctor

* style fixes

* used new api in ir parser

* fixed build

* review fixes

* remove validate_nodes_and_infer_types from remove_sink/remove_result

* removed validate_.. after discussion

* style fix
This commit is contained in:
Svetlana Dolinina 2020-10-31 19:41:05 +03:00 committed by GitHub
parent 28de789993
commit 32732a1f29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 384 additions and 17 deletions

View File

@ -28,6 +28,8 @@ void ngraph::op::GenericIE::addExtension(std::shared_ptr<const ngraph::Function>
for (auto r : func->get_results())
nodes.emplace_back(r);
for (auto s : func->get_sinks())
nodes.emplace_back(s);
for (auto param : func->get_parameters())
nodes.emplace_back(param);

View File

@ -133,7 +133,7 @@ std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, std::i
ngraph::ParameterVector parameter_nodes;
ngraph::ResultVector result_nodes;
ngraph::NodeVector allNodes;
std::vector<std::shared_ptr<ngraph::op::Assign>> assign_nodes;
ngraph::SinkVector assign_nodes;
std::map<std::string, std::shared_ptr<ngraph::Node>> variable_id_to_read_value;
// Following topological order create nGraph operations
@ -187,15 +187,12 @@ std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, std::i
}
::ngraph::op::GenericIE::DisableReshape noReshape(allNodes);
auto function = std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
if (!result_nodes.empty()) {
for (const auto& assign : assign_nodes) {
assign->add_control_dependency(variable_id_to_read_value.at(assign->get_variable_id()));
// often Assign node is a leaf of the graph, we add control_dependency for one of the results
// to make Assign node visible for traversals get_ops(), get_ordered_ops()
result_nodes[0]->add_control_dependency(assign);
}
auto function = std::make_shared<ngraph::Function>(result_nodes, assign_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
for (const auto& assign : assign_nodes) {
assign->add_control_dependency(
variable_id_to_read_value.at(std::dynamic_pointer_cast<ngraph::op::Assign>(assign)->get_variable_id()));
}
CNNNetwork net(function, _exts);
parsePreProcess(net, root, binStream);
return net;

View File

@ -27,6 +27,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/sink.hpp"
namespace ngraph
{
@ -52,6 +53,16 @@ namespace ngraph
const ParameterVector& parameters,
const std::string& name = "");
Function(const ResultVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name = "");
Function(const OutputVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name = "");
virtual ~Function() {}
/// Return the number of outputs for this function.
size_t get_output_size() const;
@ -138,6 +149,27 @@ namespace ngraph
bool evaluate(const HostTensorVector& output_tensors,
const HostTensorVector& input_tensors) const;
/// \brief Return a list of function's sinks.
const SinkVector& get_sinks() const { return m_sinks; }
/// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
/// manually after all changes.
/// \param sinks new sink nodes
void add_sinks(const SinkVector& sinks);
/// \brief Delete sink node from the list of sinks. Method doesn't delete node from graph.
/// \param sink Sink to delete
void remove_sink(const std::shared_ptr<op::Sink>& sink);
/// \brief Add new Result nodes to the list. Method doesn't validate graph, it should be
/// done manually after all changes.
/// \param results new Result nodes
void add_results(const ResultVector& results);
/// \brief Delete Result node from the list of results. Method will not delete node from
/// graph.
/// \param result Result node to delete
void remove_result(const std::shared_ptr<op::Result>& result);
private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
@ -150,6 +182,9 @@ namespace ngraph
topological_sort_t m_topological_sorter;
ResultVector m_results;
// List of the nodes with side effect in graph.
// These nodes are not outputs of graph but should not be removed even if have no children.
SinkVector m_sinks;
ParameterVector m_parameters;
};

View File

@ -16,7 +16,7 @@
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/op/sink.hpp"
#include "ngraph/op/util/variable.hpp"
namespace ngraph
@ -26,11 +26,10 @@ namespace ngraph
namespace v3
{
/// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public Op
class NGRAPH_API Assign : public Sink
{
public:
static constexpr NodeTypeInfo type_info{"Assign", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
Assign() = default;
/// \brief Constructs an Assign operation.

View File

@ -0,0 +1,46 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <vector>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// Root of nodes that can be sink nodes
class NGRAPH_API Sink : public Op
{
public:
virtual ~Sink() = 0;
NGRAPH_RTTI_DECLARATION;
protected:
Sink()
: Op()
{
}
Sink(const OutputVector& arguments)
: Op(arguments)
{
}
};
}
using SinkVector = std::vector<std::shared_ptr<op::Sink>>;
}

View File

@ -77,6 +77,28 @@ Function::Function(const std::shared_ptr<Node>& result,
{
}
Function::Function(const ResultVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name)
: m_results(results)
, m_sinks(sinks)
, m_parameters(parameters)
, m_name(name)
, m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
{
validate_nodes_and_infer_types();
}
Function::Function(const OutputVector& results,
const SinkVector& sinks,
const ParameterVector& parameters,
const std::string& name)
: Function(as_result_vector(results), sinks, parameters, name)
{
}
void Function::validate_nodes_and_infer_types()
{
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "Function::validate_nodes_and_infer_types");
@ -106,6 +128,10 @@ std::vector<shared_ptr<Node>> Function::get_ordered_ops() const
{
nodes.push_back(r);
}
for (auto& r : get_sinks())
{
nodes.emplace_back(r);
}
for (auto& param : get_parameters())
{
nodes.push_back(param);
@ -122,6 +148,11 @@ void Function::map_unordered_ops(std::function<void(Node*)> f) const
{
remaining_ops.push(r.get());
}
for (auto& r : get_sinks())
{
remaining_ops.push(r.get());
}
for (auto& param : get_parameters())
{
remaining_ops.push(param.get());
@ -347,6 +378,33 @@ bool Function::visit_attributes(AttributeVisitor& visitor)
return true;
}
void Function::add_sinks(const SinkVector& sinks)
{
m_sinks.insert(m_sinks.end(), sinks.begin(), sinks.end());
}
void Function::remove_sink(const std::shared_ptr<op::Sink>& sink)
{
m_sinks.erase(std::remove_if(m_sinks.begin(),
m_sinks.end(),
[&sink](std::shared_ptr<op::Sink>& s) { return s == sink; }),
m_sinks.end());
}
void Function::add_results(const ResultVector& results)
{
m_results.insert(m_results.end(), results.begin(), results.end());
}
void Function::remove_result(const std::shared_ptr<op::Result>& result)
{
m_results.erase(
std::remove_if(m_results.begin(),
m_results.end(),
[&result](std::shared_ptr<op::v0::Result>& r) { return r == result; }),
m_results.end());
}
constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
AttributeAdapter<shared_ptr<Function>>::AttributeAdapter(shared_ptr<Function>& ref)
@ -401,6 +459,10 @@ bool AttributeAdapter<shared_ptr<Function>>::visit_attributes(AttributeVisitor&
{
results.push_back(result);
}
for (auto sink : m_ref->get_sinks())
{
results.push_back(sink);
}
int64_t i = 0;
ostringstream index;

View File

@ -57,6 +57,10 @@ void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_pt
{
nodes.push_back(r);
}
for (auto s : p->get_sinks())
{
nodes.emplace_back(s);
}
for (auto param : p->get_parameters())
{
@ -419,7 +423,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
// clone function operations
clone_nodes(func.get_ops(), node_map);
// get cloned function results and parameters
// get cloned function results and sinks and parameters
ResultVector cloned_results;
for (shared_ptr<Node> node : func.get_results())
{
@ -430,6 +434,12 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
}
cloned_results.push_back(result);
}
SinkVector cloned_sinks;
for (auto node : func.get_sinks())
{
cloned_sinks.push_back(static_pointer_cast<op::Sink>(node_map.at(node.get())));
}
std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func.get_parameters())
{
@ -439,6 +449,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
// create and return cloned function
auto result = std::make_shared<ngraph::Function>(cloned_results, cloned_params);
result->set_friendly_name(func.get_friendly_name());
result->add_sinks(cloned_sinks);
return result;
}
@ -863,6 +874,18 @@ bool ngraph::check_for_cycles(const ngraph::Function* func,
}
}
for (auto res : func->get_sinks())
{
std::deque<std::shared_ptr<Node>> path;
// mirror of path stack for faster cycle check
std::unordered_set<std::shared_ptr<Node>> path_set;
if (check_for_cycles_bkwd(res, path, path_set, cycle_nodes))
{
is_bkwd_cycle = true;
return true;
}
}
for (auto param : func->get_parameters())
{
std::deque<std::shared_ptr<Node>> path;

View File

@ -21,10 +21,10 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v3::Assign::type_info;
NGRAPH_RTTI_DEFINITION(op::v3::Assign, "Assign", 3, op::Sink);
op::v3::Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
: Op({new_value})
: Sink({new_value})
, m_variable_id(variable_id)
{
constructor_validate_and_infer_types();

View File

@ -0,0 +1,25 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/sink.hpp"
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::Sink, "Sink", 0);
op::Sink::~Sink()
{
}

View File

@ -120,6 +120,11 @@ std::shared_ptr<Function>
new_results[i] = std::static_pointer_cast<op::Result>(m[new_results[i].get()]);
new_results[i]->set_friendly_name(name);
}
SinkVector new_sinks = f->get_sinks();
for (size_t i = 0; i < new_sinks.size(); i++)
{
new_sinks[i] = std::static_pointer_cast<op::Sink>(m[new_sinks[i].get()]);
}
return std::make_shared<Function>(new_results, new_parameters);
return std::make_shared<Function>(new_results, new_sinks, new_parameters);
}

View File

@ -204,3 +204,176 @@ TEST(build_graph, default_output_checks)
FAIL() << "nullptr initialization of Output failed";
}
}
TEST(build_graph, build_graph_with_sink)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::Split>(pattern, axis, 3);
auto assign = make_shared<op::Assign>(crop, "v0");
auto f = make_shared<Function>(ResultVector({res}), SinkVector({assign}), ParameterVector{arg});
SinkVector sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 1);
EXPECT_EQ(sinks[0], assign);
NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}
TEST(build_graph, build_graph_with_sink_output_ctor)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::Split>(pattern, axis, 3);
auto assign = make_shared<op::Assign>(crop, "v0");
auto f = make_shared<Function>(
OutputVector({pattern->output(0)}), SinkVector({assign}), ParameterVector{arg});
SinkVector sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 1);
EXPECT_EQ(sinks[0], assign);
NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}
TEST(build_graph, build_graph_with_add_sink)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::Split>(pattern, axis, 3);
auto assign = make_shared<op::Assign>(crop, "v0");
auto f = make_shared<Function>(ResultVector({res}), ParameterVector{arg});
NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
SinkVector sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 0);
f->add_sinks(SinkVector({assign}));
sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 1);
EXPECT_EQ(sinks[0], assign);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}
TEST(build_graph, build_graph_with_wrong_remove_sink)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::Split>(pattern, axis, 3);
auto assign = make_shared<op::Assign>(crop, "v0");
auto f = make_shared<Function>(ResultVector({res}), SinkVector({assign}), ParameterVector{arg});
SinkVector sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 1);
EXPECT_EQ(sinks[0], assign);
f->remove_sink(assign);
sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 0);
auto nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
}
TEST(build_graph, build_graph_with_remove_sink)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::Split>(pattern, axis, 3);
auto assign = make_shared<op::Assign>(crop, "v0");
auto f = make_shared<Function>(ResultVector({res}), SinkVector({assign}), ParameterVector{arg});
pattern->input(1).replace_source_output(arg);
SinkVector sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 1);
EXPECT_EQ(sinks[0], assign);
f->remove_sink(assign);
sinks = f->get_sinks();
EXPECT_EQ(sinks.size(), 0);
auto nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 3);
}
TEST(build_graph, build_graph_with_add_result)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");
auto f = make_shared<Function>(ResultVector({res}), ParameterVector{arg});
NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
ResultVector results = f->get_results();
EXPECT_EQ(results.size(), 1);
f->add_results(ResultVector({res2}));
results = f->get_results();
EXPECT_EQ(results.size(), 2);
EXPECT_EQ(results[1], res2);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}
TEST(build_graph, build_graph_with_remove_result)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");
auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg});
NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
ResultVector results = f->get_results();
EXPECT_EQ(results.size(), 2);
f->remove_result(res2);
results = f->get_results();
EXPECT_EQ(results.size(), 1);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
}