Fine-Grain Transformation pipeline tuning (#2547)
* Initial version of transformation callback refactoring * Improved fine-grain tuning for transformation pipeline * Check disabled matchers in GraphRewrite * Avoid deprecated classes inside PassConfig * Enabled DepthToSpace fusion by default * Removed doulbe search in map * Moved back pass_config.hpp; Added doxygen documentation for new class and methods * Added doxygen comment for Manager and GraphRewrite new mthods
This commit is contained in:
parent
da625b995e
commit
2e49b4e4d8
@ -27,8 +27,21 @@
|
||||
#include <transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
|
||||
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
|
||||
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
|
||||
#include <transformations/convert_depth_to_space.hpp>
|
||||
#include <transformations/convert_space_to_depth.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
|
||||
#include <transformations/convert_gelu.hpp>
|
||||
#include <transformations/depth_to_space_fusion.hpp>
|
||||
#include <transformations/convert_batch_to_space.hpp>
|
||||
#include <transformations/convert_extract_image_patches_to_reorg_yolo.hpp>
|
||||
#include <transformations/hswish_decomposition.hpp>
|
||||
#include <transformations/reduce_l1_decomposition.hpp>
|
||||
#include <transformations/reduce_l2_decomposition.hpp>
|
||||
#include <transformations/convert_space_to_batch.hpp>
|
||||
#include <transformations/softplus_decomposition.hpp>
|
||||
#include <transformations/convert_pad_to_group_conv.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
@ -63,31 +76,6 @@ Engine::~Engine() {
|
||||
static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
|
||||
OV_ITT_SCOPED_TASK(MKLDNNPlugin::itt::domains::MKLDNNPlugin, "Transformation");
|
||||
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
// DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
|
||||
if (auto dtsOp = std::dynamic_pointer_cast<const ::ngraph::opset3::DepthToSpace>(node)) {
|
||||
return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size();
|
||||
}
|
||||
|
||||
// SpaceToDepth node implementation supports only equal input/output tensors with rank <= 5
|
||||
if (auto stdOp = std::dynamic_pointer_cast<const ::ngraph::opset3::SpaceToDepth>(node)) {
|
||||
return stdOp->input_value(0).get_shape().size() <= 5lu && stdOp->input_value(0).get_shape().size() == stdOp->get_output_shape(0).size();
|
||||
}
|
||||
|
||||
if (auto fc_op = std::dynamic_pointer_cast<const ngraph::op::FullyConnected>(node)) {
|
||||
return fc_op->input_value(0).get_shape().size() == 3ul;
|
||||
}
|
||||
|
||||
return std::dynamic_pointer_cast<const ngraph::opset2::Gelu>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset2::BatchToSpace>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset2::SpaceToBatch>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset3::ExtractImagePatches>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::HSwish>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL1>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL2>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::SoftPlus>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
|
||||
};
|
||||
auto nGraphFunc = clonedNetwork->getFunction();
|
||||
// Disable shape inference (WA for generic operations)
|
||||
ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
|
||||
@ -116,7 +104,41 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
|
||||
manager.set_callback(transformations_callback);
|
||||
auto pass_config = manager.get_pass_config();
|
||||
|
||||
using const_node_ptr = const std::shared_ptr<const ngraph::Node>;
|
||||
|
||||
// SpaceToDepth/ DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
|
||||
pass_config->set_callback<ngraph::pass::ConvertSpaceToDepth,
|
||||
ngraph::pass::ConvertDepthToSpace>(
|
||||
[](const_node_ptr &node) -> bool {
|
||||
return node->input_value(0).get_shape().size() <= 5lu &&
|
||||
node->input_value(0).get_shape().size() == node->get_output_shape(0).size();
|
||||
});
|
||||
|
||||
// Disable FC reshaping for 3D case
|
||||
pass_config->set_callback<ngraph::pass::ReshapeFullyConnected>(
|
||||
[](const_node_ptr &node) -> bool {
|
||||
return node->input_value(0).get_shape().size() == 3ul;
|
||||
});
|
||||
|
||||
pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
|
||||
ngraph::pass::ConvertSpaceToBatch>(
|
||||
[](const_node_ptr &node) -> bool {
|
||||
const auto & rank = node->input(0).get_partial_shape().rank().get_length();
|
||||
return rank == 4lu || rank == 5lu;
|
||||
});
|
||||
|
||||
// List of enabled/disabled transformations
|
||||
pass_config->disable<ngraph::pass::ConvertGELU>();
|
||||
pass_config->disable<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
|
||||
pass_config->disable<ngraph::pass::HSwishDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ReduceL1Decomposition>();
|
||||
pass_config->disable<ngraph::pass::ReduceL2Decomposition>();
|
||||
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
|
||||
|
||||
pass_config->enable<ngraph::pass::ConvertPadToGroupConvolution>();
|
||||
|
||||
manager.run_passes(nGraphFunc);
|
||||
|
||||
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
|
||||
|
@ -68,7 +68,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
|
||||
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
|
||||
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
||||
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
|
||||
@ -111,7 +111,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
// Propagate local PassConfig to internal pass::Manager
|
||||
manager.set_pass_config(get_pass_config());
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
|
||||
if (!dts_node || m_transformation_callback(dts_node)) {
|
||||
if (!dts_node || transformation_callback(dts_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -154,7 +154,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
manager.set_pass_config(get_pass_config());
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
|
||||
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
manager.set_pass_config(get_pass_config());
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::ConvertExtractImagePatchesToReorgYolo>();
|
||||
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
manager.set_pass_config(get_pass_config());
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ ngraph::pass::ConvertPadToGroupConvolution::ConvertPadToGroupConvolution() {
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto pad = std::dynamic_pointer_cast<ngraph::opset4::Pad> (m.get_match_root());
|
||||
if (!pad || !m_transformation_callback(pad) /* disabled by default */) {
|
||||
if (!pad) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -155,11 +155,6 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
|
||||
std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
|
||||
depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
|
||||
ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
|
||||
|
||||
if (!m_transformation_callback(depth_to_space)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::replace_node(reshape_after, depth_to_space);
|
||||
return true;
|
||||
};
|
||||
|
@ -126,12 +126,34 @@ public:
|
||||
m_matchers.push_back(pass);
|
||||
}
|
||||
|
||||
template <typename T, class... Args>
|
||||
/// \brief Register given transformation class type to GraphRewrite execution list
|
||||
/// All registered transformations will be executed in a single graph traversal.
|
||||
/// Example below show the basic usage of pass::GraphRewrite
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<MatcherPassA>();
|
||||
/// anchor->add_matcher<MatcherPassB>();
|
||||
/// anchor->set_name("CommonMathcers");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// anchor->add_matcher<MatcherPassB, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enabled = true, class... Args>
|
||||
std::shared_ptr<T> add_matcher(Args&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
|
||||
"pass not derived from MatcherPass");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
pass->set_pass_config(pass_config);
|
||||
if (!Enabled)
|
||||
{
|
||||
pass_config->disable<T>();
|
||||
}
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
|
@ -22,7 +22,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pass/pass_config.hpp"
|
||||
#include "ngraph/pass/validate.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -39,14 +38,31 @@ public:
|
||||
Manager();
|
||||
~Manager();
|
||||
|
||||
template <typename T, class... Args>
|
||||
/// \brief Register given transformation class type to execution list
|
||||
/// Example below show the basic usage of pass::Manager
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// manager.register_pass<MyTransformation, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enable = true, class... Args>
|
||||
std::shared_ptr<T> register_pass(Args&&... args)
|
||||
{
|
||||
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
||||
rc->set_pass_config(m_pass_config);
|
||||
if (m_per_pass_validation)
|
||||
{
|
||||
push_pass<Validate>();
|
||||
}
|
||||
if (!Enable)
|
||||
{
|
||||
m_pass_config->disable<T>();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -59,8 +75,10 @@ public:
|
||||
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
|
||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||
/// transformations. In some cases plugins may want not to execute some transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance reasons.
|
||||
/// transformations based on some conditions. In some cases plugins may want not to execute some
|
||||
/// transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance reasons for
|
||||
/// some cases.
|
||||
/// Callback example:
|
||||
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||
@ -69,15 +87,22 @@ public:
|
||||
/// decomposition pass will check is this decomposition needed or plugin can execute this
|
||||
/// operation directly. And of course on transformation side we need to have a response for this
|
||||
/// callback.
|
||||
/// if (m_transformation_callback(batch_to_space)) {
|
||||
/// if (transformation_callback(batch_to_space)) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// \param callback lamda function that returns true in case if node is supported by plugin and
|
||||
/// transformation is not needed
|
||||
void set_callback(param_callback callback)
|
||||
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||
void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); }
|
||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||
/// configuration.
|
||||
/// This object allows to disable/enable transformations execution, set callback to particular
|
||||
/// transformation. For mo details see PassConfig class.
|
||||
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
|
||||
/// \brief Set external PassConfig object.
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
|
||||
{
|
||||
m_transformation_callback = callback;
|
||||
m_has_default_callback = false;
|
||||
*m_pass_config = *pass_config;
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -91,11 +116,7 @@ protected:
|
||||
return pass;
|
||||
}
|
||||
|
||||
param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
|
||||
return false;
|
||||
};
|
||||
bool m_has_default_callback = true;
|
||||
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
||||
bool m_visualize = false;
|
||||
bool m_per_pass_validation = true;
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "ngraph/deprecated.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pass/pass_config.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -39,7 +40,6 @@ namespace ngraph
|
||||
|
||||
typedef EnumMask<PassProperty> PassPropertyMask;
|
||||
const PassPropertyMask all_pass_property_off;
|
||||
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
|
||||
|
||||
class NGRAPH_API PassBase
|
||||
{
|
||||
@ -54,8 +54,40 @@ namespace ngraph
|
||||
void set_name(const std::string& name) { m_name = name; }
|
||||
std::string get_name() const;
|
||||
|
||||
/// \brief Set callback for particular transformation type.
|
||||
/// This method set global callback. For more details see PassConfig class
|
||||
/// documentation.
|
||||
/// \param callback lambda function that takes node and returns bool
|
||||
void set_callback(const param_callback& callback);
|
||||
|
||||
/// \brief Set PassConfig for particular transformation instance
|
||||
/// \param pass_config is a PassConfig shared_ptr
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
|
||||
{
|
||||
m_pass_config = pass_config;
|
||||
}
|
||||
|
||||
/// \brief Allows to access PassConfig shared instance
|
||||
/// \return Shared instance of PassConfig class
|
||||
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// This method remains here only for backward compatibility and will be removed
|
||||
/// after all transformations are moved to transformation_callback() method.
|
||||
/// \return result of callback execution for given node
|
||||
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
|
||||
bool m_transformation_callback(const std::shared_ptr<const Node>& node)
|
||||
{
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// \param node which will be used inside callback
|
||||
/// \return result of callback execution for given node
|
||||
bool transformation_callback(const std::shared_ptr<const Node>& node)
|
||||
{
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
@ -63,13 +95,11 @@ namespace ngraph
|
||||
protected:
|
||||
void set_property(const PassPropertyMask& prop, bool value);
|
||||
|
||||
param_callback m_transformation_callback =
|
||||
[](const std::shared_ptr<const Node>&) -> bool { return false; };
|
||||
bool m_has_default_callback = true;
|
||||
|
||||
private:
|
||||
PassPropertyMask m_property;
|
||||
|
||||
std::string m_name;
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
};
|
||||
|
||||
class NGRAPH_API FunctionPass : public PassBase
|
||||
|
@ -16,31 +16,157 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/ngraph_visibility.hpp>
|
||||
#include "ngraph/deprecated.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class PassConfig;
|
||||
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
|
||||
using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
|
||||
|
||||
/// \brief Class representing a transformations config that is used for disabling/enabling
|
||||
/// transformations registered inside pass::Manager and also allows to set callback for all
|
||||
/// transformations or for particular transformation.
|
||||
///
|
||||
/// When pass::Manager is created all passes registered inside this manager including nested
|
||||
/// passes will share the same instance of PassConfig class.
|
||||
/// To work with this class first you need to get shared instance of this class by calling
|
||||
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
|
||||
/// on transformations type_info. For example:
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<CommonOptimizations>();
|
||||
/// auto pass_config = manager.get_pass_config();
|
||||
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
|
||||
/// // CommonOptimizations pipeline
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// Sometimes it is needed to call transformation inside other transformation manually. And
|
||||
/// for that case before running transformation you need manually check that this pass is
|
||||
/// not disabled and then you need to set current PassConfig instance to this
|
||||
/// transformation. For example:
|
||||
///
|
||||
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
|
||||
/// // you need to call get_pass_config() method to get shared instance of PassConfig
|
||||
/// auto pass_config = get_pass_config();
|
||||
///
|
||||
/// // Before running nested transformation you need to check is it disabled or not
|
||||
/// if (!pass_config->is_disabled<ConvertGELU>()) {
|
||||
/// auto pass = ConvertGELU();
|
||||
/// pass->set_pass_config(pass_config);
|
||||
/// pass.apply(node);
|
||||
/// }
|
||||
///
|
||||
/// Following this logic inside your transformations you will guaranty that transformations
|
||||
/// will be executed in a right way.
|
||||
class NGRAPH_API PassConfig
|
||||
{
|
||||
public:
|
||||
/// \brief Disable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); }
|
||||
/// \brief Disable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void disable()
|
||||
{
|
||||
disable(T::type_info);
|
||||
}
|
||||
|
||||
/// \brief Enable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); }
|
||||
/// \brief Enable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void enable()
|
||||
{
|
||||
enable(T::type_info);
|
||||
}
|
||||
|
||||
/// \brief Set callback for all kind of transformations
|
||||
void set_callback(const param_callback& callback) { m_callback = callback; }
|
||||
template <typename... Args>
|
||||
typename std::enable_if<sizeof...(Args) == 0>::type
|
||||
set_callback(const param_callback& callback)
|
||||
{
|
||||
}
|
||||
|
||||
/// \brief Set callback for particular transformation class types
|
||||
///
|
||||
/// Example below show how to set callback for one or multiple passes using this method.
|
||||
///
|
||||
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
|
||||
/// ngraph::pass::ConvertSpaceToBatch>(
|
||||
/// [](const_node_ptr &node) -> bool {
|
||||
/// // Disable transformations for cases when input shape rank is not
|
||||
/// equal to 4
|
||||
/// const auto input_shape_rank =
|
||||
/// node->get_output_partial_shape(0).rank().get_length();
|
||||
/// if (input_shape_rank != 4) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// return true;
|
||||
/// });
|
||||
///
|
||||
/// Note that inside transformations you must provide code that work with this callback.
|
||||
/// See example below:
|
||||
///
|
||||
/// if (transformation_callback(node)) {
|
||||
/// return false; // exit from transformation
|
||||
/// }
|
||||
///
|
||||
template <typename T, class... Args>
|
||||
void set_callback(const param_callback& callback)
|
||||
{
|
||||
m_callback_map[T::type_info] = callback;
|
||||
set_callback<Args...>(callback);
|
||||
}
|
||||
|
||||
/// \brief Get callback for given transformation type_info
|
||||
/// \param type_info Transformation type_info
|
||||
///
|
||||
/// In case if callback wasn't set for given transformation type then global callback
|
||||
/// will be returned. But if even global callback wasn't set then default callback will
|
||||
/// be returned.
|
||||
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
|
||||
|
||||
/// \brief Get callback for given transformation class type
|
||||
/// \return callback lambda function
|
||||
template <typename T>
|
||||
param_callback get_callback() const
|
||||
{
|
||||
return get_callback(T::type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation type is disabled or not
|
||||
/// \param type_info Transformation type_info
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
bool is_disabled(const DiscreteTypeInfo& type_info) const
|
||||
{
|
||||
return m_disabled.count(type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation class type is disabled or not
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
template <typename T>
|
||||
bool is_disabled() const
|
||||
{
|
||||
return is_disabled(T::type_info);
|
||||
}
|
||||
|
||||
private:
|
||||
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
||||
return false;
|
||||
};
|
||||
param_callback_map m_callback_map;
|
||||
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::PassConfig
|
||||
{
|
||||
public:
|
||||
PassConfig();
|
||||
const std::map<std::string, bool>& get_enables() const { return m_pass_enables; }
|
||||
void set_pass_enable(const std::string& name, bool enable);
|
||||
bool get_pass_enable(const std::string& name) const;
|
||||
const std::map<std::string, bool>& get_pass_attributes() const { return m_pass_attributes; }
|
||||
void set_pass_attribute(const std::string& name, bool enable);
|
||||
bool get_pass_attribute(const std::string& name) const;
|
||||
|
||||
private:
|
||||
std::map<std::string, bool> m_pass_enables;
|
||||
std::map<std::string, bool> m_pass_attributes;
|
||||
};
|
||||
}
|
@ -72,6 +72,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
|
||||
|
||||
bool rewritten = false;
|
||||
const auto& pass_config = get_pass_config();
|
||||
|
||||
// Initialize execution queue with nodes in topological order
|
||||
deque<std::shared_ptr<Node>> nodes_to_run;
|
||||
@ -85,6 +86,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
|
||||
std::unordered_map<NodeTypeInfo, std::vector<size_t>> type_to_matcher;
|
||||
for (size_t matcher_index = 0; matcher_index < m_matchers.size(); ++matcher_index)
|
||||
{
|
||||
// Skip passes that are disabled
|
||||
if (pass_config->is_disabled(m_matchers[matcher_index]->get_type_info()))
|
||||
continue;
|
||||
|
||||
auto matcher = m_matchers[matcher_index]->get_matcher();
|
||||
if (!matcher)
|
||||
{
|
||||
@ -139,11 +144,6 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_has_default_callback)
|
||||
{
|
||||
m_pass->set_callback(m_transformation_callback);
|
||||
}
|
||||
|
||||
// Apply MatcherPass. In case if it returns true no other MatcherPasses will apply
|
||||
// to this node
|
||||
bool status = m_pass->apply(node);
|
||||
@ -224,6 +224,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
|
||||
{
|
||||
for (auto& m_pass : m_matchers)
|
||||
{
|
||||
// Skip passes that are disabled
|
||||
if (pass_config->is_disabled(m_pass->get_type_info()))
|
||||
continue;
|
||||
|
||||
if (run_matcher_pass(m_pass, node))
|
||||
{
|
||||
rewritten = true;
|
||||
|
@ -36,6 +36,7 @@ using namespace ngraph;
|
||||
|
||||
pass::Manager::Manager()
|
||||
: m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING"))
|
||||
, m_pass_config(std::make_shared<PassConfig>())
|
||||
{
|
||||
}
|
||||
|
||||
@ -56,12 +57,14 @@ void pass::Manager::run_passes(shared_ptr<Function> func)
|
||||
bool function_changed = false;
|
||||
for (auto& pass : m_pass_list)
|
||||
{
|
||||
pass_timer.start();
|
||||
if (!m_has_default_callback)
|
||||
if (m_pass_config->is_disabled(pass->get_type_info()))
|
||||
{
|
||||
pass->set_callback(m_transformation_callback);
|
||||
NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled";
|
||||
continue;
|
||||
}
|
||||
|
||||
pass_timer.start();
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass))
|
||||
{
|
||||
|
@ -33,6 +33,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
|
||||
|
||||
pass::PassBase::PassBase()
|
||||
: m_property{all_pass_property_off}
|
||||
, m_pass_config(std::make_shared<PassConfig>())
|
||||
{
|
||||
}
|
||||
|
||||
@ -73,8 +74,7 @@ std::string pass::PassBase::get_name() const
|
||||
|
||||
void pass::PassBase::set_callback(const param_callback& callback)
|
||||
{
|
||||
m_transformation_callback = callback;
|
||||
m_has_default_callback = false;
|
||||
m_pass_config->set_callback(callback);
|
||||
}
|
||||
|
||||
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
|
||||
|
@ -15,100 +15,18 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/pass/pass_config.hpp"
|
||||
#include "ngraph/env_util.hpp"
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
// TODO: Add file-based configuration support
|
||||
pass::PassConfig::PassConfig()
|
||||
pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const
|
||||
{
|
||||
//
|
||||
// Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
|
||||
// and returns the pass names and whether they should be enabled or disabled in the
|
||||
// provided unordered_map. Implementation of pass selection is up to the backend
|
||||
// E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
|
||||
// set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
|
||||
//
|
||||
string pass_enables = getenv_string("NGRAPH_PASS_ENABLES");
|
||||
if (!pass_enables.empty())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << pass_enables;
|
||||
while (ss.good())
|
||||
{
|
||||
string substr;
|
||||
getline(ss, substr, ';');
|
||||
auto split_str = split(substr, ':', false);
|
||||
switch (split_str.size())
|
||||
{
|
||||
case 1: m_pass_enables.emplace(split_str[0], true); break;
|
||||
case 2: m_pass_enables.emplace(split_str[0], parse_string<bool>(split_str[1])); break;
|
||||
default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ENABLES: " + substr);
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// Parses the semi-colon separated environment string passed through NGRAPH_PASS_ATTRIBUTES
|
||||
// and returns the pass attributes and whether they should be enabled or disabled in the
|
||||
// provided unordered_map. Naming of pass attributes is up to the backends.
|
||||
//
|
||||
// For example:
|
||||
// NGRAPH_PASS_ATTRIBUTES="OptimizeForMemory=0;MemoryAssignment::ReuseMemory=1;UseDefaultLayouts"
|
||||
// would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
|
||||
// "UseDefaultLayouts"
|
||||
//
|
||||
static const string pass_attributes = getenv_string("NGRAPH_PASS_ATTRIBUTES");
|
||||
if (!pass_attributes.empty())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << pass_attributes;
|
||||
while (ss.good())
|
||||
{
|
||||
string substr;
|
||||
getline(ss, substr, ';');
|
||||
auto split_str = split(substr, '=', false);
|
||||
switch (split_str.size())
|
||||
{
|
||||
case 1: m_pass_attributes.emplace(split_str[0], true); break;
|
||||
case 2:
|
||||
m_pass_attributes.emplace(split_str[0], parse_string<bool>(split_str[1]));
|
||||
break;
|
||||
default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ATTRIBUTES: " + substr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void pass::PassConfig::set_pass_enable(const string& name, bool enable)
|
||||
{
|
||||
m_pass_enables[name] = enable;
|
||||
}
|
||||
|
||||
bool pass::PassConfig::get_pass_enable(const string& name) const
|
||||
{
|
||||
auto it = m_pass_enables.find(name);
|
||||
if (it != m_pass_enables.end())
|
||||
const auto& it = m_callback_map.find(type_info);
|
||||
if (it != m_callback_map.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void pass::PassConfig::set_pass_attribute(const string& name, bool enable)
|
||||
{
|
||||
m_pass_attributes[name] = enable;
|
||||
}
|
||||
|
||||
bool pass::PassConfig::get_pass_attribute(const string& name) const
|
||||
{
|
||||
auto it = m_pass_attributes.find(name);
|
||||
if (it != m_pass_attributes.end())
|
||||
else
|
||||
{
|
||||
return it->second;
|
||||
return m_callback;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ using namespace ngraph;
|
||||
class TestPass : public ngraph::pass::MatcherPass
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
TestPass()
|
||||
: MatcherPass()
|
||||
{
|
||||
@ -39,12 +40,16 @@ public:
|
||||
class Anchor : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Anchor()
|
||||
: GraphRewrite()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
|
||||
NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
|
||||
|
||||
std::shared_ptr<Function> get_function()
|
||||
{
|
||||
auto data =
|
||||
@ -93,7 +98,7 @@ TEST(GraphRewriteTest, GraphRewriteCallback)
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
TEST(GraphRewriteTest, ManagerCallback)
|
||||
TEST(GraphRewriteTest, ManagerCallbackDeprecated)
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
@ -106,6 +111,20 @@ TEST(GraphRewriteTest, ManagerCallback)
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
TEST(GraphRewriteTest, ManagerCallback)
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
auto anchor = manager.register_pass<Anchor>();
|
||||
anchor->add_matcher<TestPass>();
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback(get_callback());
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
TEST(GraphRewriteTest, ManagerCallback2)
|
||||
{
|
||||
auto f = get_function();
|
||||
@ -244,4 +263,127 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2)
|
||||
anchor.run_on_function(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
|
||||
}
|
||||
|
||||
TEST(PassConfigTest, Test1)
|
||||
{
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<TestPass>();
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback(get_callback());
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<TestPass>();
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback<TestPass>(get_callback());
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<TestPass>();
|
||||
|
||||
auto pass_config = std::make_shared<ngraph::pass::PassConfig>();
|
||||
pass_config->set_callback<TestPass>(get_callback());
|
||||
|
||||
manager.set_pass_config(pass_config);
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
auto anchor = manager.register_pass<Anchor>();
|
||||
anchor->add_matcher<TestPass>();
|
||||
|
||||
auto pass_config = anchor->get_pass_config();
|
||||
pass_config->set_callback(get_callback());
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
auto anchor = manager.register_pass<Anchor>();
|
||||
anchor->add_matcher<TestPass>();
|
||||
|
||||
auto pass_config = anchor->get_pass_config();
|
||||
pass_config->set_callback<TestPass>(get_callback());
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
{
|
||||
auto pass_config = std::make_shared<pass::PassConfig>();
|
||||
|
||||
pass::Manager manager1;
|
||||
pass::Manager manager2;
|
||||
manager1.set_pass_config(pass_config);
|
||||
manager2.set_pass_config(pass_config);
|
||||
ASSERT_EQ(pass_config.use_count(), 1);
|
||||
}
|
||||
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<TestPass>();
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback<TestPass>(get_callback());
|
||||
|
||||
pass_config->disable<TestPass>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
|
||||
|
||||
pass_config->enable<TestPass>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
pass::Manager manager;
|
||||
auto anchor = manager.register_pass<Anchor>();
|
||||
anchor->add_matcher<TestPass>();
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback<TestPass>(get_callback());
|
||||
|
||||
pass_config->disable<TestPass>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 0);
|
||||
|
||||
pass_config->enable<TestPass>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
}
|
@ -51,6 +51,8 @@ set (SRC
|
||||
pass/shape_relevance.hpp
|
||||
)
|
||||
|
||||
disable_deprecated_warnings()
|
||||
|
||||
add_library(ngraph_backend SHARED ${SRC})
|
||||
target_compile_definitions(ngraph_backend
|
||||
PRIVATE
|
||||
|
@ -102,14 +102,6 @@ std::shared_ptr<ngraph::runtime::Tensor>
|
||||
throw std::invalid_argument("This backend does not support dynamic tensors");
|
||||
}
|
||||
|
||||
std::shared_ptr<runtime::Executable>
|
||||
runtime::Backend::compile(std::shared_ptr<Function> func,
|
||||
ngraph::pass::PassConfig& /* pass_config */,
|
||||
bool enable_performance_data)
|
||||
{
|
||||
return compile(func, enable_performance_data);
|
||||
}
|
||||
|
||||
bool runtime::Backend::is_supported(const Node& /* node */) const
|
||||
{
|
||||
// The default behavior is that a backend does not support any ops. If this is not the case
|
||||
|
@ -22,7 +22,6 @@
|
||||
#include "backend_visibility.hpp"
|
||||
#include "executable.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/pass/pass_config.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
@ -111,14 +110,6 @@ public:
|
||||
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
|
||||
bool enable_performance_data = false) = 0;
|
||||
|
||||
/// \brief Compiles a Function.
|
||||
/// \param func The function to compile
|
||||
/// \param pass_config Configuration object for defining compilation options
|
||||
/// \returns compiled function or nullptr on failure
|
||||
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
|
||||
ngraph::pass::PassConfig& pass_config,
|
||||
bool enable_performance_data = false);
|
||||
|
||||
/// \brief Loads a previously saved Executable object from a stream.
|
||||
/// \param input_stream the opened input stream containing the saved Executable
|
||||
/// \returns A compiled function or throws an exception on error
|
||||
|
Loading…
Reference in New Issue
Block a user