Fine-Grain Transformation pipeline tuning (#2547)

* Initial version of transformation callback refactoring

* Improved fine-grain tuning for transformation pipeline

* Check disabled matchers in GraphRewrite

* Avoid deprecated classes inside PassConfig

* Enabled DepthToSpace fusion by default

* Removed doulbe search in map

* Moved back pass_config.hpp; Added doxygen documentation for new class and methods

* Added doxygen comment for Manager and GraphRewrite new mthods
This commit is contained in:
Gleb Kazantaev 2020-10-09 15:33:19 +03:00 committed by GitHub
parent da625b995e
commit 2e49b4e4d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 462 additions and 193 deletions

View File

@ -27,8 +27,21 @@
#include <transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
#include <transformations/convert_depth_to_space.hpp>
#include <transformations/convert_space_to_depth.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
#include <transformations/convert_gelu.hpp>
#include <transformations/depth_to_space_fusion.hpp>
#include <transformations/convert_batch_to_space.hpp>
#include <transformations/convert_extract_image_patches_to_reorg_yolo.hpp>
#include <transformations/hswish_decomposition.hpp>
#include <transformations/reduce_l1_decomposition.hpp>
#include <transformations/reduce_l2_decomposition.hpp>
#include <transformations/convert_space_to_batch.hpp>
#include <transformations/softplus_decomposition.hpp>
#include <transformations/convert_pad_to_group_conv.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
@ -63,31 +76,6 @@ Engine::~Engine() {
static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
OV_ITT_SCOPED_TASK(MKLDNNPlugin::itt::domains::MKLDNNPlugin, "Transformation");
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
// DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
if (auto dtsOp = std::dynamic_pointer_cast<const ::ngraph::opset3::DepthToSpace>(node)) {
return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size();
}
// SpaceToDepth node implementation supports only equal input/output tensors with rank <= 5
if (auto stdOp = std::dynamic_pointer_cast<const ::ngraph::opset3::SpaceToDepth>(node)) {
return stdOp->input_value(0).get_shape().size() <= 5lu && stdOp->input_value(0).get_shape().size() == stdOp->get_output_shape(0).size();
}
if (auto fc_op = std::dynamic_pointer_cast<const ngraph::op::FullyConnected>(node)) {
return fc_op->input_value(0).get_shape().size() == 3ul;
}
return std::dynamic_pointer_cast<const ngraph::opset2::Gelu>(node) ||
std::dynamic_pointer_cast<const ngraph::opset2::BatchToSpace>(node) ||
std::dynamic_pointer_cast<const ngraph::opset2::SpaceToBatch>(node) ||
std::dynamic_pointer_cast<const ngraph::opset3::ExtractImagePatches>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::HSwish>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL1>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL2>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
};
auto nGraphFunc = clonedNetwork->getFunction();
// Disable shape inference (WA for generic operations)
ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
@ -116,7 +104,41 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
manager.set_callback(transformations_callback);
auto pass_config = manager.get_pass_config();
using const_node_ptr = const std::shared_ptr<const ngraph::Node>;
// SpaceToDepth/ DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
pass_config->set_callback<ngraph::pass::ConvertSpaceToDepth,
ngraph::pass::ConvertDepthToSpace>(
[](const_node_ptr &node) -> bool {
return node->input_value(0).get_shape().size() <= 5lu &&
node->input_value(0).get_shape().size() == node->get_output_shape(0).size();
});
// Disable FC reshaping for 3D case
pass_config->set_callback<ngraph::pass::ReshapeFullyConnected>(
[](const_node_ptr &node) -> bool {
return node->input_value(0).get_shape().size() == 3ul;
});
pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
ngraph::pass::ConvertSpaceToBatch>(
[](const_node_ptr &node) -> bool {
const auto & rank = node->input(0).get_partial_shape().rank().get_length();
return rank == 4lu || rank == 5lu;
});
// List of enabled/disabled transformations
pass_config->disable<ngraph::pass::ConvertGELU>();
pass_config->disable<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
pass_config->disable<ngraph::pass::HSwishDecomposition>();
pass_config->disable<ngraph::pass::ReduceL1Decomposition>();
pass_config->disable<ngraph::pass::ReduceL2Decomposition>();
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
pass_config->enable<ngraph::pass::ConvertPadToGroupConvolution>();
manager.run_passes(nGraphFunc);
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);

