Fix Pruning for case with INT8 GroupConvolution operation (#6872)

This commit is contained in:
Gleb Kazantaev 2021-08-02 12:36:51 +03:00 committed by GitHub
parent bfca47ad5e
commit cc5f63d87a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,12 +19,12 @@ namespace mask_propagation {
class Convolution; class Convolution;
class GroupConvolution; class GroupConvolution;
class GroupConvolutionReshape;
class Elementwise; class Elementwise;
class PassThrough; class PassThrough;
class StopPropagation; class StopPropagation;
class FakeQuantize; class FakeQuantize;
class Concat; class Concat;
class Reshape;
} // namespace mask_propagation } // namespace mask_propagation
} // namespace pass } // namespace pass
@ -192,9 +192,9 @@ public:
} }
}; };
class ngraph::pass::mask_propagation::Reshape : public MatcherPass { class ngraph::pass::mask_propagation::GroupConvolutionReshape : public MatcherPass {
public: public:
Reshape() { GroupConvolutionReshape() {
auto input = pattern::any_input(pattern::has_static_shape()); auto input = pattern::any_input(pattern::has_static_shape());
auto shape = pattern::any_input(); auto shape = pattern::any_input();
// Working only for Reshapes on Group Convolution weights // Working only for Reshapes on Group Convolution weights
@ -258,10 +258,12 @@ public:
ngraph::replace_node(old_shape_const, new_const); ngraph::replace_node(old_shape_const, new_const);
setMask(m_output, output_mask); setMask(m_output, output_mask);
return true; // This transformation propagates only Reshape mask and doesn't do anything with GroupConvolution.
// So, not to disable GroupConvolution mask propagation we return false here.
return false;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "ReshapeMaskPropagation"); auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ReshapeMaskPropagation");
register_matcher(m, callback); register_matcher(m, callback);
} }
}; };
@ -604,11 +606,11 @@ public:
ngraph::pass::PropagateMasks::PropagateMasks() { ngraph::pass::PropagateMasks::PropagateMasks() {
add_matcher<mask_propagation::Convolution>(); add_matcher<mask_propagation::Convolution>();
add_matcher<mask_propagation::GroupConvolutionReshape>();
add_matcher<mask_propagation::GroupConvolution>(); add_matcher<mask_propagation::GroupConvolution>();
add_matcher<mask_propagation::Elementwise>(); add_matcher<mask_propagation::Elementwise>();
add_matcher<mask_propagation::PassThrough>(); add_matcher<mask_propagation::PassThrough>();
add_matcher<mask_propagation::FakeQuantize>(); add_matcher<mask_propagation::FakeQuantize>();
add_matcher<mask_propagation::Concat>(); add_matcher<mask_propagation::Concat>();
add_matcher<mask_propagation::Reshape>();
add_matcher<mask_propagation::StopPropagation>(); add_matcher<mask_propagation::StopPropagation>();
} }