Fix Pruning for case with INT8 GroupConvolution operation (#6872)
This commit is contained in:
parent
bfca47ad5e
commit
cc5f63d87a
@ -19,12 +19,12 @@ namespace mask_propagation {
|
|||||||
|
|
||||||
class Convolution;
|
class Convolution;
|
||||||
class GroupConvolution;
|
class GroupConvolution;
|
||||||
|
class GroupConvolutionReshape;
|
||||||
class Elementwise;
|
class Elementwise;
|
||||||
class PassThrough;
|
class PassThrough;
|
||||||
class StopPropagation;
|
class StopPropagation;
|
||||||
class FakeQuantize;
|
class FakeQuantize;
|
||||||
class Concat;
|
class Concat;
|
||||||
class Reshape;
|
|
||||||
|
|
||||||
} // namespace mask_propagation
|
} // namespace mask_propagation
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
@ -192,9 +192,9 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
|
class ngraph::pass::mask_propagation::GroupConvolutionReshape : public MatcherPass {
|
||||||
public:
|
public:
|
||||||
Reshape() {
|
GroupConvolutionReshape() {
|
||||||
auto input = pattern::any_input(pattern::has_static_shape());
|
auto input = pattern::any_input(pattern::has_static_shape());
|
||||||
auto shape = pattern::any_input();
|
auto shape = pattern::any_input();
|
||||||
// Working only for Reshapes on Group Convolution weights
|
// Working only for Reshapes on Group Convolution weights
|
||||||
@ -258,10 +258,12 @@ public:
|
|||||||
ngraph::replace_node(old_shape_const, new_const);
|
ngraph::replace_node(old_shape_const, new_const);
|
||||||
|
|
||||||
setMask(m_output, output_mask);
|
setMask(m_output, output_mask);
|
||||||
return true;
|
// This transformation propagates only Reshape mask and doesn't do anything with GroupConvolution.
|
||||||
|
// So, not to disable GroupConvolution mask propagation we return false here.
|
||||||
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "ReshapeMaskPropagation");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ReshapeMaskPropagation");
|
||||||
register_matcher(m, callback);
|
register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -604,11 +606,11 @@ public:
|
|||||||
|
|
||||||
ngraph::pass::PropagateMasks::PropagateMasks() {
|
ngraph::pass::PropagateMasks::PropagateMasks() {
|
||||||
add_matcher<mask_propagation::Convolution>();
|
add_matcher<mask_propagation::Convolution>();
|
||||||
|
add_matcher<mask_propagation::GroupConvolutionReshape>();
|
||||||
add_matcher<mask_propagation::GroupConvolution>();
|
add_matcher<mask_propagation::GroupConvolution>();
|
||||||
add_matcher<mask_propagation::Elementwise>();
|
add_matcher<mask_propagation::Elementwise>();
|
||||||
add_matcher<mask_propagation::PassThrough>();
|
add_matcher<mask_propagation::PassThrough>();
|
||||||
add_matcher<mask_propagation::FakeQuantize>();
|
add_matcher<mask_propagation::FakeQuantize>();
|
||||||
add_matcher<mask_propagation::Concat>();
|
add_matcher<mask_propagation::Concat>();
|
||||||
add_matcher<mask_propagation::Reshape>();
|
|
||||||
add_matcher<mask_propagation::StopPropagation>();
|
add_matcher<mask_propagation::StopPropagation>();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user