Generic LoadTime Optimizations (#4011)

* Updated container passes to return false to avoid excess function validation

* Added support for nested GraphRewrite registration

* Updated passes to use MatcherPass; Reorganized CommonOptimizations pipeline

* Disable node validation when graph is not modified
This commit is contained in:
Gleb Kazantaev 2021-01-27 11:53:11 +03:00 committed by GitHub
parent 68f2507811
commit b6d9dba444
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 306 additions and 200 deletions

View File

@ -25,15 +25,6 @@ class INFERENCE_ENGINE_API_CLASS(ConvertMatMulToGemm);
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvertMatMulToFCorGemm() {
add_matcher<ngraph::pass::ConvertMatMulToFC>();
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
}
};
class ngraph::pass::ConvertMatMulToFC: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
@ -45,3 +36,12 @@ public:
NGRAPH_RTTI_DECLARATION;
ConvertMatMulToGemm();
};
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvertMatMulToFCorGemm() {
add_matcher<ngraph::pass::ConvertMatMulToFC>();
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
}
};

View File

@ -22,16 +22,6 @@ class INFERENCE_ENGINE_API_CLASS(Reshape1DMaxPool);
} // namespace pass
} // namespace ngraph
class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
Reshape1DOps() {
add_matcher<ngraph::pass::Reshape1DConvolution>();
add_matcher<ngraph::pass::Reshape1DAvgPool>();
add_matcher<ngraph::pass::Reshape1DMaxPool>();
}
};
class ngraph::pass::Reshape1DConvolution: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
@ -48,4 +38,14 @@ class ngraph::pass::Reshape1DMaxPool: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
Reshape1DMaxPool();
};
class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
Reshape1DOps() {
add_matcher<ngraph::pass::Reshape1DConvolution>();
add_matcher<ngraph::pass::Reshape1DAvgPool>();
add_matcher<ngraph::pass::Reshape1DMaxPool>();
}
};

View File

@ -145,5 +145,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.run_passes(f);
return true;
// Returning value is false because pass::Manager always apply Validation pass
// if function was changed. This helps to avoid excess Validations after applying
// this pass. In future when we will return more meaningful status code it will be
// replaced with real status reported by manager.run_passes() method call.
return false;
}

View File

@ -29,8 +29,8 @@ class TRANSFORMATIONS_API AlgebraicSimplification;
} // namespace pass
} // namespace ngraph
class ngraph::pass::AlgebraicSimplification : public FunctionPass {
class ngraph::pass::AlgebraicSimplification : public GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
AlgebraicSimplification();
};

View File

@ -19,25 +19,9 @@ class TRANSFORMATIONS_API HSigmoidFusionWithReluMul;
class TRANSFORMATIONS_API HSigmoidFusionWithoutRelu;
class TRANSFORMATIONS_API HSigmoidFusionWithClamp;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces various sub-graphs with a HSigmoid op.
*/
class ngraph::pass::HSigmoidFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusion() {
add_matcher<ngraph::pass::HSigmoidFusionWithReluDiv>();
add_matcher<ngraph::pass::HSigmoidFusionWithReluMul>();
add_matcher<ngraph::pass::HSigmoidFusionWithoutRelu>();
add_matcher<ngraph::pass::HSigmoidFusionWithClamp>();
}
};
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSigmoid op.
@ -77,3 +61,18 @@ public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusionWithClamp();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces various sub-graphs with a HSigmoid op.
*/
class ngraph::pass::HSigmoidFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusion() {
add_matcher<ngraph::pass::HSigmoidFusionWithReluDiv>();
add_matcher<ngraph::pass::HSigmoidFusionWithReluMul>();
add_matcher<ngraph::pass::HSigmoidFusionWithoutRelu>();
add_matcher<ngraph::pass::HSigmoidFusionWithClamp>();
}
};

View File

@ -23,21 +23,6 @@ class TRANSFORMATIONS_API HSwishFusionWithClamp;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
*/
class ngraph::pass::HSwishFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
HSwishFusion() {
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
}
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSwish op.
@ -77,3 +62,18 @@ public:
NGRAPH_RTTI_DECLARATION;
HSwishFusionWithClamp();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
*/
class ngraph::pass::HSwishFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
HSwishFusion() {
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
}
};

View File