View File

@ -68,7 +68,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
@ -111,7 +111,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
manager.set_callback(m_transformation_callback);
// Propagate local PassConfig to internal pass::Manager
manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}

View File

@ -18,7 +18,7 @@ ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
if (!dts_node || m_transformation_callback(dts_node)) {
if (!dts_node || transformation_callback(dts_node)) {
return false;
}

View File

@ -154,7 +154,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.set_callback(m_transformation_callback);
manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}

View File

@ -24,7 +24,7 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
manager.set_callback(m_transformation_callback);
manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}

View File

@ -33,7 +33,7 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
manager.set_callback(m_transformation_callback);
manager.set_pass_config(get_pass_config());
manager.run_passes(f);
return true;
}

View File

@ -19,7 +19,7 @@ ngraph::pass::ConvertPadToGroupConvolution::ConvertPadToGroupConvolution() {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto pad = std::dynamic_pointer_cast<ngraph::opset4::Pad> (m.get_match_root());
if (!pad || !m_transformation_callback(pad) /* disabled by default */) {
if (!pad) {
return false;
}

View File

@ -155,11 +155,6 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
if (!m_transformation_callback(depth_to_space)) {
return false;
}
ngraph::replace_node(reshape_after, depth_to_space);
return true;
};

View File

@ -126,12 +126,34 @@ public:
m_matchers.push_back(pass);
}
template <typename T, class... Args>
/// \brief Register given transformation class type to GraphRewrite execution list
/// All registered transformations will be executed in a single graph traversal.
/// Example below show the basic usage of pass::GraphRewrite
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMathcers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enabled = true, class... Args>
std::shared_ptr<T> add_matcher(Args&&... args)
{
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
"pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
if (!Enabled)
{
pass_config->disable<T>();
}
m_matchers.push_back(pass);
return pass;
}

View File

@ -22,7 +22,6 @@
#include <vector>
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/pass/validate.hpp"
namespace ngraph
@ -39,14 +38,31 @@ public:
Manager();
~Manager();
template <typename T, class... Args>
/// \brief Register given transformation class type to execution list
/// Example below show the basic usage of pass::Manager
///
/// pass::Manager manager;
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// manager.register_pass<MyTransformation, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enable = true, class... Args>
std::shared_ptr<T> register_pass(Args&&... args)
{
auto rc = push_pass<T>(std::forward<Args>(args)...);
rc->set_pass_config(m_pass_config);
if (m_per_pass_validation)
{
push_pass<Validate>();
}
if (!Enable)
{
m_pass_config->disable<T>();
}
return rc;
}
@ -59,8 +75,10 @@ public:
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
/// \brief Callback is a lambda function that can be used by registered transformations.
/// The main purpose of this callback is to provide a way for plugins to disable/enable
/// transformations. In some cases plugins may want not to execute some transformations.
/// For example plugin can disable unpleasant decompositions because of performance reasons.
/// transformations based on some conditions. In some cases plugins may want not to execute some
/// transformations.
/// For example plugin can disable unpleasant decompositions because of performance reasons for
/// some cases.
/// Callback example:
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
@ -69,15 +87,22 @@ public:
/// decomposition pass will check is this decomposition needed or plugin can execute this
/// operation directly. And of course on transformation side we need to have a response for this
/// callback.
/// if (m_transformation_callback(batch_to_space)) {
/// if (transformation_callback(batch_to_space)) {
/// return false;
/// }
/// \param callback lamda function that returns true in case if node is supported by plugin and
/// transformation is not needed
void set_callback(param_callback callback)
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); }
/// \return PassConfig shared object. This object is used for transformations pipeline
/// configuration.
/// This object allows to disable/enable transformations execution, set callback to particular
/// transformation. For mo details see PassConfig class.
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
/// \brief Set external PassConfig object.
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
{
m_transformation_callback = callback;
m_has_default_callback = false;
*m_pass_config = *pass_config;
}
protected:
@ -91,11 +116,7 @@ protected:
return pass;
}
param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
return false;
};
bool m_has_default_callback = true;
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
bool m_visualize = false;
bool m_per_pass_validation = true;

View File

@ -23,6 +23,7 @@
#include "ngraph/deprecated.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/util.hpp"
namespace ngraph
@ -39,7 +40,6 @@ namespace ngraph
typedef EnumMask<PassProperty> PassPropertyMask;
const PassPropertyMask all_pass_property_off;
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
class NGRAPH_API PassBase
{
@ -54,8 +54,40 @@ namespace ngraph
void set_name(const std::string& name) { m_name = name; }
std::string get_name() const;
/// \brief Set callback for particular transformation type.
/// This method set global callback. For more details see PassConfig class
/// documentation.
/// \param callback lambda function that takes node and returns bool
void set_callback(const param_callback& callback);
/// \brief Set PassConfig for particular transformation instance
/// \param pass_config is a PassConfig shared_ptr
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
{
m_pass_config = pass_config;
}
/// \brief Allows to access PassConfig shared instance
/// \return Shared instance of PassConfig class
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
/// \brief Applies callback for given node. By default callback returns false.
/// This method remains here only for backward compatibility and will be removed
/// after all transformations are moved to transformation_callback() method.
/// \return result of callback execution for given node
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
bool m_transformation_callback(const std::shared_ptr<const Node>& node)
{
return m_pass_config->get_callback(get_type_info())(node);
}
/// \brief Applies callback for given node. By default callback returns false.
/// \param node which will be used inside callback
/// \return result of callback execution for given node
bool transformation_callback(const std::shared_ptr<const Node>& node)
{
return m_pass_config->get_callback(get_type_info())(node);
}
using type_info_t = DiscreteTypeInfo;
virtual const type_info_t& get_type_info() const = 0;
@ -63,13 +95,11 @@ namespace ngraph
protected:
void set_property(const PassPropertyMask& prop, bool value);
param_callback m_transformation_callback =
[](const std::shared_ptr<const Node>&) -> bool { return false; };
bool m_has_default_callback = true;
private:
PassPropertyMask m_property;
std::string m_name;
std::shared_ptr<PassConfig> m_pass_config;
};
class NGRAPH_API FunctionPass : public PassBase

View File

