Reduce number of Function Validations (#8668)

This commit is contained in:
Gleb Kazantaev 2021-11-19 13:04:49 +03:00 committed by GitHub
parent 1ca9592d75
commit 0e749b8b15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 28 additions and 4 deletions

View File

@ -110,6 +110,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Function> func) {
bool enableInt8; bool enableInt8;
{ {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
enableInt8 = config.enableInt8 && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(func); enableInt8 = config.enableInt8 && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(func);
if (enableInt8) { if (enableInt8) {
manager.register_pass<ngraph::pass::DisableConvertConstantFoldingOnConstPath>( manager.register_pass<ngraph::pass::DisableConvertConstantFoldingOnConstPath>(
@ -160,6 +162,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Function> func) {
{ngraph::element::u4, ngraph::element::u8}, {ngraph::element::u4, ngraph::element::u8},
}; };
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list); manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
auto pass_config = manager.get_pass_config(); auto pass_config = manager.get_pass_config();

View File

@ -129,6 +129,7 @@ Engine::~Engine() {
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT) { static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::InitNodeInfo>(); manager.register_pass<ngraph::pass::InitNodeInfo>();
const bool useLpt = const bool useLpt =
@ -187,6 +188,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
manager.register_pass<ngraph::pass::low_precision::ConvertSubtractConstant>( manager.register_pass<ngraph::pass::low_precision::ConvertSubtractConstant>(
std::vector<ngraph::element::Type>{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 }); std::vector<ngraph::element::Type>{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
} }
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions); manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
manager.register_pass<ngraph::pass::EliminateConvert>(); manager.register_pass<ngraph::pass::EliminateConvert>();

View File

@ -96,6 +96,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations",
bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(CommonOptimizations); RUN_ON_FUNCTION_SCOPE(CommonOptimizations);
ngraph::pass::Manager manager(get_pass_config()); ngraph::pass::Manager manager(get_pass_config());
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::DisableDecompressionConvertConstantFolding>(); manager.register_pass<ov::pass::DisableDecompressionConvertConstantFolding>();
@ -186,6 +187,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
// because we cannot insert any MaxPools since they may prevent // because we cannot insert any MaxPools since they may prevent
// other optimizations // other optimizations
manager.register_pass<ngraph::pass::StridesOptimization>(); manager.register_pass<ngraph::pass::StridesOptimization>();
manager.register_pass<ngraph::pass::Validate>();
manager.run_passes(f); manager.run_passes(f);

View File

@ -68,6 +68,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
} }
ngraph::pass::Manager manager(get_pass_config()); ngraph::pass::Manager manager(get_pass_config());
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::InitNodeInfo>(); manager.register_pass<ngraph::pass::InitNodeInfo>();
if (m_low_precision_enabled) { if (m_low_precision_enabled) {
@ -79,10 +80,15 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
} }
manager.register_pass<ngraph::pass::DisableRandomUniformConstantFolding>(); manager.register_pass<ngraph::pass::DisableRandomUniformConstantFolding>();
manager.register_pass<ngraph::pass::ConstantFolding>(); manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::Validate>();
// FusedFilteringBoxesBySize transformation has the complex pattern // FusedFilteringBoxesBySize transformation has the complex pattern
// which can be affected by further transformations. So we have to // which can be affected by further transformations. So we have to
// execute it at the beginning of the pipeline. // execute it at the beginning of the pipeline. Also, this pass resolves
// dynamism, so we have to execute type/shape propagation after.
manager.register_pass<ngraph::pass::FuseFilteringBoxesBySize>(); manager.register_pass<ngraph::pass::FuseFilteringBoxesBySize>();
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>(); manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>(); manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
if (!m_use_shapes) { if (!m_use_shapes) {

View File

@ -293,12 +293,18 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSu
bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(SimplifyShapeOfSubGraph); RUN_ON_FUNCTION_SCOPE(SimplifyShapeOfSubGraph);
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::EliminateGatherUnsqueeze>(); manager.register_pass<ngraph::pass::EliminateGatherUnsqueeze>();
manager.register_pass<ngraph::pass::SharedShapeOf>(); manager.register_pass<ngraph::pass::SharedShapeOf>();
manager.register_pass<ngraph::pass::GroupedGatherElimination>(); manager.register_pass<ngraph::pass::GroupedGatherElimination>();
// GatherNopElimination depends on shape, so it requires shape propagation
// if previous transformations has resolved some dynamic shapes.
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::GatherNopElimination>(); manager.register_pass<ngraph::pass::GatherNopElimination>();
manager.register_pass<ngraph::pass::SimplifyGatherShapeOf>(); manager.register_pass<ngraph::pass::SimplifyGatherShapeOf>();
manager.register_pass<ngraph::pass::SimplifySecondInputOfReshape>(); manager.register_pass<ngraph::pass::SimplifySecondInputOfReshape>();
// TODO: potentially this Validate is not needed but it requires additional validation
manager.register_pass<ngraph::pass::Validate>();
manager.run_passes(f); manager.run_passes(f);
return false; return false;
} }

View File

@ -18,6 +18,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet2ToOpSet1, "ConvertOpSet2ToOpSe
bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(ConvertOpSet2ToOpSet1); RUN_ON_FUNCTION_SCOPE(ConvertOpSet2ToOpSet1);
ngraph::pass::Manager manager(get_pass_config()); ngraph::pass::Manager manager(get_pass_config());
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>(); manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
manager.register_pass<ngraph::pass::ConvertBatchToSpace>(); manager.register_pass<ngraph::pass::ConvertBatchToSpace>();

View File

@ -21,6 +21,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet3ToOpSet2, "ConvertOpSet3ToOpSe
bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(ConvertOpSet3ToOpSet2); RUN_ON_FUNCTION_SCOPE(ConvertOpSet3ToOpSet2);
ngraph::pass::Manager manager(get_pass_config()); ngraph::pass::Manager manager(get_pass_config());
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::ConvertBroadcast3>(); manager.register_pass<ngraph::pass::ConvertBroadcast3>();
manager.register_pass<ngraph::pass::ConvertShapeOf3>(); manager.register_pass<ngraph::pass::ConvertShapeOf3>();

View File

@ -55,9 +55,8 @@ public:
/// \brief Set flag to enable/disable running Validate pass after executing /// \brief Set flag to enable/disable running Validate pass after executing
/// each registered pass /// each registered pass
/// \param new_state Value "true" enables Validate pass run; "false", otherwise /// \param new_state Value "true" enables Validate pass run; "false", otherwise
void set_per_pass_validation(bool new_state) { void set_per_pass_validation(bool new_state);
m_per_pass_validation = new_state;
}
/// \brief Callback is a lambda function that can be used by registered transformations. /// \brief Callback is a lambda function that can be used by registered transformations.
/// The main purpose of this callback is to provide a way for plugins to disable/enable /// The main purpose of this callback is to provide a way for plugins to disable/enable
/// transformations based on some conditions. In some cases plugins may want not to /// transformations based on some conditions. In some cases plugins may want not to

View File

@ -44,6 +44,10 @@ ov::pass::Manager::~Manager() = default;
ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {} ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {}
void ov::pass::Manager::set_per_pass_validation(bool new_state) {
m_per_pass_validation = new_state;
}
void ov::pass::Manager::run_passes(shared_ptr<ov::Function> func) { void ov::pass::Manager::run_passes(shared_ptr<ov::Function> func) {
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes"); OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes");