Fixed disable/enable logic in PassConfig (#2940)

* Fixed disable/enable logic in PassConfig

* Removed set_pass_config method for Manager; added comments
This commit is contained in:
Gleb Kazantaev 2020-11-05 17:34:32 +03:00 committed by GitHub
parent 24ed4133dd
commit 022ea97f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 477 additions and 28 deletions

View File

@ -61,7 +61,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet1ToLegacy, "ConvertOpSet1ToLega
bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph::Function> f) {
OV_ITT_SCOPED_TASK(InferenceEngine::itt::domains::IELegacy, "ngraph::pass::ConvertOpSet1ToLegacy"); OV_ITT_SCOPED_TASK(InferenceEngine::itt::domains::IELegacy, "ngraph::pass::ConvertOpSet1ToLegacy");
ngraph::pass::Manager manager; ngraph::pass::Manager manager(get_pass_config());
manager.register_pass<ngraph::pass::ConstantFolding>(); manager.register_pass<ngraph::pass::ConstantFolding>();
@ -148,7 +148,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ConstantFolding>(); manager.register_pass<ngraph::pass::ConstantFolding>();
manager.set_pass_config(get_pass_config());
manager.run_passes(f); manager.run_passes(f);
return true; return true;
} }

View File

@ -53,7 +53,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations",
bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) {
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::CommonOptimizations"); OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::CommonOptimizations");
ngraph::pass::Manager manager; ngraph::pass::Manager manager(get_pass_config());
// This pass must be called first in pipeline // This pass must be called first in pipeline
manager.register_pass<ngraph::pass::InitNodeInfo>(); manager.register_pass<ngraph::pass::InitNodeInfo>();
@ -118,8 +118,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>(); fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions"); fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
// Propagate local PassConfig to internal pass::Manager
manager.set_pass_config(get_pass_config());
manager.run_passes(f); manager.run_passes(f);
return true; return true;
} }

View File

@ -18,12 +18,11 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet2ToOpSet1, "ConvertOpSet2ToOpSe
bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) {
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet2ToOpSet1"); OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet2ToOpSet1");
ngraph::pass::Manager manager; ngraph::pass::Manager manager(get_pass_config());
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>(); manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
manager.register_pass<ngraph::pass::ConvertBatchToSpace>(); manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
manager.set_pass_config(get_pass_config());
manager.run_passes(f); manager.run_passes(f);
return true; return true;
} }

View File

@ -22,7 +22,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet3ToOpSet2, "ConvertOpSet3ToOpSe
bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) {
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet3ToOpSet2"); OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet3ToOpSet2");
ngraph::pass::Manager manager; ngraph::pass::Manager manager(get_pass_config());
manager.register_pass<ngraph::pass::ConvertBroadcast3>(); manager.register_pass<ngraph::pass::ConvertBroadcast3>();
manager.register_pass<ngraph::pass::ConvertNMS1ToNMS3>(); manager.register_pass<ngraph::pass::ConvertNMS1ToNMS3>();
@ -31,7 +31,6 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ConvertTopK3>(); manager.register_pass<ngraph::pass::ConvertTopK3>();
manager.register_pass<ngraph::pass::SoftPlusDecomposition>(); manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
manager.set_pass_config(get_pass_config());
manager.run_passes(f); manager.run_passes(f);
return true; return true;
} }

View File

@ -150,7 +150,7 @@ public:
auto pass = std::make_shared<T>(std::forward<Args>(args)...); auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config(); auto pass_config = get_pass_config();
pass->set_pass_config(pass_config); pass->set_pass_config(pass_config);
if (!Enabled) if (!Enabled && !pass_config->is_enabled<T>())
{ {
pass_config->disable<T>(); pass_config->disable<T>();
} }
@ -168,6 +168,8 @@ public:
bool run_on_function(std::shared_ptr<ngraph::Function> f) override; bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
protected: protected:
bool m_enable_shape_inference = false; bool m_enable_shape_inference = false;

View File

