Generic LoadTime Optimizations (#4011)
* Updated container passes to return false to avoid excess function validation * Added support for nested GraphRewrite registration * Updated passes to use MatcherPass; Reorganized CommonOptimizations pipeline * Disable node validation when graph is not modified
This commit is contained in:
parent
68f2507811
commit
b6d9dba444
@ -25,15 +25,6 @@ class INFERENCE_ENGINE_API_CLASS(ConvertMatMulToGemm);
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertMatMulToFCorGemm() {
|
||||
add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertMatMulToFC: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -45,3 +36,12 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertMatMulToGemm();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertMatMulToFCorGemm() {
|
||||
add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||
}
|
||||
};
|
@ -22,16 +22,6 @@ class INFERENCE_ENGINE_API_CLASS(Reshape1DMaxPool);
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Reshape1DOps() {
|
||||
add_matcher<ngraph::pass::Reshape1DConvolution>();
|
||||
add_matcher<ngraph::pass::Reshape1DAvgPool>();
|
||||
add_matcher<ngraph::pass::Reshape1DMaxPool>();
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::Reshape1DConvolution: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -48,4 +38,14 @@ class ngraph::pass::Reshape1DMaxPool: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Reshape1DMaxPool();
|
||||
};
|
||||
|
||||
class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Reshape1DOps() {
|
||||
add_matcher<ngraph::pass::Reshape1DConvolution>();
|
||||
add_matcher<ngraph::pass::Reshape1DAvgPool>();
|
||||
add_matcher<ngraph::pass::Reshape1DMaxPool>();
|
||||
}
|
||||
};
|
@ -145,5 +145,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
|
||||
// Returning value is false because pass::Manager always apply Validation pass
|
||||
// if function was changed. This helps to avoid excess Validations after applying
|
||||
// this pass. In future when we will return more meaningful status code it will be
|
||||
// replaced with real status reported by manager.run_passes() method call.
|
||||
return false;
|
||||
}
|
||||
|
@ -29,8 +29,8 @@ class TRANSFORMATIONS_API AlgebraicSimplification;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::AlgebraicSimplification : public FunctionPass {
|
||||
class ngraph::pass::AlgebraicSimplification : public GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
AlgebraicSimplification();
|
||||
};
|
||||
|
@ -19,25 +19,9 @@ class TRANSFORMATIONS_API HSigmoidFusionWithReluMul;
|
||||
class TRANSFORMATIONS_API HSigmoidFusionWithoutRelu;
|
||||
class TRANSFORMATIONS_API HSigmoidFusionWithClamp;
|
||||
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSigmoidFusion transformation replaces various sub-graphs with a HSigmoid op.
|
||||
*/
|
||||
class ngraph::pass::HSigmoidFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSigmoidFusion() {
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithReluDiv>();
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithReluMul>();
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithoutRelu>();
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithClamp>();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSigmoidFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSigmoid op.
|
||||
@ -77,3 +61,18 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSigmoidFusionWithClamp();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSigmoidFusion transformation replaces various sub-graphs with a HSigmoid op.
|
||||
*/
|
||||
class ngraph::pass::HSigmoidFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSigmoidFusion() {
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithReluDiv>();
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithReluMul>();
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithoutRelu>();
|
||||
add_matcher<ngraph::pass::HSigmoidFusionWithClamp>();
|
||||
}
|
||||
};
|
@ -23,21 +23,6 @@ class TRANSFORMATIONS_API HSwishFusionWithClamp;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSwishFusion() {
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSwish op.
|
||||
@ -77,3 +62,18 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSwishFusionWithClamp();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSwishFusion() {
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
|
||||
}
|
||||
};
|
@ -21,16 +21,6 @@ class TRANSFORMATIONS_API MultiplyMultiplyFusion;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::LinOpSequenceFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
LinOpSequenceFusion() {
|
||||
add_matcher<ngraph::pass::AddMultiplyFusion>();
|
||||
add_matcher<ngraph::pass::AddAddFusion>();
|
||||
add_matcher<ngraph::pass::MultiplyMultiplyFusion>();
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::AddMultiplyFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -48,3 +38,13 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MultiplyMultiplyFusion();
|
||||
};
|
||||
|
||||
class ngraph::pass::LinOpSequenceFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
LinOpSequenceFusion() {
|
||||
add_matcher<ngraph::pass::AddMultiplyFusion>();
|
||||
add_matcher<ngraph::pass::AddAddFusion>();
|
||||
add_matcher<ngraph::pass::MultiplyMultiplyFusion>();
|
||||
}
|
||||
};
|
@ -20,8 +20,8 @@ class TRANSFORMATIONS_API NopElimination;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::NopElimination: public ngraph::pass::FunctionPass {
|
||||
class ngraph::pass::NopElimination: public GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<Function>) override;
|
||||
NopElimination();
|
||||
};
|
||||
|
@ -21,19 +21,6 @@ class TRANSFORMATIONS_API NormalizeL2FusionWithAdd;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief NormalizeL2Fusion transformation replaces various sub-graphs with a NormalizeL2 op.
|
||||
*/
|
||||
class ngraph::pass::NormalizeL2Fusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
NormalizeL2Fusion() {
|
||||
add_matcher<ngraph::pass::NormalizeL2FusionWithMax>();
|
||||
add_matcher<ngraph::pass::NormalizeL2FusionWithAdd>();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief NormalizeL2FusionWithMax transformation replaces a sub-graph
|
||||
@ -55,3 +42,16 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
NormalizeL2FusionWithAdd();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief NormalizeL2Fusion transformation replaces various sub-graphs with a NormalizeL2 op.
|
||||
*/
|
||||
class ngraph::pass::NormalizeL2Fusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
NormalizeL2Fusion() {
|
||||
add_matcher<ngraph::pass::NormalizeL2FusionWithMax>();
|
||||
add_matcher<ngraph::pass::NormalizeL2FusionWithAdd>();
|
||||
}
|
||||
};
|
@ -22,21 +22,6 @@ class TRANSFORMATIONS_API SwishFusionWithoutBeta;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
|
||||
*/
|
||||
class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SwishFusion() {
|
||||
add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithBeta>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x) with a Swish op.
|
||||
@ -76,3 +61,18 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SwishFusionWithoutBeta();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
|
||||
*/
|
||||
class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SwishFusion() {
|
||||
add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithBeta>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
|
||||
}
|
||||
};
|
@ -15,6 +15,8 @@
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API BidirectionalSequenceDecomposition;
|
||||
|
||||
class TRANSFORMATIONS_API BidirectionalLSTMSequenceDecomposition;
|
||||
class TRANSFORMATIONS_API BidirectionalGRUSequenceDecomposition;
|
||||
class TRANSFORMATIONS_API BidirectionalRNNSequenceDecomposition;
|
||||
@ -57,3 +59,19 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
BidirectionalRNNSequenceDecomposition();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Container for all types of sequences decomposition.
|
||||
*
|
||||
*/
|
||||
|
||||
class ngraph::pass::BidirectionalSequenceDecomposition : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
BidirectionalSequenceDecomposition() {
|
||||
add_matcher<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
||||
add_matcher<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
|
||||
add_matcher<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
|
||||
}
|
||||
};
|
||||
|
@ -24,17 +24,6 @@ class TRANSFORMATIONS_API ConvertGroupDeconvolution;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertConvolutions: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertConvolutions() {
|
||||
add_matcher<ngraph::pass::ConvertConvolution>();
|
||||
add_matcher<ngraph::pass::ConvertGroupConvolution>();
|
||||
add_matcher<ngraph::pass::ConvertDeconvolution>();
|
||||
add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertConvolution: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -57,4 +46,15 @@ class ngraph::pass::ConvertGroupDeconvolution: public ngraph::pass::MatcherPass
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGroupDeconvolution();
|
||||
};
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertConvolutions: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertConvolutions() {
|
||||
add_matcher<ngraph::pass::ConvertConvolution>();
|
||||
add_matcher<ngraph::pass::ConvertGroupConvolution>();
|
||||
add_matcher<ngraph::pass::ConvertDeconvolution>();
|
||||
add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
|
||||
}
|
||||
};
|
||||
|
@ -35,16 +35,6 @@ public:
|
||||
ngraph::matcher_pass_callback convert_reduce_to_pooling();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertReduceToPooling: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertReduceToPooling() {
|
||||
add_matcher<ConvertReduceMeanToPooling>();
|
||||
add_matcher<ConvertReduceMaxToPooling>();
|
||||
add_matcher<ConvertReduceSumToPooling>();
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertReduceMeanToPooling: public ConvertReduceBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -63,6 +53,16 @@ public:
|
||||
ConvertReduceSumToPooling();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertReduceToPooling: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertReduceToPooling() {
|
||||
add_matcher<ConvertReduceMeanToPooling>();
|
||||
add_matcher<ConvertReduceMaxToPooling>();
|
||||
add_matcher<ConvertReduceSumToPooling>();
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
|
||||
return [&](ngraph::pattern::Matcher& m) {
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -234,31 +235,32 @@ static bool replace_transpose_with_reshape(shared_ptr<Node> transpose) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f) {
|
||||
RUN_ON_FUNCTION_SCOPE(AlgebraicSimplification);
|
||||
static const unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
|
||||
{{opset3::Gather::type_info, simplify_gather},
|
||||
{opset2::ShapeOf::type_info, simplify_gather_shapeof},
|
||||
{opset3::ShapeOf::type_info, simplify_gather_shapeof},
|
||||
{opset3::Transpose::type_info, replace_transpose_with_reshape}};
|
||||
#define ECHO(NAME) #NAME
|
||||
#define STR(NAME) ECHO(NAME)
|
||||
#define SIMPLE_MATCHER_PASS_DEFINITION(NAME, OP, FUNC) \
|
||||
class NAME : public ngraph::pass::MatcherPass { \
|
||||
public: \
|
||||
NGRAPH_RTTI_DECLARATION; \
|
||||
NAME() { \
|
||||
MATCHER_SCOPE(NAME); \
|
||||
auto match_node = ngraph::pattern::wrap_type<OP>(); \
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { \
|
||||
return FUNC(m.get_match_root()); \
|
||||
}; \
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(match_node, matcher_name); \
|
||||
register_matcher(m, callback); \
|
||||
} \
|
||||
}; \
|
||||
NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);
|
||||
|
||||
bool replaced = false;
|
||||
for (auto n : f->get_ordered_ops()) {
|
||||
if (op::is_output(n) || op::is_parameter(n)) {
|
||||
continue;
|
||||
}
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, opset3::Gather, simplify_gather);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf2Gather, opset2::ShapeOf, simplify_gather_shapeof);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf3Gather, opset3::ShapeOf, simplify_gather_shapeof);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(ConvertTransposeToReshape, opset3::Transpose, replace_transpose_with_reshape);
|
||||
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(n)) {
|
||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
||||
replaced |= run_on_function(sub_graph);
|
||||
}
|
||||
}
|
||||
|
||||
auto eh = ops_to_simplifiers.find(n->get_type_info());
|
||||
if (eh != ops_to_simplifiers.end()) {
|
||||
replaced |= eh->second(n);
|
||||
}
|
||||
}
|
||||
return replaced;
|
||||
ngraph::pass::AlgebraicSimplification::AlgebraicSimplification() {
|
||||
add_matcher<EliminateGather>();
|
||||
add_matcher<SimplifyShapeOf2Gather>();
|
||||
add_matcher<SimplifyShapeOf3Gather>();
|
||||
add_matcher<ConvertTransposeToReshape>();
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ ngraph::pass::ClampFusion::ClampFusion() {
|
||||
double min_value = min_const->cast_vector<double>()[0];
|
||||
double max_value = max_const->cast_vector<double>()[0];
|
||||
|
||||
auto clamp = std::make_shared<ngraph::opset5::Clamp>(data, min_value, max_value);
|
||||
auto clamp = register_new_node<ngraph::opset5::Clamp>(data, min_value, max_value);
|
||||
auto minimum = pattern_map.at(min_pattern);
|
||||
clamp->set_friendly_name(minimum.get_node()->get_friendly_name());
|
||||
|
||||
|
@ -64,39 +64,47 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
// This pass must be called first in pipeline
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
|
||||
|
||||
// TODO: move to KMB
|
||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
|
||||
manager.register_pass<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
|
||||
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
|
||||
manager.register_pass<ngraph::pass::EliminateUnsqueezeGather>();
|
||||
manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
|
||||
|
||||
auto eliminations = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
eliminations->add_matcher<ngraph::pass::EliminateUnsqueezeGather>();
|
||||
eliminations->add_matcher<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
|
||||
eliminations->add_matcher<ngraph::pass::NopElimination>(); // may introduce fake dynamism
|
||||
eliminations->set_name("ngraph::pass::CommonEliminations");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.register_pass<ngraph::pass::MishFusion>();
|
||||
manager.register_pass<ngraph::pass::SoftPlusFusion>();
|
||||
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.register_pass<ngraph::pass::ClampFusion>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusion>();
|
||||
manager.register_pass<ngraph::pass::HSigmoidFusion>();
|
||||
|
||||
auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
common_fusions->add_matcher<ngraph::pass::ConvertScatterElementsToScatter>();
|
||||
common_fusions->add_matcher<ngraph::pass::DepthToSpaceFusion>();
|
||||
//common_fusions->add_matcher<ngraph::pass::MishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::HSwishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::HSigmoidFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::NormalizeL2Fusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::ClampFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::PadFusion>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
|
||||
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
|
||||
manager.register_pass<ngraph::pass::PadFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvertInterpolate1ToInterpolate4, false>();
|
||||
|
||||
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
decomp->add_matcher<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::BidirectionalSequenceDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>();
|
||||
decomp->add_matcher<ngraph::pass::HSwishDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::HSigmoidDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::LogSoftmaxDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMod>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertGELU>();
|
||||
@ -116,12 +124,14 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
// LinOpSequenceFusion must be executed after all decompositions
|
||||
manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvolutionMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
||||
auto conv_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
conv_fusions->add_matcher<ngraph::pass::ConvolutionMultiplyFusion>();
|
||||
conv_fusions->add_matcher<ngraph::pass::GroupConvolutionMultiplyFusion>();
|
||||
conv_fusions->add_matcher<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
|
||||
conv_fusions->add_matcher<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
||||
conv_fusions->set_name("ngraph::pass::ConvFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConvertInterpolate1ToInterpolate4, false>();
|
||||
|
||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
@ -131,5 +141,10 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
||||
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
|
||||
// Returning value is false because pass::Manager always apply Validation pass
|
||||
// if function was changed. This helps to avoid excess Validations after applying
|
||||
// this pass. In future when we will return more meaningful status code it will be
|
||||
// replaced with real status reported by manager.run_passes() method call.
|
||||
return false;
|
||||
}
|
||||
|
@ -25,12 +25,15 @@
|
||||
#include <ngraph/util.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
#define TI(x) x::type_info
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);
|
||||
|
||||
static bool eliminate_nop(const std::shared_ptr<Node>& node) {
|
||||
// skip if shapes are dynamic
|
||||
if (node->get_input_partial_shape(0).is_dynamic() ||
|
||||
@ -327,33 +330,38 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);
|
||||
#define ECHO(NAME) #NAME
|
||||
#define STR(NAME) ECHO(NAME)
|
||||
#define SIMPLE_MATCHER_PASS_DEFINITION(NAME, OP, FUNC) \
|
||||
class NAME : public ngraph::pass::MatcherPass { \
|
||||
public: \
|
||||
NGRAPH_RTTI_DECLARATION; \
|
||||
NAME() { \
|
||||
MATCHER_SCOPE(NAME); \
|
||||
auto match_node = ngraph::pattern::wrap_type<OP>(); \
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { \
|
||||
return FUNC(m.get_match_root()); \
|
||||
}; \
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(match_node, matcher_name); \
|
||||
register_matcher(m, callback); \
|
||||
} \
|
||||
}; \
|
||||
NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);
|
||||
|
||||
bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function) {
|
||||
RUN_ON_FUNCTION_SCOPE(NopElimination);
|
||||
static const std::unordered_map<NodeTypeInfo, std::function<bool(const std::shared_ptr<Node>&)>>
|
||||
dispatcher{{TI(opset3::Pad), &eliminate_nop},
|
||||
{TI(opset3::Convert), &eliminate_convert},
|
||||
{TI(opset3::Reshape), &eliminate_reshape_v1},
|
||||
{TI(opset3::Concat), &eliminate_concat},
|
||||
{TI(opset3::Squeeze), &eliminate_squeeze},
|
||||
{TI(op::v1::Broadcast), &eliminate_nop},
|
||||
{TI(opset3::Unsqueeze), &eliminate_unsqueeze}};
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminatePad, opset3::Pad, eliminate_nop);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateConvert, opset3::Convert, eliminate_convert);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateReshape, opset3::Reshape, eliminate_reshape_v1);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateConcat, opset3::Concat, eliminate_concat);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateSqueeze, opset3::Squeeze, eliminate_squeeze);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateUnsqueeze, opset3::Unsqueeze, eliminate_unsqueeze);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateBroadcast, op::v1::Broadcast, eliminate_nop);
|
||||
|
||||
bool clobbered = false;
|
||||
|
||||
for (const auto& node : function->get_ops()) {
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
||||
clobbered |= run_on_function(sub_graph);
|
||||
}
|
||||
}
|
||||
auto handler = dispatcher.find(node->get_type_info());
|
||||
if (handler != dispatcher.end()) {
|
||||
clobbered |= handler->second(node);
|
||||
}
|
||||
}
|
||||
|
||||
return clobbered;
|
||||
}
|
||||
ngraph::pass::NopElimination::NopElimination() {
|
||||
add_matcher<EliminatePad>();
|
||||
add_matcher<EliminateConvert>();
|
||||
add_matcher<EliminateReshape>();
|
||||
add_matcher<EliminateConcat>();
|
||||
add_matcher<EliminateSqueeze>();
|
||||
add_matcher<EliminateUnsqueeze>();
|
||||
add_matcher<EliminateBroadcast>();
|
||||
}
|
@ -12,6 +12,7 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalSequenceDecomposition, "BidirectionalSequenceDecomposition", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalLSTMSequenceDecomposition, "BidirectionalLSTMSequenceDecomposition", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalGRUSequenceDecomposition, "BidirectionalGRUSequenceDecomposition", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BidirectionalRNNSequenceDecomposition, "BidirectionalRNNSequenceDecomposition", 0);
|
||||
|
@ -23,5 +23,10 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
|
||||
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
|
||||
// Returning value is false because pass::Manager always apply Validation pass
|
||||
// if function was changed. This helps to avoid excess Validations after applying
|
||||
// this pass. In future when we will return more meaningful status code it will be
|
||||
// replaced with real status reported by manager.run_passes() method call.
|
||||
return false;
|
||||
}
|
||||
|
@ -29,5 +29,10 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
|
||||
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
|
||||
// Returning value is false because pass::Manager always apply Validation pass
|
||||
// if function was changed. This helps to avoid excess Validations after applying
|
||||
// this pass. In future when we will return more meaningful status code it will be
|
||||
// replaced with real status reported by manager.run_passes() method call.
|
||||
return false;
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ namespace ngraph
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<MatcherPassA>();
|
||||
/// anchor->add_matcher<MatcherPassB>();
|
||||
/// anchor->set_name("CommonMathcers");
|
||||
/// anchor->set_name("CommonMatchers");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
@ -156,7 +156,11 @@ namespace ngraph
|
||||
/// anchor->add_matcher<MatcherPassB, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enabled = true, class... Args>
|
||||
template <
|
||||
typename T,
|
||||
bool Enabled = true,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>{}, bool>::type = true>
|
||||
std::shared_ptr<T> add_matcher(Args&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
|
||||
@ -171,6 +175,47 @@ namespace ngraph
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
|
||||
/// passes registered in its ctor.
|
||||
/// For example:
|
||||
///
|
||||
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
|
||||
/// public:
|
||||
/// NGRAPH_RTTI_DECLARATION;
|
||||
/// Fusions() {
|
||||
/// add_matcher<ngraph::pass::AddFusion>();
|
||||
/// add_matcher<ngraph::pass::MulFusion>();
|
||||
/// }
|
||||
/// };
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<LinFusions>();
|
||||
/// anchor->add_matcher<OtherFusions>();
|
||||
/// anchor->set_name("CommonFusions");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// In this case all matcher passes from LinFusions pass will be united with other
|
||||
/// registered matchers.
|
||||
template <typename T,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>{},
|
||||
bool>::type = true>
|
||||
void add_matcher(Args&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<pass::GraphRewrite, T>::value,
|
||||
"pass not derived from GraphRewrite");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
|
||||
for (auto& matcher : pass->m_matchers)
|
||||
{
|
||||
pass->set_pass_config(pass_config);
|
||||
m_matchers.push_back(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
|
@ -29,7 +29,10 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
|
||||
|
||||
for (const auto& node : f->get_ordered_ops())
|
||||
{
|
||||
node->revalidate_and_infer_types();
|
||||
if (rewritten)
|
||||
{
|
||||
node->revalidate_and_infer_types();
|
||||
}
|
||||
|
||||
OutputVector replacements(node->get_output_size());
|
||||
if (node->constant_fold(replacements, node->input_values()))
|
||||
|
Loading…
Reference in New Issue
Block a user