Fix constant sub-graph folding and dynamic shape issues inside Pruning transformation (#5768)

* Fix Pruning to avoid CF; added checks for static dims

* Fix shape_size call
This commit is contained in:
Gleb Kazantaev 2021-05-24 23:03:16 +03:00 committed by GitHub
parent eba2410411
commit cce849b90d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 10 deletions

View File

@ -29,7 +29,7 @@ class ngraph::pass::mask_propagation::Convolution : public MatcherPass {
public: public:
Convolution() { Convolution() {
auto input = pattern::any_input(); auto input = pattern::any_input();
auto weights = pattern::any_input(); auto weights = pattern::any_input(pattern::has_static_shape());
auto conv = pattern::wrap_type<opset6::Convolution>({input, weights}); auto conv = pattern::wrap_type<opset6::Convolution>({input, weights});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
@ -92,8 +92,8 @@ public:
class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass { class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass {
public: public:
GroupConvolution() { GroupConvolution() {
auto input = pattern::any_input(); auto input = pattern::any_input(pattern::has_static_dim(1));
auto weights = pattern::any_input(); auto weights = pattern::any_input(pattern::has_static_shape());
auto group_conv = pattern::wrap_type<opset6::GroupConvolution>({input, weights}); auto group_conv = pattern::wrap_type<opset6::GroupConvolution>({input, weights});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
@ -104,9 +104,9 @@ public:
// TODO: check static rank in pattern, use only particular dims // TODO: check static rank in pattern, use only particular dims
auto weights_shape = m_weights.get_shape(); auto weights_shape = m_weights.get_shape();
auto input_shape = m_input.get_shape(); auto input_shape = m_input.get_partial_shape();
// support only depthwise convolutions // support only depthwise convolutions
if (weights_shape[0] != input_shape[1]) { if (weights_shape[0] != static_cast<size_t>(input_shape[1].get_length())) {
return false; return false;
} }
@ -137,7 +137,7 @@ public:
} }
// Update output channels mask dims // Update output channels mask dims
auto conv_mask = std::make_shared<Mask>(input_shape.size()); auto conv_mask = std::make_shared<Mask>(input_shape.rank().get_length());
conv_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool { conv_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask->at(0); cur_mask->at(1) = weights_mask->at(0);

View File

@ -51,7 +51,6 @@ bool ngraph::pass::Pruning::run_on_function(std::shared_ptr<Function> f) {
#endif #endif
manager.register_pass<ShrinkWeights>(); manager.register_pass<ShrinkWeights>();
manager.register_pass<ConstantFolding>();
#ifdef NGRAPH_DEBUG_ENABLE #ifdef NGRAPH_DEBUG_ENABLE
// Uncomment following line and change path to resulting svg file // Uncomment following line and change path to resulting svg file

View File

@ -11,6 +11,7 @@
#include <ngraph/pattern/op/wrap_type.hpp> #include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset6.hpp> #include <ngraph/opsets/opset6.hpp>
#include <ngraph/log.hpp> #include <ngraph/log.hpp>
#include <ngraph/ngraph.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ShrinkWeights, "ShrinkWeights", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ShrinkWeights, "ShrinkWeights", 0);
@ -62,14 +63,22 @@ bool ngraph::pass::ShrinkWeights::run_on_function(std::shared_ptr<ngraph::Functi
} }
} }
const auto & prev_shape = last_output.get_shape(); const auto & prev_shape = last_output.get_partial_shape();
const auto & prev_name = last_output.get_node()->get_friendly_name(); const auto & prev_name = last_output.get_node()->get_friendly_name();
last_output = std::make_shared<opset6::Gather>(last_output, last_output = std::make_shared<opset6::Gather>(last_output,
opset6::Constant::create(element::i64, Shape{dims_to_keep.size()}, dims_to_keep), opset6::Constant::create(element::i64, Shape{dims_to_keep.size()}, dims_to_keep),
opset6::Constant::create(element::i64, Shape{}, {dim})); opset6::Constant::create(element::i64, Shape{}, {dim}));
NGRAPH_DEBUG << "Transform(" << prev_name << "): " << prev_shape << " to " << last_output.get_shape(); NGRAPH_DEBUG << "Transform(" << prev_name << "): " << prev_shape << " to " << last_output.get_partial_shape();
reduced_weights_count += shape_size(prev_shape) - shape_size(last_output.get_shape()); if (prev_shape.is_static() && last_output.get_partial_shape().is_static()) {
reduced_weights_count += shape_size(prev_shape.get_shape()) - shape_size(last_output.get_shape());
} else {
NGRAPH_DEBUG << "[ WARNING ] Can not find the number of reduced elements due to dynamic shapes.";
}
}
// Trying to fold sequence of Gather ops to avoid additional constant folding.
if (auto folded_const = ngraph::get_constant_from_source(last_output)) {
last_output = folded_const;
} }
// as we insert Gather operations after Constant we need to reconnect all // as we insert Gather operations after Constant we need to reconnect all
// Constant consumers to the latest Gather. // Constant consumers to the latest Gather.

View File

@ -126,6 +126,66 @@ TEST(TransformationTests, PropagateMasksBasic) {
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
} }
TEST(TransformationTests, PropagateMasksDynamicConvolution) {
PartialShape input_shape{Dimension::dynamic(), 3, 64, 64};
Shape weights_shape{6, 3, 3, 3};
Shape weights_shape2{6, 6, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = opset5::Constant::create(element::f32, weights_shape, {0});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(relu, sub_const);
auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{2}, {}, {}});
auto mul = std::make_shared<opset5::Subtract>(sub, mul_const);
auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0});
auto conv2 = std::make_shared<opset5::Convolution>(mul, weights2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
compare_masks(*getMask(weights->output(0)), Mask({{2}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(sub_const), Mask({{2}, {}, {}}));
compare_masks(*getMask(mul_const), Mask({{2}, {}, {}}));
compare_masks(*getMask(weights2->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
}
TEST(TransformationTests, PropagateMasksDynamicGroupConvolution) {
PartialShape input_shape{Dimension::dynamic(), 3, 64, 64};
Shape weights_shape{3, 2, 1, 3, 3};
Shape weights_shape2{6, 1, 1, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = opset5::Constant::create(element::f32, weights_shape, {0});
auto conv = std::make_shared<opset5::GroupConvolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(relu, sub_const);
auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{2}, {}, {}});
auto mul = std::make_shared<opset5::Subtract>(sub, mul_const);
auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0});
auto conv2 = std::make_shared<opset5::GroupConvolution>(mul, weights2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
}
TEST(TransformationTests, PropagateMasksEmpty) { TEST(TransformationTests, PropagateMasksEmpty) {
Shape input_shape{1, 3, 64, 64}; Shape input_shape{1, 3, 64, 64};
Shape weights_shape{6, 3, 3, 3}; Shape weights_shape{6, 3, 3, 3};