* RTTI base for ngraph::Node; cherry-pick from another branch, draft * Added comments, moved code, switched to custom RTTI-based version of is_type * Move rtti definitions in ngraph op class to the beginning of each class definition as a preparation for the next replacement * Migrate part of operations to new RTTI * Migrate GroupConvolution and Concat to new RTTI * Apply code style for ngraph part * Rename RTTI_DECLARATION/DEFINITION to NGRAPH_RTTI_DECLARATION/DEFINITION * Reverted accidentally updated version of mkldnn * TMP: rewrite RTTI back to constexprions as an attempt to fix static objects initialization order issue * Apply ngraph code style * Finalize move back to constexpr for RTTI * Applied code-style * Fix in fast algorithm in GraphRewrite, add new tests for this and other cases * Make parent optional parameter for NGRAPH_RTTI_DECLARATION and remove Node::type_info; remove ability to have Node as a parent for type_info * Try to resolve compilation error on Windows * The next attempt to fix Windows build: re-introduce get_type_info_static * Removed file that was removed in master and kept in this branch by mistake * Next attempt to fix Windows build: externConstexpr * Attempt to fix win build: extra public (suspect icc bug), remove get_type_info_static as useless. * Next attempt to fix Windows: proxy const and constexpr * Fixed constexpr * Next attmpts: move get_type_info to cpp file * Code stype fix * Re-implemented RTTI without use of constexpr; run-time initialization is used; removed global definitions to avoid issues with order of static objects initialization * Remove already unncecessary compiler flag for Windows * get_type_info_static initializes static local constant with type_info that is used for CLASS::type_info and CLASS::get_type_info * Rewrite commens for NGRAPH_RTTI_... macros, remove not used header
247 lines
7.1 KiB
C++
247 lines
7.1 KiB
C++
// Copyright (C) 2018-2020 Intel Corporation
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <ngraph/opsets/opset3.hpp>
|
|
#include <ngraph/pass/graph_rewrite.hpp>
|
|
#include <ngraph/pass/manager.hpp>
|
|
#include <util/test_tools.hpp>
|
|
|
|
using namespace ::testing;
|
|
using namespace std;
|
|
using namespace ngraph;
|
|
|
|
class TestPass : public ngraph::pass::MatcherPass
|
|
{
|
|
public:
|
|
TestPass()
|
|
: MatcherPass()
|
|
{
|
|
auto divide = std::make_shared<ngraph::pattern::op::Label>(
|
|
element::f32, Shape{}, pattern::has_class<opset3::Divide>());
|
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
|
if (m_transformation_callback(m.get_match_root()))
|
|
{
|
|
auto relu =
|
|
std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
|
|
ngraph::replace_node(m.get_match_root(), relu);
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
|
|
this->register_matcher(m, callback);
|
|
}
|
|
};
|
|
|
|
class Anchor : public ngraph::pass::GraphRewrite
|
|
{
|
|
public:
|
|
Anchor()
|
|
: GraphRewrite()
|
|
{
|
|
}
|
|
};
|
|
|
|
std::shared_ptr<Function> get_function()
|
|
{
|
|
auto data =
|
|
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
|
auto divide_constant =
|
|
ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
|
|
auto divide = std::make_shared<ngraph::opset3::Divide>(data, divide_constant);
|
|
return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
|
|
ngraph::ParameterVector{data});
|
|
}
|
|
|
|
ngraph::pass::param_callback get_callback()
|
|
{
|
|
return [](const std::shared_ptr<const Node>& node) -> bool {
|
|
if (std::dynamic_pointer_cast<const opset3::Divide>(node))
|
|
{
|
|
return true;
|
|
}
|
|
else
|
|
{
|
|
return false;
|
|
}
|
|
};
|
|
}
|
|
|
|
TEST(GraphRewriteTest, MatcherPassCallback)
|
|
{
|
|
auto f = get_function();
|
|
|
|
Anchor anchor;
|
|
anchor.add_matcher<TestPass>()->set_callback(get_callback());
|
|
anchor.run_on_function(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
TEST(GraphRewriteTest, GraphRewriteCallback)
|
|
{
|
|
auto f = get_function();
|
|
|
|
Anchor anchor;
|
|
anchor.add_matcher<TestPass>();
|
|
anchor.set_callback(get_callback());
|
|
anchor.run_on_function(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
TEST(GraphRewriteTest, ManagerCallback)
|
|
{
|
|
auto f = get_function();
|
|
|
|
pass::Manager manager;
|
|
auto anchor = manager.register_pass<Anchor>();
|
|
anchor->add_matcher<TestPass>();
|
|
manager.set_callback(get_callback());
|
|
manager.run_passes(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
TEST(GraphRewriteTest, ManagerCallback2)
|
|
{
|
|
auto f = get_function();
|
|
|
|
pass::Manager manager;
|
|
auto anchor = manager.register_pass<TestPass>();
|
|
manager.set_callback(get_callback());
|
|
manager.run_passes(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
class PrivateDivide : public ngraph::opset3::Divide
|
|
{
|
|
public:
|
|
NGRAPH_RTTI_DECLARATION;
|
|
using ngraph::opset3::Divide::Divide;
|
|
};
|
|
|
|
NGRAPH_RTTI_DEFINITION(PrivateDivide, "PrivateDivide", 0, ngraph::opset3::Divide);
|
|
|
|
std::shared_ptr<Function> get_derived_function()
|
|
{
|
|
auto data =
|
|
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
|
auto divide_constant =
|
|
ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
|
|
auto divide = std::make_shared<PrivateDivide>(data, divide_constant);
|
|
return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
|
|
ngraph::ParameterVector{data});
|
|
}
|
|
|
|
TEST(GraphRewriteTest, MatcherPassCallbackDerived)
|
|
{
|
|
auto f = get_derived_function();
|
|
|
|
Anchor anchor;
|
|
anchor.add_matcher<TestPass>()->set_callback(get_callback());
|
|
anchor.run_on_function(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
class TypeBasedTestPass : public ngraph::pass::MatcherPass
|
|
{
|
|
public:
|
|
TypeBasedTestPass()
|
|
: MatcherPass()
|
|
{
|
|
auto divide = std::make_shared<ngraph::opset3::Divide>(
|
|
std::make_shared<ngraph::pattern::op::Label>(),
|
|
std::make_shared<ngraph::pattern::op::Label>());
|
|
// element::f32, Shape{}, pattern::has_class<opset3::Divide>());
|
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
|
if (m_transformation_callback(m.get_match_root()))
|
|
{
|
|
auto relu =
|
|
std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
|
|
ngraph::replace_node(m.get_match_root(), relu);
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
|
|
this->register_matcher(m, callback);
|
|
}
|
|
};
|
|
|
|
class TypeBasedTestPassDerived : public ngraph::pass::MatcherPass
|
|
{
|
|
public:
|
|
TypeBasedTestPassDerived()
|
|
: MatcherPass()
|
|
{
|
|
auto divide =
|
|
std::make_shared<PrivateDivide>(std::make_shared<ngraph::pattern::op::Label>(),
|
|
std::make_shared<ngraph::pattern::op::Label>());
|
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
|
if (m_transformation_callback(m.get_match_root()))
|
|
{
|
|
auto tanh =
|
|
std::make_shared<ngraph::opset3::Tanh>(m.get_match_root()->input_value(0));
|
|
ngraph::replace_node(m.get_match_root(), tanh);
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
|
|
this->register_matcher(m, callback);
|
|
}
|
|
};
|
|
|
|
TEST(GraphRewriteTest, TypeBasedMatcherPassCallback)
|
|
{
|
|
auto f = get_function();
|
|
|
|
Anchor anchor;
|
|
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
|
anchor.run_on_function(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
TEST(GraphRewriteTest, TypeBasedMatcherPassCallbackDerived)
|
|
{
|
|
auto f = get_derived_function();
|
|
|
|
Anchor anchor;
|
|
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
|
anchor.run_on_function(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
TEST(GraphRewriteTest, TypeBasedMatcherPassOrder1)
|
|
{
|
|
auto f = get_derived_function();
|
|
|
|
Anchor anchor;
|
|
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
|
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
|
|
anchor.run_on_function(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
|
}
|
|
|
|
TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2)
|
|
{
|
|
auto f = get_derived_function();
|
|
|
|
Anchor anchor;
|
|
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
|
|
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
|
anchor.run_on_function(f);
|
|
|
|
ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
|
|
} |