@ -16,31 +16,157 @@
#pragma once
#include <map>
#include <string>
#include <list>
#include <memory>
#include <vector>
#include <ngraph/ngraph_visibility.hpp>
#include "ngraph/deprecated.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace pass
{
class PassConfig;
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
/// \brief Class representing a transformations config that is used for disabling/enabling
/// transformations registered inside pass::Manager and also allows to set callback for all
/// transformations or for particular transformation.
///
/// When pass::Manager is created all passes registered inside this manager including nested
/// passes will share the same instance of PassConfig class.
/// To work with this class first you need to get shared instance of this class by calling
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
/// on transformations type_info. For example:
///
/// pass::Manager manager;
/// manager.register_pass<CommonOptimizations>();
/// auto pass_config = manager.get_pass_config();
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
/// // CommonOptimizations pipeline
/// manager.run_passes(f);
///
/// Sometimes it is needed to call transformation inside other transformation manually. And
/// for that case before running transformation you need manually check that this pass is
/// not disabled and then you need to set current PassConfig instance to this
/// transformation. For example:
///
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
/// // you need to call get_pass_config() method to get shared instance of PassConfig
/// auto pass_config = get_pass_config();
///
/// // Before running nested transformation you need to check is it disabled or not
/// if (!pass_config->is_disabled<ConvertGELU>()) {
/// auto pass = ConvertGELU();
/// pass->set_pass_config(pass_config);
/// pass.apply(node);
/// }
///
/// Following this logic inside your transformations you will guaranty that transformations
/// will be executed in a right way.
class NGRAPH_API PassConfig
{
public:
/// \brief Disable transformation by its type_info
/// \param type_info Transformation type_info
void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); }
/// \brief Disable transformation by its class type (based on type_info)
template <typename T>
void disable()
{
disable(T::type_info);
}
/// \brief Enable transformation by its type_info
/// \param type_info Transformation type_info
void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); }
/// \brief Enable transformation by its class type (based on type_info)
template <typename T>
void enable()
{
enable(T::type_info);
}
/// \brief Set callback for all kind of transformations
void set_callback(const param_callback& callback) { m_callback = callback; }
template <typename... Args>
typename std::enable_if<sizeof...(Args) == 0>::type
set_callback(const param_callback& callback)
{
}
/// \brief Set callback for particular transformation class types
///
/// Example below show how to set callback for one or multiple passes using this method.
///
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
/// ngraph::pass::ConvertSpaceToBatch>(
/// [](const_node_ptr &node) -> bool {
/// // Disable transformations for cases when input shape rank is not
/// equal to 4
/// const auto input_shape_rank =
/// node->get_output_partial_shape(0).rank().get_length();
/// if (input_shape_rank != 4) {
/// return false;
/// }
/// return true;
/// });
///
/// Note that inside transformations you must provide code that work with this callback.
/// See example below:
///
/// if (transformation_callback(node)) {
/// return false; // exit from transformation
/// }
///
template <typename T, class... Args>
void set_callback(const param_callback& callback)
{
m_callback_map[T::type_info] = callback;
set_callback<Args...>(callback);
}
/// \brief Get callback for given transformation type_info
/// \param type_info Transformation type_info
///
/// In case if callback wasn't set for given transformation type then global callback
/// will be returned. But if even global callback wasn't set then default callback will
/// be returned.
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
/// \brief Get callback for given transformation class type
/// \return callback lambda function
template <typename T>
param_callback get_callback() const
{
return get_callback(T::type_info);
}
/// \brief Check either transformation type is disabled or not
/// \param type_info Transformation type_info
/// \return true if transformation type was disabled and false otherwise
bool is_disabled(const DiscreteTypeInfo& type_info) const
{
return m_disabled.count(type_info);
}
/// \brief Check either transformation class type is disabled or not
/// \return true if transformation type was disabled and false otherwise
template <typename T>
bool is_disabled() const
{
return is_disabled(T::type_info);
}
private:
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
return false;
};
param_callback_map m_callback_map;
std::unordered_set<DiscreteTypeInfo> m_disabled;
};
}
}
class NGRAPH_API ngraph::pass::PassConfig
{
public:
PassConfig();
const std::map<std::string, bool>& get_enables() const { return m_pass_enables; }
void set_pass_enable(const std::string& name, bool enable);
bool get_pass_enable(const std::string& name) const;
const std::map<std::string, bool>& get_pass_attributes() const { return m_pass_attributes; }
void set_pass_attribute(const std::string& name, bool enable);
bool get_pass_attribute(const std::string& name) const;
private:
std::map<std::string, bool> m_pass_enables;
std::map<std::string, bool> m_pass_attributes;
};
}

View File