@ -21,16 +21,6 @@ class TRANSFORMATIONS_API MultiplyMultiplyFusion;
} // namespace pass
} // namespace ngraph
class ngraph::pass::LinOpSequenceFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
LinOpSequenceFusion() {
add_matcher<ngraph::pass::AddMultiplyFusion>();
add_matcher<ngraph::pass::AddAddFusion>();
add_matcher<ngraph::pass::MultiplyMultiplyFusion>();
}
};
class ngraph::pass::AddMultiplyFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
@ -48,3 +38,13 @@ public:
NGRAPH_RTTI_DECLARATION;
MultiplyMultiplyFusion();
};
class ngraph::pass::LinOpSequenceFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
LinOpSequenceFusion() {
add_matcher<ngraph::pass::AddMultiplyFusion>();
add_matcher<ngraph::pass::AddAddFusion>();
add_matcher<ngraph::pass::MultiplyMultiplyFusion>();
}
};

View File

@ -20,8 +20,8 @@ class TRANSFORMATIONS_API NopElimination;
} // namespace pass
} // namespace ngraph
class ngraph::pass::NopElimination: public ngraph::pass::FunctionPass {
class ngraph::pass::NopElimination: public GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<Function>) override;
NopElimination();
};

View File

@ -21,19 +21,6 @@ class TRANSFORMATIONS_API NormalizeL2FusionWithAdd;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2Fusion transformation replaces various sub-graphs with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2Fusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
NormalizeL2Fusion() {
add_matcher<ngraph::pass::NormalizeL2FusionWithMax>();
add_matcher<ngraph::pass::NormalizeL2FusionWithAdd>();
}
};
/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2FusionWithMax transformation replaces a sub-graph
@ -55,3 +42,16 @@ public:
NGRAPH_RTTI_DECLARATION;
NormalizeL2FusionWithAdd();
};
/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2Fusion transformation replaces various sub-graphs with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2Fusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
NormalizeL2Fusion() {
add_matcher<ngraph::pass::NormalizeL2FusionWithMax>();
add_matcher<ngraph::pass::NormalizeL2FusionWithAdd>();
}
};

View File

@ -22,21 +22,6 @@ class TRANSFORMATIONS_API SwishFusionWithoutBeta;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
*/
class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
SwishFusion() {
add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
add_matcher<ngraph::pass::SwishFusionWithBeta>();
add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
}
};
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x) with a Swish op.
@ -76,3 +61,18 @@ public:
NGRAPH_RTTI_DECLARATION;
SwishFusionWithoutBeta();
};
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
*/
class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
SwishFusion() {
add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
add_matcher<ngraph::pass::SwishFusionWithBeta>();
add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
}
};

View File

@ -15,6 +15,8 @@
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API BidirectionalSequenceDecomposition;
class TRANSFORMATIONS_API BidirectionalLSTMSequenceDecomposition;
class TRANSFORMATIONS_API BidirectionalGRUSequenceDecomposition;
class TRANSFORMATIONS_API BidirectionalRNNSequenceDecomposition;
@ -57,3 +59,19 @@ public:
NGRAPH_RTTI_DECLARATION;
BidirectionalRNNSequenceDecomposition();
};
/**
* @ingroup ie_transformation_common_api
* @brief Container for all types of sequences decomposition.
*
*/
class ngraph::pass::BidirectionalSequenceDecomposition : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
BidirectionalSequenceDecomposition() {
add_matcher<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
add_matcher<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
add_matcher<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
}
};

View File

@ -24,17 +24,6 @@ class TRANSFORMATIONS_API ConvertGroupDeconvolution;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertConvolutions: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvertConvolutions() {
add_matcher<ngraph::pass::ConvertConvolution>();
add_matcher<ngraph::pass::ConvertGroupConvolution>();
add_matcher<ngraph::pass::ConvertDeconvolution>();
add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
}
};
class ngraph::pass::ConvertConvolution: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
@ -57,4 +46,15 @@ class ngraph::pass::ConvertGroupDeconvolution: public ngraph::pass::MatcherPass
public:
NGRAPH_RTTI_DECLARATION;
ConvertGroupDeconvolution();
};
};
class ngraph::pass::ConvertConvolutions: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvertConvolutions() {
add_matcher<ngraph::pass::ConvertConvolution>();
add_matcher<ngraph::pass::ConvertGroupConvolution>();
add_matcher<ngraph::pass::ConvertDeconvolution>();
add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
}
};