@ -38,6 +38,9 @@ public:
Manager(); Manager();
~Manager(); ~Manager();
//// \brief Construct Manager with shared PassConfig instance
explicit Manager(std::shared_ptr<PassConfig> pass_config);
/// \brief Register given transformation class type to execution list /// \brief Register given transformation class type to execution list
/// Example below show the basic usage of pass::Manager /// Example below show the basic usage of pass::Manager
/// ///
@ -59,7 +62,7 @@ public:
{ {
push_pass<Validate>(); push_pass<Validate>();
} }
if (!Enable) if (!Enable && !m_pass_config->is_enabled<T>())
{ {
m_pass_config->disable<T>(); m_pass_config->disable<T>();
} }
@ -99,12 +102,6 @@ public:
/// This object allows to disable/enable transformations execution, set callback to particular /// This object allows to disable/enable transformations execution, set callback to particular
/// transformation. For mo details see PassConfig class. /// transformation. For mo details see PassConfig class.
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; } std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
/// \brief Set external PassConfig object.
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
{
*m_pass_config = *pass_config;
}
protected: protected:
template <typename T, class... Args> template <typename T, class... Args>
std::shared_ptr<T> push_pass(Args&&... args) std::shared_ptr<T> push_pass(Args&&... args)

View File

@ -62,7 +62,7 @@ namespace ngraph
/// \brief Set PassConfig for particular transformation instance /// \brief Set PassConfig for particular transformation instance
/// \param pass_config is a PassConfig shared_ptr /// \param pass_config is a PassConfig shared_ptr
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
{ {
m_pass_config = pass_config; m_pass_config = pass_config;
} }

View File

@ -72,7 +72,7 @@ namespace ngraph
public: public:
/// \brief Disable transformation by its type_info /// \brief Disable transformation by its type_info
/// \param type_info Transformation type_info /// \param type_info Transformation type_info
void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); } void disable(const DiscreteTypeInfo& type_info);
/// \brief Disable transformation by its class type (based on type_info) /// \brief Disable transformation by its class type (based on type_info)
template <typename T> template <typename T>
void disable() void disable()
@ -82,7 +82,7 @@ namespace ngraph
/// \brief Enable transformation by its type_info /// \brief Enable transformation by its type_info
/// \param type_info Transformation type_info /// \param type_info Transformation type_info
void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); } void enable(const DiscreteTypeInfo& type_info);
/// \brief Enable transformation by its class type (based on type_info) /// \brief Enable transformation by its class type (based on type_info)
template <typename T> template <typename T>
void enable() void enable()
@ -161,12 +161,31 @@ namespace ngraph
return is_disabled(T::type_info); return is_disabled(T::type_info);
} }
/// \brief Check either transformation type is force enabled or not
/// \param type_info Transformation type_info
/// \return true if transformation type was force enabled and false otherwise
bool is_enabled(const DiscreteTypeInfo& type_info) const
{
return m_enabled.count(type_info);
}
/// \brief Check either transformation class type is force enabled or not
/// \return true if transformation type was force enabled and false otherwise
template <typename T>
bool is_enabled() const
{
return is_enabled(T::type_info);
}
void add_disabled_passes(const PassConfig& rhs);
private: private:
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) { param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
return false; return false;
}; };
param_callback_map m_callback_map; param_callback_map m_callback_map;
std::unordered_set<DiscreteTypeInfo> m_disabled; std::unordered_set<DiscreteTypeInfo> m_disabled;
std::unordered_set<DiscreteTypeInfo> m_enabled;
}; };
} }
} }

View File

@ -272,6 +272,35 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_SUPPRESS_DEPRECATED_END
} }
void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs)
{
auto pass_config = get_pass_config();
// We have to preserve disabled passes because in case when we register matchers inside
// GraphRewrite c-tor we work with local PassConfig instance.
// For example:
//
// class ExampleGraphRewrite: public pass::GraphRewrite {
// ExampleGraphRewrite() {
// add_mather<TestMatcher1, false /* disabled by default */>();
// add_mather<TestMatcher2>();
// }
// };
//
// When we call add_matcher inside c-tor we automatically work with locally created PassConfig
// instance that is not shared. So when instance of this pass is being created in pass::Manager
// we set shared PassConfig but we will override already existing rules inside local config. To
// resolve this we have to copy disabled passes from local PassConfig to shared but we take into
// account that if passes were manually enabled we do not add them.
rhs->add_disabled_passes(*pass_config);
PassBase::set_pass_config(rhs);
// update nested transformations with new shared pass_config
for (auto& pass : m_matchers)
{
pass->set_pass_config(rhs);
}
}
void pass::RecurrentGraphRewrite::add_matcher( void pass::RecurrentGraphRewrite::add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m, const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback, const ngraph::recurrent_graph_rewrite_callback& callback,

View File

@ -44,6 +44,11 @@ pass::Manager::~Manager()
{ {
} }
pass::Manager::Manager(std::shared_ptr<ngraph::pass::PassConfig> pass_config)
: m_pass_config(std::move(pass_config))
{
}
void pass::Manager::run_passes(shared_ptr<Function> func) void pass::Manager::run_passes(shared_ptr<Function> func)
{ {
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::Manager::run_passes"); OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::Manager::run_passes");

View File

@ -30,3 +30,25 @@ pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type
return m_callback; return m_callback;
} }
} }
void pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info)
{
m_disabled.erase(type_info);
m_enabled.insert(type_info);
}
void pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info)
{
m_enabled.erase(type_info);
m_disabled.insert(type_info);
}
void pass::PassConfig::add_disabled_passes(const PassConfig& rhs)
{
for (const auto& pass : rhs.m_disabled)
{
if (is_enabled(pass))
continue;
disable(pass);
}
}

