diff --git a/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp b/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp index 671cc6f8885..e944ffff57b 100644 --- a/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp +++ b/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp @@ -19,12 +19,12 @@ namespace mask_propagation { class Convolution; class GroupConvolution; +class GroupConvolutionReshape; class Elementwise; class PassThrough; class StopPropagation; class FakeQuantize; class Concat; -class Reshape; } // namespace mask_propagation } // namespace pass @@ -192,9 +192,9 @@ public: } }; -class ngraph::pass::mask_propagation::Reshape : public MatcherPass { +class ngraph::pass::mask_propagation::GroupConvolutionReshape : public MatcherPass { public: - Reshape() { + GroupConvolutionReshape() { auto input = pattern::any_input(pattern::has_static_shape()); auto shape = pattern::any_input(); // Working only for Reshapes on Group Convolution weights @@ -258,10 +258,12 @@ public: ngraph::replace_node(old_shape_const, new_const); setMask(m_output, output_mask); - return true; + // This transformation propagates only Reshape mask and doesn't do anything with GroupConvolution. + // So, not to disable GroupConvolution mask propagation we return false here. + return false; }; - auto m = std::make_shared(reshape, "ReshapeMaskPropagation"); + auto m = std::make_shared(gconv, "ReshapeMaskPropagation"); register_matcher(m, callback); } }; @@ -604,11 +606,11 @@ public: ngraph::pass::PropagateMasks::PropagateMasks() { add_matcher(); + add_matcher(); add_matcher(); add_matcher(); add_matcher(); add_matcher(); add_matcher(); - add_matcher(); add_matcher(); }