Files
openvino/ngraph/test/graph_rewrite.cpp
Sergey Lyalin a069e39906 Hierarchical extension to nGraph RTTI (#1245)
* RTTI base for ngraph::Node; cherry-pick from another branch, draft

* Added comments, moved code, switched to custom RTTI-based version of is_type

* Move rtti definitions in ngraph op class to the beginning of each class definition as a preparation for the next replacement

* Migrate part of operations to new RTTI

* Migrate GroupConvolution and Concat to new RTTI

* Apply code style for ngraph part

* Rename RTTI_DECLARATION/DEFINITION to NGRAPH_RTTI_DECLARATION/DEFINITION

* Reverted accidentally updated version of mkldnn

* TMP: rewrite RTTI back to constexprions as an attempt to fix static objects initialization order issue

* Apply ngraph code style

* Finalize move back to constexpr for RTTI

* Applied code-style

* Fix in fast algorithm in GraphRewrite, add new tests for this and other cases

* Make parent optional parameter for NGRAPH_RTTI_DECLARATION and remove Node::type_info; remove ability to have Node as a parent for type_info

* Try to resolve compilation error on Windows

* The next attempt to fix Windows build: re-introduce get_type_info_static

* Removed file that was removed in master and kept in this branch by mistake

* Next attempt to fix Windows build: externConstexpr

* Attempt to fix win build: extra public (suspect icc bug), remove get_type_info_static as useless.

* Next attempt to fix Windows: proxy const and constexpr

* Fixed constexpr

* Next attmpts: move get_type_info to cpp file

* Code stype fix

* Re-implemented RTTI without use of constexpr; run-time initialization is used; removed global definitions to avoid issues with order of static objects initialization

* Remove already unncecessary compiler flag for Windows

* get_type_info_static initializes static local constant with type_info that is used for CLASS::type_info and CLASS::get_type_info

* Rewrite commens for NGRAPH_RTTI_... macros, remove not used header
2020-08-04 06:35:58 +03:00

247 lines
7.1 KiB
C++

// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pass/manager.hpp>
#include <util/test_tools.hpp>
using namespace ::testing;
using namespace std;
using namespace ngraph;
class TestPass : public ngraph::pass::MatcherPass
{
public:
TestPass()
: MatcherPass()
{
auto divide = std::make_shared<ngraph::pattern::op::Label>(
element::f32, Shape{}, pattern::has_class<opset3::Divide>());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
if (m_transformation_callback(m.get_match_root()))
{
auto relu =
std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
ngraph::replace_node(m.get_match_root(), relu);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
this->register_matcher(m, callback);
}
};
class Anchor : public ngraph::pass::GraphRewrite
{
public:
Anchor()
: GraphRewrite()
{
}
};
std::shared_ptr<Function> get_function()
{
auto data =
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide_constant =
ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
auto divide = std::make_shared<ngraph::opset3::Divide>(data, divide_constant);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
ngraph::ParameterVector{data});
}
ngraph::pass::param_callback get_callback()
{
return [](const std::shared_ptr<const Node>& node) -> bool {
if (std::dynamic_pointer_cast<const opset3::Divide>(node))
{
return true;
}
else
{
return false;
}
};
}
TEST(GraphRewriteTest, MatcherPassCallback)
{
auto f = get_function();
Anchor anchor;
anchor.add_matcher<TestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, GraphRewriteCallback)
{
auto f = get_function();
Anchor anchor;
anchor.add_matcher<TestPass>();
anchor.set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallback)
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
manager.set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallback2)
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<TestPass>();
manager.set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
class PrivateDivide : public ngraph::opset3::Divide
{
public:
NGRAPH_RTTI_DECLARATION;
using ngraph::opset3::Divide::Divide;
};
NGRAPH_RTTI_DEFINITION(PrivateDivide, "PrivateDivide", 0, ngraph::opset3::Divide);
std::shared_ptr<Function> get_derived_function()
{
auto data =
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide_constant =
ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
auto divide = std::make_shared<PrivateDivide>(data, divide_constant);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
ngraph::ParameterVector{data});
}
TEST(GraphRewriteTest, MatcherPassCallbackDerived)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
class TypeBasedTestPass : public ngraph::pass::MatcherPass
{
public:
TypeBasedTestPass()
: MatcherPass()
{
auto divide = std::make_shared<ngraph::opset3::Divide>(
std::make_shared<ngraph::pattern::op::Label>(),
std::make_shared<ngraph::pattern::op::Label>());
// element::f32, Shape{}, pattern::has_class<opset3::Divide>());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
if (m_transformation_callback(m.get_match_root()))
{
auto relu =
std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
ngraph::replace_node(m.get_match_root(), relu);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
this->register_matcher(m, callback);
}
};
class TypeBasedTestPassDerived : public ngraph::pass::MatcherPass
{
public:
TypeBasedTestPassDerived()
: MatcherPass()
{
auto divide =
std::make_shared<PrivateDivide>(std::make_shared<ngraph::pattern::op::Label>(),
std::make_shared<ngraph::pattern::op::Label>());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
if (m_transformation_callback(m.get_match_root()))
{
auto tanh =
std::make_shared<ngraph::opset3::Tanh>(m.get_match_root()->input_value(0));
ngraph::replace_node(m.get_match_root(), tanh);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
this->register_matcher(m, callback);
}
};
TEST(GraphRewriteTest, TypeBasedMatcherPassCallback)
{
auto f = get_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, TypeBasedMatcherPassCallbackDerived)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, TypeBasedMatcherPassOrder1)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
}