Fixed disable/enable logic in PassConfig (#2940)
* Fixed disable/enable logic in PassConfig * Removed set_pass_config method for Manager; added comments
This commit is contained in:
parent
24ed4133dd
commit
022ea97f18
@ -61,7 +61,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet1ToLegacy, "ConvertOpSet1ToLega
|
|||||||
bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
OV_ITT_SCOPED_TASK(InferenceEngine::itt::domains::IELegacy, "ngraph::pass::ConvertOpSet1ToLegacy");
|
OV_ITT_SCOPED_TASK(InferenceEngine::itt::domains::IELegacy, "ngraph::pass::ConvertOpSet1ToLegacy");
|
||||||
|
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager(get_pass_config());
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
|
|
||||||
@ -148,7 +148,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
|||||||
|
|
||||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
|
|
||||||
manager.set_pass_config(get_pass_config());
|
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations",
|
|||||||
bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::CommonOptimizations");
|
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::CommonOptimizations");
|
||||||
|
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager(get_pass_config());
|
||||||
|
|
||||||
// This pass must be called first in pipeline
|
// This pass must be called first in pipeline
|
||||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
@ -118,8 +118,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
|||||||
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||||
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
||||||
|
|
||||||
// Propagate local PassConfig to internal pass::Manager
|
|
||||||
manager.set_pass_config(get_pass_config());
|
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -18,12 +18,11 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet2ToOpSet1, "ConvertOpSet2ToOpSe
|
|||||||
bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet2ToOpSet1");
|
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet2ToOpSet1");
|
||||||
|
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager(get_pass_config());
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
|
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
|
||||||
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
|
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
|
||||||
|
|
||||||
manager.set_pass_config(get_pass_config());
|
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet3ToOpSet2, "ConvertOpSet3ToOpSe
|
|||||||
bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet3ToOpSet2");
|
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet3ToOpSet2");
|
||||||
|
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager(get_pass_config());
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::ConvertBroadcast3>();
|
manager.register_pass<ngraph::pass::ConvertBroadcast3>();
|
||||||
manager.register_pass<ngraph::pass::ConvertNMS1ToNMS3>();
|
manager.register_pass<ngraph::pass::ConvertNMS1ToNMS3>();
|
||||||
@ -31,7 +31,6 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
|
|||||||
manager.register_pass<ngraph::pass::ConvertTopK3>();
|
manager.register_pass<ngraph::pass::ConvertTopK3>();
|
||||||
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
|
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
|
||||||
|
|
||||||
manager.set_pass_config(get_pass_config());
|
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -150,7 +150,7 @@ public:
|
|||||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||||
auto pass_config = get_pass_config();
|
auto pass_config = get_pass_config();
|
||||||
pass->set_pass_config(pass_config);
|
pass->set_pass_config(pass_config);
|
||||||
if (!Enabled)
|
if (!Enabled && !pass_config->is_enabled<T>())
|
||||||
{
|
{
|
||||||
pass_config->disable<T>();
|
pass_config->disable<T>();
|
||||||
}
|
}
|
||||||
@ -168,6 +168,8 @@ public:
|
|||||||
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||||
|
|
||||||
|
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool m_enable_shape_inference = false;
|
bool m_enable_shape_inference = false;
|
||||||
|
|
||||||
|
@ -38,6 +38,9 @@ public:
|
|||||||
Manager();
|
Manager();
|
||||||
~Manager();
|
~Manager();
|
||||||
|
|
||||||
|
//// \brief Construct Manager with shared PassConfig instance
|
||||||
|
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||||
|
|
||||||
/// \brief Register given transformation class type to execution list
|
/// \brief Register given transformation class type to execution list
|
||||||
/// Example below show the basic usage of pass::Manager
|
/// Example below show the basic usage of pass::Manager
|
||||||
///
|
///
|
||||||
@ -59,7 +62,7 @@ public:
|
|||||||
{
|
{
|
||||||
push_pass<Validate>();
|
push_pass<Validate>();
|
||||||
}
|
}
|
||||||
if (!Enable)
|
if (!Enable && !m_pass_config->is_enabled<T>())
|
||||||
{
|
{
|
||||||
m_pass_config->disable<T>();
|
m_pass_config->disable<T>();
|
||||||
}
|
}
|
||||||
@ -99,12 +102,6 @@ public:
|
|||||||
/// This object allows to disable/enable transformations execution, set callback to particular
|
/// This object allows to disable/enable transformations execution, set callback to particular
|
||||||
/// transformation. For mo details see PassConfig class.
|
/// transformation. For mo details see PassConfig class.
|
||||||
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
|
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
|
||||||
/// \brief Set external PassConfig object.
|
|
||||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
|
|
||||||
{
|
|
||||||
*m_pass_config = *pass_config;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
template <typename T, class... Args>
|
template <typename T, class... Args>
|
||||||
std::shared_ptr<T> push_pass(Args&&... args)
|
std::shared_ptr<T> push_pass(Args&&... args)
|
||||||
|
@ -62,7 +62,7 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Set PassConfig for particular transformation instance
|
/// \brief Set PassConfig for particular transformation instance
|
||||||
/// \param pass_config is a PassConfig shared_ptr
|
/// \param pass_config is a PassConfig shared_ptr
|
||||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
|
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
|
||||||
{
|
{
|
||||||
m_pass_config = pass_config;
|
m_pass_config = pass_config;
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ namespace ngraph
|
|||||||
public:
|
public:
|
||||||
/// \brief Disable transformation by its type_info
|
/// \brief Disable transformation by its type_info
|
||||||
/// \param type_info Transformation type_info
|
/// \param type_info Transformation type_info
|
||||||
void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); }
|
void disable(const DiscreteTypeInfo& type_info);
|
||||||
/// \brief Disable transformation by its class type (based on type_info)
|
/// \brief Disable transformation by its class type (based on type_info)
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void disable()
|
void disable()
|
||||||
@ -82,7 +82,7 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Enable transformation by its type_info
|
/// \brief Enable transformation by its type_info
|
||||||
/// \param type_info Transformation type_info
|
/// \param type_info Transformation type_info
|
||||||
void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); }
|
void enable(const DiscreteTypeInfo& type_info);
|
||||||
/// \brief Enable transformation by its class type (based on type_info)
|
/// \brief Enable transformation by its class type (based on type_info)
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void enable()
|
void enable()
|
||||||
@ -161,12 +161,31 @@ namespace ngraph
|
|||||||
return is_disabled(T::type_info);
|
return is_disabled(T::type_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// \brief Check either transformation type is force enabled or not
|
||||||
|
/// \param type_info Transformation type_info
|
||||||
|
/// \return true if transformation type was force enabled and false otherwise
|
||||||
|
bool is_enabled(const DiscreteTypeInfo& type_info) const
|
||||||
|
{
|
||||||
|
return m_enabled.count(type_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Check either transformation class type is force enabled or not
|
||||||
|
/// \return true if transformation type was force enabled and false otherwise
|
||||||
|
template <typename T>
|
||||||
|
bool is_enabled() const
|
||||||
|
{
|
||||||
|
return is_enabled(T::type_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_disabled_passes(const PassConfig& rhs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
param_callback_map m_callback_map;
|
param_callback_map m_callback_map;
|
||||||
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
||||||
|
std::unordered_set<DiscreteTypeInfo> m_enabled;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -272,6 +272,35 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
|
|||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs)
|
||||||
|
{
|
||||||
|
auto pass_config = get_pass_config();
|
||||||
|
// We have to preserve disabled passes because in case when we register matchers inside
|
||||||
|
// GraphRewrite c-tor we work with local PassConfig instance.
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// class ExampleGraphRewrite: public pass::GraphRewrite {
|
||||||
|
// ExampleGraphRewrite() {
|
||||||
|
// add_mather<TestMatcher1, false /* disabled by default */>();
|
||||||
|
// add_mather<TestMatcher2>();
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// When we call add_matcher inside c-tor we automatically work with locally created PassConfig
|
||||||
|
// instance that is not shared. So when instance of this pass is being created in pass::Manager
|
||||||
|
// we set shared PassConfig but we will override already existing rules inside local config. To
|
||||||
|
// resolve this we have to copy disabled passes from local PassConfig to shared but we take into
|
||||||
|
// account that if passes were manually enabled we do not add them.
|
||||||
|
rhs->add_disabled_passes(*pass_config);
|
||||||
|
PassBase::set_pass_config(rhs);
|
||||||
|
|
||||||
|
// update nested transformations with new shared pass_config
|
||||||
|
for (auto& pass : m_matchers)
|
||||||
|
{
|
||||||
|
pass->set_pass_config(rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void pass::RecurrentGraphRewrite::add_matcher(
|
void pass::RecurrentGraphRewrite::add_matcher(
|
||||||
const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
const ngraph::recurrent_graph_rewrite_callback& callback,
|
||||||
|
@ -44,6 +44,11 @@ pass::Manager::~Manager()
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pass::Manager::Manager(std::shared_ptr<ngraph::pass::PassConfig> pass_config)
|
||||||
|
: m_pass_config(std::move(pass_config))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
void pass::Manager::run_passes(shared_ptr<Function> func)
|
void pass::Manager::run_passes(shared_ptr<Function> func)
|
||||||
{
|
{
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::Manager::run_passes");
|
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::Manager::run_passes");
|
||||||
|
@ -30,3 +30,25 @@ pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type
|
|||||||
return m_callback;
|
return m_callback;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info)
|
||||||
|
{
|
||||||
|
m_disabled.erase(type_info);
|
||||||
|
m_enabled.insert(type_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info)
|
||||||
|
{
|
||||||
|
m_enabled.erase(type_info);
|
||||||
|
m_disabled.insert(type_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void pass::PassConfig::add_disabled_passes(const PassConfig& rhs)
|
||||||
|
{
|
||||||
|
for (const auto& pass : rhs.m_disabled)
|
||||||
|
{
|
||||||
|
if (is_enabled(pass))
|
||||||
|
continue;
|
||||||
|
disable(pass);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -88,6 +88,7 @@ set(SRC
|
|||||||
op_is.cpp
|
op_is.cpp
|
||||||
opset1.cpp
|
opset1.cpp
|
||||||
partial_shape.cpp
|
partial_shape.cpp
|
||||||
|
pass_config.cpp
|
||||||
pass_liveness.cpp
|
pass_liveness.cpp
|
||||||
pass_manager.cpp
|
pass_manager.cpp
|
||||||
pass_shape_relevance.cpp
|
pass_shape_relevance.cpp
|
||||||
|
@ -298,13 +298,13 @@ TEST(PassConfigTest, Test1)
|
|||||||
{
|
{
|
||||||
auto f = get_function();
|
auto f = get_function();
|
||||||
|
|
||||||
pass::Manager manager;
|
auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
|
||||||
|
pass::Manager manager(pass_config);
|
||||||
|
|
||||||
manager.register_pass<TestPass>();
|
manager.register_pass<TestPass>();
|
||||||
|
|
||||||
auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
|
|
||||||
pass_config->set_callback<TestPass>(get_callback());
|
pass_config->set_callback<TestPass>(get_callback());
|
||||||
|
|
||||||
manager.set_pass_config(pass_config);
|
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
|
|
||||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||||
@ -343,11 +343,9 @@ TEST(PassConfigTest, Test1)
|
|||||||
{
|
{
|
||||||
auto pass_config = std::make_shared<pass::PassConfig>();
|
auto pass_config = std::make_shared<pass::PassConfig>();
|
||||||
|
|
||||||
pass::Manager manager1;
|
pass::Manager manager1(pass_config);
|
||||||
pass::Manager manager2;
|
pass::Manager manager2(pass_config);
|
||||||
manager1.set_pass_config(pass_config);
|
ASSERT_EQ(pass_config.use_count(), 3);
|
||||||
manager2.set_pass_config(pass_config);
|
|
||||||
ASSERT_EQ(pass_config.use_count(), 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
381
ngraph/test/pass_config.cpp
Normal file
381
ngraph/test/pass_config.cpp
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
// Copyright (C) 2018-2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <util/test_tools.hpp>
|
||||||
|
|
||||||
|
using namespace ::testing;
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
class RenameReLU : public ngraph::pass::MatcherPass
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
RenameReLU()
|
||||||
|
: MatcherPass()
|
||||||
|
{
|
||||||
|
auto relu = pattern::wrap_type<opset3::Relu>();
|
||||||
|
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||||
|
auto relu = m.get_match_root();
|
||||||
|
relu->set_friendly_name("renamed");
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(relu, "RenameReLU");
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(RenameReLU, "RenameReLU", 0);
|
||||||
|
|
||||||
|
class RenameSigmoid : public ngraph::pass::MatcherPass
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
RenameSigmoid()
|
||||||
|
: MatcherPass()
|
||||||
|
{
|
||||||
|
auto sigmoid = pattern::wrap_type<opset3::Sigmoid>();
|
||||||
|
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||||
|
auto sigmoid = m.get_match_root();
|
||||||
|
sigmoid->set_friendly_name("renamed");
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(sigmoid, "RenameSigmoid");
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(RenameSigmoid, "RenameSigmoid", 0);
|
||||||
|
|
||||||
|
class TestFunctionPass : public ngraph::pass::FunctionPass
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
bool run_on_function(std::shared_ptr<Function> f) override
|
||||||
|
{
|
||||||
|
pass::Manager manager(get_pass_config());
|
||||||
|
|
||||||
|
manager.register_pass<RenameReLU, false /*disabled by default*/>();
|
||||||
|
manager.register_pass<RenameSigmoid>();
|
||||||
|
|
||||||
|
manager.run_passes(f);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(TestFunctionPass, "TestFunctionPass", 0);
|
||||||
|
|
||||||
|
class TestGraphRewritePass : public ngraph::pass::GraphRewrite
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
TestGraphRewritePass()
|
||||||
|
{
|
||||||
|
add_matcher<RenameReLU, false /*disabled by default*/>();
|
||||||
|
add_matcher<RenameSigmoid>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(TestGraphRewritePass, "TestGraphRewritePass", 0);
|
||||||
|
|
||||||
|
std::tuple<std::shared_ptr<Function>, std::shared_ptr<Node>, std::shared_ptr<Node>>
|
||||||
|
get_test_function()
|
||||||
|
{
|
||||||
|
auto data =
|
||||||
|
std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||||
|
auto relu = std::make_shared<ngraph::opset3::Relu>(data);
|
||||||
|
relu->set_friendly_name("relu");
|
||||||
|
auto sigmoid = std::make_shared<ngraph::opset3::Sigmoid>(relu);
|
||||||
|
sigmoid->set_friendly_name("sigmoid");
|
||||||
|
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sigmoid},
|
||||||
|
ngraph::ParameterVector{data});
|
||||||
|
return std::tuple<std::shared_ptr<Function>, std::shared_ptr<Node>, std::shared_ptr<Node>>(
|
||||||
|
f, relu, sigmoid);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses1)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestFunctionPass>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "relu");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses2)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestFunctionPass>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->disable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "relu");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
|
||||||
|
|
||||||
|
pass_config->enable<RenameSigmoid>();
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses3)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestFunctionPass>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses4)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestFunctionPass>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
pass_config->disable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
|
||||||
|
|
||||||
|
pass_config->enable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses5)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestGraphRewritePass>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "relu");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses6)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestGraphRewritePass>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->disable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "relu");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
|
||||||
|
|
||||||
|
pass_config->enable<RenameSigmoid>();
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses7)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestGraphRewritePass>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses8)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestGraphRewritePass>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
pass_config->disable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
|
||||||
|
|
||||||
|
pass_config->enable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses9)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
auto anchor = manager.register_pass<pass::GraphRewrite>();
|
||||||
|
anchor->add_matcher<RenameReLU, false>();
|
||||||
|
anchor->add_matcher<RenameSigmoid>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
pass_config->disable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
|
||||||
|
|
||||||
|
pass_config->enable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestNestedMatcher : public ngraph::pass::MatcherPass
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
TestNestedMatcher()
|
||||||
|
: MatcherPass()
|
||||||
|
{
|
||||||
|
auto any_op = pattern::any_input();
|
||||||
|
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||||
|
auto root = m.get_match_root();
|
||||||
|
auto pass_config = this->get_pass_config();
|
||||||
|
if (std::dynamic_pointer_cast<opset3::Relu>(root) &&
|
||||||
|
!pass_config->is_disabled<RenameReLU>())
|
||||||
|
{
|
||||||
|
auto pass = std::make_shared<RenameReLU>();
|
||||||
|
pass->set_pass_config(pass_config);
|
||||||
|
pass->apply(root);
|
||||||
|
}
|
||||||
|
else if (std::dynamic_pointer_cast<opset3::Sigmoid>(root) &&
|
||||||
|
!pass_config->is_disabled<RenameSigmoid>())
|
||||||
|
{
|
||||||
|
auto pass = std::make_shared<RenameSigmoid>();
|
||||||
|
pass->set_pass_config(pass_config);
|
||||||
|
pass->apply(root);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(any_op, "TestNestedMatcher");
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(TestNestedMatcher, "TestNestedMatcher", 0);
|
||||||
|
|
||||||
|
class TestNestedGraphRewrite : public pass::GraphRewrite
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
TestNestedGraphRewrite() { add_matcher<TestNestedMatcher>(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(TestNestedGraphRewrite, "TestNestedGraphRewrite", 0);
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses10)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
manager.register_pass<TestNestedGraphRewrite>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->disable<RenameReLU>();
|
||||||
|
pass_config->disable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "relu");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
|
||||||
|
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
pass_config->enable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassConfig, EnableDisablePasses11)
|
||||||
|
{
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
std::shared_ptr<Node> relu, sigmoid;
|
||||||
|
std::tie(f, relu, sigmoid) = get_test_function();
|
||||||
|
|
||||||
|
pass::Manager manager;
|
||||||
|
auto anchor = manager.register_pass<pass::GraphRewrite>();
|
||||||
|
anchor->add_matcher<TestNestedMatcher>();
|
||||||
|
|
||||||
|
auto pass_config = manager.get_pass_config();
|
||||||
|
pass_config->disable<RenameReLU>();
|
||||||
|
pass_config->disable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "relu");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid");
|
||||||
|
|
||||||
|
pass_config->enable<RenameReLU>();
|
||||||
|
pass_config->enable<RenameSigmoid>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(relu->get_friendly_name(), "renamed");
|
||||||
|
ASSERT_EQ(sigmoid->get_friendly_name(), "renamed");
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user