@ -72,6 +72,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
bool rewritten = false;
const auto& pass_config = get_pass_config();
// Initialize execution queue with nodes in topological order
deque<std::shared_ptr<Node>> nodes_to_run;
@ -85,6 +86,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
std::unordered_map<NodeTypeInfo, std::vector<size_t>> type_to_matcher;
for (size_t matcher_index = 0; matcher_index < m_matchers.size(); ++matcher_index)
{
// Skip passes that are disabled
if (pass_config->is_disabled(m_matchers[matcher_index]->get_type_info()))
continue;
auto matcher = m_matchers[matcher_index]->get_matcher();
if (!matcher)
{
@ -139,11 +144,6 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
return false;
}
if (!m_has_default_callback)
{
m_pass->set_callback(m_transformation_callback);
}
// Apply MatcherPass. In case if it returns true no other MatcherPasses will apply
// to this node
bool status = m_pass->apply(node);
@ -224,6 +224,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
{
for (auto& m_pass : m_matchers)
{
// Skip passes that are disabled
if (pass_config->is_disabled(m_pass->get_type_info()))
continue;
if (run_matcher_pass(m_pass, node))
{
rewritten = true;

View File

@ -36,6 +36,7 @@ using namespace ngraph;
pass::Manager::Manager()
: m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING"))
, m_pass_config(std::make_shared<PassConfig>())
{
}
@ -56,12 +57,14 @@ void pass::Manager::run_passes(shared_ptr<Function> func)
bool function_changed = false;
for (auto& pass : m_pass_list)
{
pass_timer.start();
if (!m_has_default_callback)
if (m_pass_config->is_disabled(pass->get_type_info()))
{
pass->set_callback(m_transformation_callback);
NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled";
continue;
}
pass_timer.start();
NGRAPH_SUPPRESS_DEPRECATED_START
if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass))
{

View File

@ -33,6 +33,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
pass::PassBase::PassBase()
: m_property{all_pass_property_off}
, m_pass_config(std::make_shared<PassConfig>())
{
}
@ -73,8 +74,7 @@ std::string pass::PassBase::get_name() const
void pass::PassBase::set_callback(const param_callback& callback)
{
m_transformation_callback = callback;
m_has_default_callback = false;
m_pass_config->set_callback(callback);
}
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM

View File

@ -15,100 +15,18 @@
//*****************************************************************************
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/env_util.hpp"
#include "ngraph/except.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
// TODO: Add file-based configuration support
pass::PassConfig::PassConfig()
pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const
{
//
// Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
// and returns the pass names and whether they should be enabled or disabled in the
// provided unordered_map. Implementation of pass selection is up to the backend
// E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
// set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
//
string pass_enables = getenv_string("NGRAPH_PASS_ENABLES");
if (!pass_enables.empty())
{
stringstream ss;
ss << pass_enables;
while (ss.good())
{
string substr;
getline(ss, substr, ';');
auto split_str = split(substr, ':', false);
switch (split_str.size())
{
case 1: m_pass_enables.emplace(split_str[0], true); break;
case 2: m_pass_enables.emplace(split_str[0], parse_string<bool>(split_str[1])); break;
default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ENABLES: " + substr);
}
}
}
//
// Parses the semi-colon separated environment string passed through NGRAPH_PASS_ATTRIBUTES
// and returns the pass attributes and whether they should be enabled or disabled in the
// provided unordered_map. Naming of pass attributes is up to the backends.
//
// For example:
// NGRAPH_PASS_ATTRIBUTES="OptimizeForMemory=0;MemoryAssignment::ReuseMemory=1;UseDefaultLayouts"
// would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
// "UseDefaultLayouts"
//
static const string pass_attributes = getenv_string("NGRAPH_PASS_ATTRIBUTES");
if (!pass_attributes.empty())
{
stringstream ss;
ss << pass_attributes;
while (ss.good())
{
string substr;
getline(ss, substr, ';');
auto split_str = split(substr, '=', false);
switch (split_str.size())
{
case 1: m_pass_attributes.emplace(split_str[0], true); break;
case 2:
m_pass_attributes.emplace(split_str[0], parse_string<bool>(split_str[1]));
break;
default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ATTRIBUTES: " + substr);
}
}
}
}
void pass::PassConfig::set_pass_enable(const string& name, bool enable)
{
m_pass_enables[name] = enable;
}
bool pass::PassConfig::get_pass_enable(const string& name) const
{
auto it = m_pass_enables.find(name);
if (it != m_pass_enables.end())
const auto& it = m_callback_map.find(type_info);
if (it != m_callback_map.end())
{
return it->second;
}
return false;
}
void pass::PassConfig::set_pass_attribute(const string& name, bool enable)
{
m_pass_attributes[name] = enable;
}
bool pass::PassConfig::get_pass_attribute(const string& name) const
{
auto it = m_pass_attributes.find(name);
if (it != m_pass_attributes.end())
else
{
return it->second;
return m_callback;
}
return false;
}

View File

@ -15,6 +15,7 @@ using namespace ngraph;
class TestPass : public ngraph::pass::MatcherPass
{
public:
NGRAPH_RTTI_DECLARATION;
TestPass()
: MatcherPass()
{
@ -39,12 +40,16 @@ public:
class Anchor : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
Anchor()
: GraphRewrite()
{
}
};
NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
std::shared_ptr<Function> get_function()
{
auto data =
@ -93,7 +98,7 @@ TEST(GraphRewriteTest, GraphRewriteCallback)
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallback)
TEST(GraphRewriteTest, ManagerCallbackDeprecated)
{
auto f = get_function();
@ -106,6 +111,20 @@ TEST(GraphRewriteTest, ManagerCallback)
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallback)
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
TEST(GraphRewriteTest, ManagerCallback2)
{
auto f = get_function();
@ -244,4 +263,127 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2)
anchor.run_on_function(f);
ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
}
TEST(PassConfigTest, Test1)
{
{
auto f = get_function();
pass::Manager manager;
manager.register_pass<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
manager.register_pass<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
manager.register_pass<TestPass>();
auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
pass_config->set_callback<TestPass>(get_callback());
manager.set_pass_config(pass_config);
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = anchor->get_pass_config();
pass_config->set_callback(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = anchor->get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto pass_config = std::make_shared<pass::PassConfig>();
pass::Manager manager1;
pass::Manager manager2;
manager1.set_pass_config(pass_config);
manager2.set_pass_config(pass_config);
ASSERT_EQ(pass_config.use_count(), 1);
}
{
auto f = get_function();
pass::Manager manager;
manager.register_pass<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
pass_config->disable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
pass_config->enable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
{
auto f = get_function();
pass::Manager manager;
auto anchor = manager.register_pass<Anchor>();
anchor->add_matcher<TestPass>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<TestPass>(get_callback());
pass_config->disable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
pass_config->enable<TestPass>();
manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
}
}

View File

@ -51,6 +51,8 @@ set (SRC
pass/shape_relevance.hpp
)
disable_deprecated_warnings()
add_library(ngraph_backend SHARED ${SRC})
target_compile_definitions(ngraph_backend
PRIVATE

View File

@ -102,14 +102,6 @@ std::shared_ptr<ngraph::runtime::Tensor>
throw std::invalid_argument("This backend does not support dynamic tensors");
}
std::shared_ptr<runtime::Executable>
runtime::Backend::compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& /* pass_config */,
bool enable_performance_data)
{
return compile(func, enable_performance_data);
}
bool runtime::Backend::is_supported(const Node& /* node */) const
{
// The default behavior is that a backend does not support any ops. If this is not the case

View File

@ -22,7 +22,6 @@
#include "backend_visibility.hpp"
#include "executable.hpp"
#include "ngraph/function.hpp"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
@ -111,14 +110,6 @@ public:
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
bool enable_performance_data = false) = 0;
/// \brief Compiles a Function.
/// \param func The function to compile
/// \param pass_config Configuration object for defining compilation options
/// \returns compiled function or nullptr on failure
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool enable_performance_data = false);
/// \brief Loads a previously saved Executable object from a stream.
/// \param input_stream the opened input stream containing the saved Executable
/// \returns A compiled function or throws an exception on error