View File

@ -88,6 +88,7 @@ set(SRC
op_is.cpp op_is.cpp
opset1.cpp opset1.cpp
partial_shape.cpp partial_shape.cpp
pass_config.cpp
pass_liveness.cpp pass_liveness.cpp
pass_manager.cpp pass_manager.cpp
pass_shape_relevance.cpp pass_shape_relevance.cpp

View File

@ -298,13 +298,13 @@ TEST(PassConfigTest, Test1)
{ {
auto f = get_function(); auto f = get_function();
pass::Manager manager; auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
pass::Manager manager(pass_config);
manager.register_pass<TestPass>(); manager.register_pass<TestPass>();
auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
pass_config->set_callback<TestPass>(get_callback()); pass_config->set_callback<TestPass>(get_callback());
manager.set_pass_config(pass_config);
manager.run_passes(f); manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1); ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
@ -343,11 +343,9 @@ TEST(PassConfigTest, Test1)
{ {
auto pass_config = std::make_shared<pass::PassConfig>(); auto pass_config = std::make_shared<pass::PassConfig>();
pass::Manager manager1; pass::Manager manager1(pass_config);
pass::Manager manager2; pass::Manager manager2(pass_config);
manager1.set_pass_config(pass_config); ASSERT_EQ(pass_config.use_count(), 3);
manager2.set_pass_config(pass_config);
ASSERT_EQ(pass_config.use_count(), 1);
} }
{ {

381
ngraph/test/pass_config.cpp Normal file
View File

@ -0,0 +1,381 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <util/test_tools.hpp>
using namespace ::testing;
using namespace std;
using namespace ngraph;
class RenameReLU : public ngraph::pass::MatcherPass
{
public:
NGRAPH_RTTI_DECLARATION;
RenameReLU()
: MatcherPass()
{
auto relu = pattern::wrap_type<opset3::Relu>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto relu = m.get_match_root();
relu->set_friendly_name("renamed");
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(relu, "RenameReLU");
this->register_matcher(m, callback);
}
};
NGRAPH_RTTI_DEFINITION(RenameReLU, "RenameReLU", 0);
class RenameSigmoid : public ngraph::pass::MatcherPass
{
public:
NGRAPH_RTTI_DECLARATION;
RenameSigmoid()
: MatcherPass()
{
auto sigmoid = pattern::wrap_type<opset3::Sigmoid>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto sigmoid = m.get_match_root();
sigmoid->set_friendly_name("renamed");
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(sigmoid, "RenameSigmoid");
this->register_matcher(m, callback);
}
};
NGRAPH_RTTI_DEFINITION(RenameSigmoid, "RenameSigmoid", 0);
class TestFunctionPass : public ngraph::pass::FunctionPass
{
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<Function> f) override
{
pass::Manager manager(get_pass_config());
manager.register_pass<RenameReLU, false /*disabled by default*/>();
manager.register_pass<RenameSigmoid>();
manager.run_passes(f);
return true;
}
};
NGRAPH_RTTI_DEFINITION(TestFunctionPass, "TestFunctionPass", 0);
class TestGraphRewritePass : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
TestGraphRewritePass()
{
add_matcher<RenameReLU, false /*disabled by default*/>();
add_matcher<RenameSigmoid>();
}
};
NGRAPH_RTTI_DEFINITION(TestGraphRewritePass, "TestGraphRewritePass", 0);
std::tuple<std::shared_ptr<Function>, std::shared_ptr<Node>, std::shared_ptr<Node>>
get_test_function()
{
auto data =
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto relu = std::make_shared<ngraph::opset3::Relu>(data);
relu->set_friendly_name("relu");
auto sigmoid = std::make_shared<ngraph::opset3::Sigmoid>(relu);
sigmoid->set_friendly_name("sigmoid");
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sigmoid},
ngraph::ParameterVector{data});
return std::tuple<std::shared_ptr<Function>, std::shared_ptr<Node>, std::shared_ptr<Node>>(
f, relu, sigmoid);
}
TEST(PassConfig, EnableDisablePasses1)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestFunctionPass>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "relu");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses2)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestFunctionPass>();
auto pass_config = manager.get_pass_config();
pass_config->disable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "relu");
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
pass_config->enable<RenameSigmoid>();
pass_config->enable<RenameReLU>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses3)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestFunctionPass>();
auto pass_config = manager.get_pass_config();
pass_config->enable<RenameReLU>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses4)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestFunctionPass>();
auto pass_config = manager.get_pass_config();
pass_config->enable<RenameReLU>();
pass_config->disable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
pass_config->enable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses5)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestGraphRewritePass>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "relu");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses6)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestGraphRewritePass>();
auto pass_config = manager.get_pass_config();
pass_config->disable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "relu");
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
pass_config->enable<RenameSigmoid>();
pass_config->enable<RenameReLU>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses7)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestGraphRewritePass>();
auto pass_config = manager.get_pass_config();
pass_config->enable<RenameReLU>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses8)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestGraphRewritePass>();
auto pass_config = manager.get_pass_config();
pass_config->enable<RenameReLU>();
pass_config->disable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
pass_config->enable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses9)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
auto anchor = manager.register_pass<pass::GraphRewrite>();
anchor->add_matcher<RenameReLU, false>();
anchor->add_matcher<RenameSigmoid>();
auto pass_config = manager.get_pass_config();
pass_config->enable<RenameReLU>();
pass_config->disable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
pass_config->enable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
class TestNestedMatcher : public ngraph::pass::MatcherPass
{
public:
NGRAPH_RTTI_DECLARATION;
TestNestedMatcher()
: MatcherPass()
{
auto any_op = pattern::any_input();
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto root = m.get_match_root();
auto pass_config = this->get_pass_config();
if (std::dynamic_pointer_cast<opset3::Relu>(root) &&
!pass_config->is_disabled<RenameReLU>())
{
auto pass = std::make_shared<RenameReLU>();
pass->set_pass_config(pass_config);
pass->apply(root);
}
else if (std::dynamic_pointer_cast<opset3::Sigmoid>(root) &&
!pass_config->is_disabled<RenameSigmoid>())
{
auto pass = std::make_shared<RenameSigmoid>();
pass->set_pass_config(pass_config);
pass->apply(root);
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(any_op, "TestNestedMatcher");
this->register_matcher(m, callback);
}
};
NGRAPH_RTTI_DEFINITION(TestNestedMatcher, "TestNestedMatcher", 0);
class TestNestedGraphRewrite : public pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
TestNestedGraphRewrite() { add_matcher<TestNestedMatcher>(); }
};
NGRAPH_RTTI_DEFINITION(TestNestedGraphRewrite, "TestNestedGraphRewrite", 0);
TEST(PassConfig, EnableDisablePasses10)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
manager.register_pass<TestNestedGraphRewrite>();
auto pass_config = manager.get_pass_config();
pass_config->disable<RenameReLU>();
pass_config->disable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "relu");
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
pass_config->enable<RenameReLU>();
pass_config->enable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}
TEST(PassConfig, EnableDisablePasses11)
{
std::shared_ptr<Function> f;
std::shared_ptr<Node> relu, sigmoid;
std::tie(f, relu, sigmoid) = get_test_function();
pass::Manager manager;
auto anchor = manager.register_pass<pass::GraphRewrite>();
anchor->add_matcher<TestNestedMatcher>();
auto pass_config = manager.get_pass_config();
pass_config->disable<RenameReLU>();
pass_config->disable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "relu");
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
pass_config->enable<RenameReLU>();
pass_config->enable<RenameSigmoid>();
manager.run_passes(f);
ASSERT_EQ(relu->get_friendly_name(), "renamed");
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
}