Added BackwardGraphRewrite And Input<Node> RTInfo (#5343)

* Added BackwardGraphRewrite

* Add RT Info to Input<Node> class

* Add RTInfo tests; BackwardGraphRewrite tests
This commit is contained in:
Gleb Kazantaev 2021-04-27 17:26:20 +03:00 committed by GitHub
parent 3028c78594
commit 689f8aedb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 130 additions and 5 deletions

View File

@ -5,6 +5,7 @@
#pragma once
#include <cstring>
#include <map>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/partial_shape.hpp"
@ -23,6 +24,8 @@ namespace ngraph
{
};
class Variant;
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<Node>
@ -58,6 +61,12 @@ namespace ngraph
/// \param new_source_output A handle for the output that will replace this input's source.
void replace_source_output(const Output<Node>& new_source_output) const;
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
/// \return The reference to runtime info map
RTMap& get_rt_info();
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
@ -101,6 +110,10 @@ namespace ngraph
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;

View File

@ -219,11 +219,29 @@ namespace ngraph
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
protected:
bool apply_matcher_passes(std::shared_ptr<Function> f,
std::deque<std::shared_ptr<Node>> nodes_to_run);
bool m_enable_shape_inference = false;
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
BackwardGraphRewrite() = default;
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass)
: GraphRewrite(pass)
{
}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass
{
public:

View File

@ -82,6 +82,20 @@ namespace ngraph
{
}
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
RTMap& Input<Node>::get_rt_info() { return m_node->m_outputs.at(m_index).get_rt_info(); }
const RTMap& Input<Node>::get_rt_info() const
{
return m_node->m_outputs.at(m_index).get_rt_info();
}
const RTMap& Input<const Node>::get_rt_info() const
{
return m_node->m_outputs.at(m_index).get_rt_info();
}
const Node* Input<const Node>::get_node() const { return m_node; }
size_t Input<const Node>::get_index() const { return m_index; }
const element::Type& Input<const Node>::get_element_type() const

View File

@ -54,6 +54,8 @@ using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::GraphRewrite, "ngraph::pass::GraphRewrite", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::BackwardGraphRewrite, "ngraph::pass::BackwardGraphRewrite", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatcherPass, "ngraph::pass::MatcherPass", 0);
namespace ngraph
@ -71,19 +73,35 @@ namespace ngraph
} // namespace pass
} // namespace ngraph
bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
{
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
bool rewritten = false;
const auto& pass_config = get_pass_config();
// Initialize execution queue with nodes in topological order
deque<std::shared_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops())
{
nodes_to_run.emplace_front(node);
}
return apply_matcher_passes(f, std::move(nodes_to_run));
}
bool pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
{
// Initialize execution queue with nodes in topological order
deque<std::shared_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops())
{
nodes_to_run.emplace_back(node);
}
return apply_matcher_passes(f, std::move(nodes_to_run));
}
bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f,
deque<std::shared_ptr<Node>> nodes_to_run)
{
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
bool rewritten = false;
const auto& pass_config = get_pass_config();
// Check that all Matchers in MatcherPasses has type bases root node
bool all_roots_has_type = true;

View File

@ -39,6 +39,23 @@ public:
}
};
class GatherNodesPass : public ngraph::pass::MatcherPass
{
public:
NGRAPH_RTTI_DECLARATION;
GatherNodesPass(NodeVector & order)
: MatcherPass()
{
ngraph::matcher_pass_callback callback = [&order](pattern::Matcher& m) {
order.push_back(m.get_match_root());
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::any_input(), "GatherNodesPass");
this->register_matcher(m, callback);
}
};
class Anchor : public ngraph::pass::GraphRewrite
{
public:
@ -51,6 +68,7 @@ public:
NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
NGRAPH_RTTI_DEFINITION(GatherNodesPass, "GatherNodesPass", 0);
std::shared_ptr<Function> get_function()
{
@ -77,6 +95,34 @@ ngraph::pass::param_callback get_callback()
};
}
TEST(GraphRewriteOrderTest, MatcherPass)
{
auto f = get_function();
NodeVector order;
ngraph::pass::Manager m;
auto pass = m.register_pass<pass::GraphRewrite>();
pass->add_matcher<GatherNodesPass>(order);
m.run_passes(f);
ASSERT_EQ(order, f->get_ordered_ops());
}
TEST(BackwardGraphRewriteOrderTest, MatcherPass)
{
auto f = get_function();
NodeVector order;
ngraph::pass::Manager m;
auto pass = m.register_pass<pass::BackwardGraphRewrite>();
pass->add_matcher<GatherNodesPass>(order);
m.run_passes(f);
auto ref_order = f->get_ordered_ops();
std::reverse(ref_order.begin(), ref_order.end());
ASSERT_EQ(order, ref_order);
}
TEST(GraphRewriteTest, MatcherPassCallback)
{
auto f = get_function();

View File

@ -124,9 +124,25 @@ TEST(op, variant)
EXPECT_EQ(ship.y, 4);
auto node = make_shared<op::Parameter>(element::f32, Shape{1});
// Check Node RTInfo
node->get_rt_info()["A"] = var_ship;
auto node_var_ship = node->get_rt_info().at("A");
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_var_ship)));
Ship& node_ship = as_type_ptr<VariantWrapper<Ship>>(node_var_ship)->get();
EXPECT_EQ(&node_ship, &ship);
// Check Node Input<Node> RTInfo
auto relu = make_shared<op::Relu>(node);
relu->input(0).get_rt_info()["A"] = var_ship;
auto node_input_var_ship = node->get_rt_info().at("A");
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_input_var_ship)));
Ship& node_input_ship = as_type_ptr<VariantWrapper<Ship>>(node_input_var_ship)->get();
EXPECT_EQ(&node_input_ship, &ship);
// Check Node Input<Node> RTInfo
node->output(0).get_rt_info()["A"] = var_ship;
auto node_output_var_ship = node->get_rt_info().at("A");
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_output_var_ship)));
Ship& node_output_ship = as_type_ptr<VariantWrapper<Ship>>(node_input_var_ship)->get();
EXPECT_EQ(&node_output_ship, &ship);
}