Fix Pruning for case with INT8 GroupConvolution operation (#6872)

This commit is contained in:
Gleb Kazantaev 2021-08-02 12:36:51 +03:00 committed by GitHub
parent bfca47ad5e
commit cc5f63d87a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,12 +19,12 @@ namespace mask_propagation {
class Convolution;
class GroupConvolution;
class GroupConvolutionReshape;
class Elementwise;
class PassThrough;
class StopPropagation;
class FakeQuantize;
class Concat;
class Reshape;
} // namespace mask_propagation
} // namespace pass
@ -192,9 +192,9 @@ public:
}
};
class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
class ngraph::pass::mask_propagation::GroupConvolutionReshape : public MatcherPass {
public:
Reshape() {
GroupConvolutionReshape() {
auto input = pattern::any_input(pattern::has_static_shape());
auto shape = pattern::any_input();
// Working only for Reshapes on Group Convolution weights
@ -258,10 +258,12 @@ public:
ngraph::replace_node(old_shape_const, new_const);
setMask(m_output, output_mask);
return true;
// This transformation propagates only Reshape mask and doesn't do anything with GroupConvolution.
// So, not to disable GroupConvolution mask propagation we return false here.
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "ReshapeMaskPropagation");
auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ReshapeMaskPropagation");
register_matcher(m, callback);
}
};
@ -604,11 +606,11 @@ public:
ngraph::pass::PropagateMasks::PropagateMasks() {
add_matcher<mask_propagation::Convolution>();
add_matcher<mask_propagation::GroupConvolutionReshape>();
add_matcher<mask_propagation::GroupConvolution>();
add_matcher<mask_propagation::Elementwise>();
add_matcher<mask_propagation::PassThrough>();
add_matcher<mask_propagation::FakeQuantize>();
add_matcher<mask_propagation::Concat>();
add_matcher<mask_propagation::Reshape>();
add_matcher<mask_propagation::StopPropagation>();
}