View File

@ -35,16 +35,6 @@ public:
ngraph::matcher_pass_callback convert_reduce_to_pooling();
};
class ngraph::pass::ConvertReduceToPooling: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvertReduceToPooling() {
add_matcher<ConvertReduceMeanToPooling>();
add_matcher<ConvertReduceMaxToPooling>();
add_matcher<ConvertReduceSumToPooling>();
}
};
class ngraph::pass::ConvertReduceMeanToPooling: public ConvertReduceBase {
public:
NGRAPH_RTTI_DECLARATION;
@ -63,6 +53,16 @@ public:
ConvertReduceSumToPooling();
};
class ngraph::pass::ConvertReduceToPooling: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvertReduceToPooling() {
add_matcher<ConvertReduceMeanToPooling>();
add_matcher<ConvertReduceMaxToPooling>();
add_matcher<ConvertReduceSumToPooling>();
}
};
template <class T>
ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
return [&](ngraph::pattern::Matcher& m) {

View File

@ -25,6 +25,7 @@
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
using namespace std;
using namespace ngraph;
@ -234,31 +235,32 @@ static bool replace_transpose_with_reshape(shared_ptr<Node> transpose) {
return true;
}
bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f) {
RUN_ON_FUNCTION_SCOPE(AlgebraicSimplification);
static const unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
{{opset3::Gather::type_info, simplify_gather},
{opset2::ShapeOf::type_info, simplify_gather_shapeof},
{opset3::ShapeOf::type_info, simplify_gather_shapeof},
{opset3::Transpose::type_info, replace_transpose_with_reshape}};
#define ECHO(NAME) #NAME
#define STR(NAME) ECHO(NAME)
#define SIMPLE_MATCHER_PASS_DEFINITION(NAME, OP, FUNC) \
class NAME : public ngraph::pass::MatcherPass { \
public: \
NGRAPH_RTTI_DECLARATION; \
NAME() { \
MATCHER_SCOPE(NAME); \
auto match_node = ngraph::pattern::wrap_type<OP>(); \
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { \
return FUNC(m.get_match_root()); \
}; \
auto m = std::make_shared<ngraph::pattern::Matcher>(match_node, matcher_name); \
register_matcher(m, callback); \
} \
}; \
NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);
bool replaced = false;
for (auto n : f->get_ordered_ops()) {
if (op::is_output(n) || op::is_parameter(n)) {
continue;
}
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, opset3::Gather, simplify_gather);
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf2Gather, opset2::ShapeOf, simplify_gather_shapeof);
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf3Gather, opset3::ShapeOf, simplify_gather_shapeof);
SIMPLE_MATCHER_PASS_DEFINITION(ConvertTransposeToReshape, opset3::Transpose, replace_transpose_with_reshape);
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(n)) {
if (auto sub_graph = sub_graph_node->get_function()) {
replaced |= run_on_function(sub_graph);
}
}
auto eh = ops_to_simplifiers.find(n->get_type_info());
if (eh != ops_to_simplifiers.end()) {
replaced |= eh->second(n);
}
}
return replaced;
ngraph::pass::AlgebraicSimplification::AlgebraicSimplification() {
add_matcher<EliminateGather>();
add_matcher<SimplifyShapeOf2Gather>();
add_matcher<SimplifyShapeOf3Gather>();
add_matcher<ConvertTransposeToReshape>();
}

View File

