Files
openvino/ngraph/test/graph_rewrite.cpp
Gleb Kazantaev 022ea97f18 Fixed disable/enable logic in PassConfig (#2940)
* Fixed disable/enable logic in PassConfig

* Removed set_pass_config method for Manager; added comments
2020-11-05 17:34:32 +03:00

387 lines
11 KiB
C++

// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pass/manager.hpp>
#include <util/test_tools.hpp>
using namespace ::testing;
using namespace std;
using namespace ngraph;
class TestPass : public ngraph::pass::MatcherPass
{
public:
NGRAPH_RTTI_DECLARATION;
TestPass()
: MatcherPass()
{
auto divide = std::make_shared<ngraph::pattern::op::Label>(
element::f32, Shape{}, pattern::has_class<opset3::Divide>());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
if (m_transformation_callback(m.get_match_root()))
{
auto relu =
std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
ngraph::replace_node(m.get_match_root(), relu);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
this->register_matcher(m, callback);
}
};
class Anchor : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
Anchor()
: GraphRewrite()
{
}
};
NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
std::shared_ptr<Function> get_function()
{
auto data =
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide_constant =
ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
auto divide = std::make_shared<ngraph::opset3::Divide>(data, divide_constant);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
ngraph::ParameterVector{data});
}
ngraph::pass::param_callback get_callback()
{
return [](const std::shared_ptr<const Node>& node) -> bool {
if (std::dynamic_pointer_cast<const opset3::Divide>(node))
{
return true;
}
else
{
return false;
}
};
}
TEST(GraphRewriteTest, MatcherPassCallback)
{
auto f = get_function();
Anchor anchor;
anchor.add_matcher<TestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, GraphRewriteCallback)
{
auto f = get_function();
Anchor anchor;
anchor.add_matcher<TestPass>();
anchor.set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallbackDeprecated)
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
manager.set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallback)
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallback2)
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<TestPass>();
manager.set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
class PrivateDivide : public ngraph::opset3::Divide
{
public:
NGRAPH_RTTI_DECLARATION;
using ngraph::opset3::Divide::Divide;
};
NGRAPH_RTTI_DEFINITION(PrivateDivide, "PrivateDivide", 0, ngraph::opset3::Divide);
std::shared_ptr<Function> get_derived_function()
{
auto data =
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide_constant =
ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
auto divide = std::make_shared<PrivateDivide>(data, divide_constant);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
ngraph::ParameterVector{data});
}
TEST(GraphRewriteTest, MatcherPassCallbackDerived)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
class TypeBasedTestPass : public ngraph::pass::MatcherPass
{
public:
TypeBasedTestPass()
: MatcherPass()
{
auto divide = std::make_shared<ngraph::opset3::Divide>(
std::make_shared<ngraph::pattern::op::Label>(),
std::make_shared<ngraph::pattern::op::Label>());
// element::f32, Shape{}, pattern::has_class<opset3::Divide>());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
if (m_transformation_callback(m.get_match_root()))
{
auto relu =
std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
ngraph::replace_node(m.get_match_root(), relu);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
this->register_matcher(m, callback);
}
};
class TypeBasedTestPassDerived : public ngraph::pass::MatcherPass
{
public:
TypeBasedTestPassDerived()
: MatcherPass()
{
auto divide =
std::make_shared<PrivateDivide>(std::make_shared<ngraph::pattern::op::Label>(),
std::make_shared<ngraph::pattern::op::Label>());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
if (m_transformation_callback(m.get_match_root()))
{
auto tanh =
std::make_shared<ngraph::opset3::Tanh>(m.get_match_root()->input_value(0));
ngraph::replace_node(m.get_match_root(), tanh);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
this->register_matcher(m, callback);
}
};
TEST(GraphRewriteTest, TypeBasedMatcherPassCallback)
{
auto f = get_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, TypeBasedMatcherPassCallbackDerived)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, TypeBasedMatcherPassOrder1)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2)
{
auto f = get_derived_function();
Anchor anchor;
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
}
TEST(PassConfigTest, Test1)
{
{
auto f = get_function();
pass::Manager manager;
manager.register_pass<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
manager.register_pass<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
pass::Manager manager(pass_config);
manager.register_pass<TestPass>();
pass_config->set_callback<TestPass>(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = anchor->get_pass_config();
pass_config->set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = anchor->get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto pass_config = std::make_shared<pass::PassConfig>();
pass::Manager manager1(pass_config);
pass::Manager manager2(pass_config);
ASSERT_EQ(pass_config.use_count(), 3);
}
{
auto f = get_function();
pass::Manager manager;
manager.register_pass<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
pass_config->disable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
pass_config->enable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
pass_config->disable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
pass_config->enable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
}