@ -41,7 +41,7 @@ ngraph::pass::ClampFusion::ClampFusion() {
double min_value = min_const->cast_vector<double>()[0];
double max_value = max_const->cast_vector<double>()[0];
auto clamp = std::make_shared<ngraph::opset5::Clamp>(data, min_value, max_value);
auto clamp = register_new_node<ngraph::opset5::Clamp>(data, min_value, max_value);
auto minimum = pattern_map.at(min_pattern);
clamp->set_friendly_name(minimum.get_node()->get_friendly_name());

View File

@ -64,39 +64,47 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
// This pass must be called first in pipeline
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
// TODO: move to KMB
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
manager.register_pass<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
manager.register_pass<ngraph::pass::EliminateUnsqueezeGather>();
manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
auto eliminations = manager.register_pass<ngraph::pass::GraphRewrite>();
eliminations->add_matcher<ngraph::pass::EliminateUnsqueezeGather>();
eliminations->add_matcher<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
eliminations->add_matcher<ngraph::pass::NopElimination>(); // may introduce fake dynamism
eliminations->set_name("ngraph::pass::CommonEliminations");
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.register_pass<ngraph::pass::MishFusion>();
manager.register_pass<ngraph::pass::SoftPlusFusion>();
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::ClampFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
common_fusions->add_matcher<ngraph::pass::ConvertScatterElementsToScatter>();
common_fusions->add_matcher<ngraph::pass::DepthToSpaceFusion>();
//common_fusions->add_matcher<ngraph::pass::MishFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
common_fusions->add_matcher<ngraph::pass::HSwishFusion>();
common_fusions->add_matcher<ngraph::pass::HSigmoidFusion>();
common_fusions->add_matcher<ngraph::pass::NormalizeL2Fusion>();
common_fusions->add_matcher<ngraph::pass::ClampFusion>();
common_fusions->add_matcher<ngraph::pass::PadFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
manager.register_pass<ngraph::pass::PadFusion>();
manager.register_pass<ngraph::pass::ConvertInterpolate1ToInterpolate4, false>();
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
decomp->add_matcher<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
decomp->add_matcher<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
decomp->add_matcher<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
decomp->add_matcher<ngraph::pass::BidirectionalSequenceDecomposition>();
decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>();
decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>();
decomp->add_matcher<ngraph::pass::HSwishDecomposition>();
decomp->add_matcher<ngraph::pass::HSigmoidDecomposition>();
decomp->add_matcher<ngraph::pass::LogSoftmaxDecomposition>();
decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
decomp->add_matcher<ngraph::pass::ConvertReduceToPooling>();
decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
decomp->add_matcher<ngraph::pass::ConvertMod>();
decomp->add_matcher<ngraph::pass::ConvertGELU>();
@ -116,12 +124,14 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
// LinOpSequenceFusion must be executed after all decompositions
manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
manager.register_pass<ngraph::pass::ConvolutionMultiplyFusion>();
manager.register_pass<ngraph::pass::GroupConvolutionMultiplyFusion>();
manager.register_pass<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
auto conv_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
conv_fusions->add_matcher<ngraph::pass::ConvolutionMultiplyFusion>();
conv_fusions->add_matcher<ngraph::pass::GroupConvolutionMultiplyFusion>();
conv_fusions->add_matcher<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
conv_fusions->add_matcher<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
conv_fusions->set_name("ngraph::pass::ConvFusions");
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertInterpolate1ToInterpolate4, false>();
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
@ -131,5 +141,10 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
manager.run_passes(f);
return true;
// Returning value is false because pass::Manager always apply Validation pass
// if function was changed. This helps to avoid excess Validations after applying
// this pass. In future when we will return more meaningful status code it will be
// replaced with real status reported by manager.run_passes() method call.
return false;
}

View File

@ -25,12 +25,15 @@
#include <ngraph/util.hpp>
#include <ngraph/log.hpp>
#include <transformations/common_optimizations/nop_elimination.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
using namespace std;
using namespace ngraph;
#define TI(x) x::type_info
NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);
static bool eliminate_nop(const std::shared_ptr<Node>& node) {
// skip if shapes are dynamic
if (node->get_input_partial_shape(0).is_dynamic() ||
@ -327,33 +330,38 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node) {
return false;
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);
#define ECHO(NAME) #NAME
#define STR(NAME) ECHO(NAME)
#define SIMPLE_MATCHER_PASS_DEFINITION(NAME, OP, FUNC) \
class NAME : public ngraph::pass::MatcherPass { \
public: \
NGRAPH_RTTI_DECLARATION; \
NAME() { \
MATCHER_SCOPE(NAME); \
auto match_node = ngraph::pattern::wrap_type<OP>(); \
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { \
return FUNC(m.get_match_root()); \
}; \
auto m = std::make_shared<ngraph::pattern::Matcher>(match_node, matcher_name); \
register_matcher(m, callback); \
} \
}; \
NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);
bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function) {
RUN_ON_FUNCTION_SCOPE(NopElimination);
static const std::unordered_map<NodeTypeInfo, std::function<bool(const std::shared_ptr<Node>&)>>
dispatcher{{TI(opset3::Pad), &eliminate_nop},
{TI(opset3::Convert), &eliminate_convert},
{TI(opset3::Reshape), &eliminate_reshape_v1},
{TI(opset3::Concat), &eliminate_concat},
{TI(opset3::Squeeze), &eliminate_squeeze},
{TI(op::v1::Broadcast), &eliminate_nop},
{TI(opset3::Unsqueeze), &eliminate_unsqueeze}};
SIMPLE_MATCHER_PASS_DEFINITION(EliminatePad, opset3::Pad, eliminate_nop);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateConvert, opset3::Convert, eliminate_convert);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateReshape, opset3::Reshape, eliminate_reshape_v1);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateConcat, opset3::Concat, eliminate_concat);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateSqueeze, opset3::Squeeze, eliminate_squeeze);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateUnsqueeze, opset3::Unsqueeze, eliminate_unsqueeze);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateBroadcast, op::v1::Broadcast, eliminate_nop);
bool clobbered = false;
for (const auto& node : function->get_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
clobbered |= run_on_function(sub_graph);
}
}
auto handler = dispatcher.find(node->get_type_info());
if (handler != dispatcher.end()) {
clobbered |= handler->second(node);
}
}
return clobbered;
}
ngraph::pass::NopElimination::NopElimination() {
add_matcher<EliminatePad>();
add_matcher<EliminateConvert>();
add_matcher<EliminateReshape>();
add_matcher<EliminateConcat>();
add_matcher<EliminateSqueeze>();
add_matcher<EliminateUnsqueeze>();
add_matcher<EliminateBroadcast>();
}

View File

@ -12,6 +12,7 @@
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalSequenceDecomposition, "BidirectionalSequenceDecomposition", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalLSTMSequenceDecomposition, "BidirectionalLSTMSequenceDecomposition", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalGRUSequenceDecomposition, "BidirectionalGRUSequenceDecomposition", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalRNNSequenceDecomposition, "BidirectionalRNNSequenceDecomposition", 0);

View File

@ -23,5 +23,10 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
manager.run_passes(f);
return true;
// Returning value is false because pass::Manager always apply Validation pass
// if function was changed. This helps to avoid excess Validations after applying
// this pass. In future when we will return more meaningful status code it will be
// replaced with real status reported by manager.run_passes() method call.
return false;
}

View File

@ -29,5 +29,10 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
manager.run_passes(f);
return true;
// Returning value is false because pass::Manager always apply Validation pass
// if function was changed. This helps to avoid excess Validations after applying
// this pass. In future when we will return more meaningful status code it will be
// replaced with real status reported by manager.run_passes() method call.
return false;
}

View File

@ -148,7 +148,7 @@ namespace ngraph
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMathcers");
/// anchor->set_name("CommonMatchers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
@ -156,7 +156,11 @@ namespace ngraph
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enabled = true, class... Args>
template <
typename T,
bool Enabled = true,
class... Args,
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>{}, bool>::type = true>
std::shared_ptr<T> add_matcher(Args&&... args)
{
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
@ -171,6 +175,47 @@ namespace ngraph
m_matchers.push_back(pass);
return pass;
}
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
/// passes registered in its ctor.
/// For example:
///
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
/// public:
/// NGRAPH_RTTI_DECLARATION;
/// Fusions() {
/// add_matcher<ngraph::pass::AddFusion>();
/// add_matcher<ngraph::pass::MulFusion>();
/// }
/// };
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<LinFusions>();
/// anchor->add_matcher<OtherFusions>();
/// anchor->set_name("CommonFusions");
/// manager.run_passes(f);
///
/// In this case all matcher passes from LinFusions pass will be united with other
/// registered matchers.
template <typename T,
class... Args,
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>{},
bool>::type = true>
void add_matcher(Args&&... args)
{
static_assert(std::is_base_of<pass::GraphRewrite, T>::value,
"pass not derived from GraphRewrite");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
for (auto& matcher : pass->m_matchers)
{
pass->set_pass_config(pass_config);
m_matchers.push_back(matcher);
}
}
NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,

View File

@ -29,7 +29,10 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
for (const auto& node : f->get_ordered_ops())
{
node->revalidate_and_infer_types();
if (rewritten)
{
node->revalidate_and_infer_types();
}
OutputVector replacements(node->get_output_size());
if (node->constant_fold(replacements, node->input_values()))