diff --git a/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp b/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp index 7ce6a032abd..6016a162162 100644 --- a/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp +++ b/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp @@ -29,7 +29,7 @@ class ngraph::pass::mask_propagation::Convolution : public MatcherPass { public: Convolution() { auto input = pattern::any_input(); - auto weights = pattern::any_input(); + auto weights = pattern::any_input(pattern::has_static_shape()); auto conv = pattern::wrap_type({input, weights}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { @@ -92,8 +92,8 @@ public: class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass { public: GroupConvolution() { - auto input = pattern::any_input(); - auto weights = pattern::any_input(); + auto input = pattern::any_input(pattern::has_static_dim(1)); + auto weights = pattern::any_input(pattern::has_static_shape()); auto group_conv = pattern::wrap_type({input, weights}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { @@ -104,9 +104,9 @@ public: // TODO: check static rank in pattern, use only particular dims auto weights_shape = m_weights.get_shape(); - auto input_shape = m_input.get_shape(); + auto input_shape = m_input.get_partial_shape(); // support only depthwise convolutions - if (weights_shape[0] != input_shape[1]) { + if (weights_shape[0] != static_cast(input_shape[1].get_length())) { return false; } @@ -137,7 +137,7 @@ public: } // Update output channels mask dims - auto conv_mask = std::make_shared(input_shape.size()); + auto conv_mask = std::make_shared(input_shape.rank().get_length()); conv_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool { cur_mask->at(1) = weights_mask->at(0); diff --git a/inference-engine/src/offline_transformations/src/pruning/pruning.cpp b/inference-engine/src/offline_transformations/src/pruning/pruning.cpp index 4bcef7e5674..3159e3db7db 100644 --- a/inference-engine/src/offline_transformations/src/pruning/pruning.cpp +++ b/inference-engine/src/offline_transformations/src/pruning/pruning.cpp @@ -51,7 +51,6 @@ bool ngraph::pass::Pruning::run_on_function(std::shared_ptr f) { #endif manager.register_pass(); - manager.register_pass(); #ifdef NGRAPH_DEBUG_ENABLE // Uncomment following line and change path to resulting svg file diff --git a/inference-engine/src/offline_transformations/src/pruning/shrink_weights.cpp b/inference-engine/src/offline_transformations/src/pruning/shrink_weights.cpp index cc9b638e874..80c2abbb709 100644 --- a/inference-engine/src/offline_transformations/src/pruning/shrink_weights.cpp +++ b/inference-engine/src/offline_transformations/src/pruning/shrink_weights.cpp @@ -11,6 +11,7 @@ #include #include #include +#include NGRAPH_RTTI_DEFINITION(ngraph::pass::ShrinkWeights, "ShrinkWeights", 0); @@ -62,14 +63,22 @@ bool ngraph::pass::ShrinkWeights::run_on_function(std::shared_ptrget_friendly_name(); last_output = std::make_shared(last_output, opset6::Constant::create(element::i64, Shape{dims_to_keep.size()}, dims_to_keep), opset6::Constant::create(element::i64, Shape{}, {dim})); - NGRAPH_DEBUG << "Transform(" << prev_name << "): " << prev_shape << " to " << last_output.get_shape(); + NGRAPH_DEBUG << "Transform(" << prev_name << "): " << prev_shape << " to " << last_output.get_partial_shape(); - reduced_weights_count += shape_size(prev_shape) - shape_size(last_output.get_shape()); + if (prev_shape.is_static() && last_output.get_partial_shape().is_static()) { + reduced_weights_count += shape_size(prev_shape.get_shape()) - shape_size(last_output.get_shape()); + } else { + NGRAPH_DEBUG << "[ WARNING ] Can not find the number of reduced elements due to dynamic shapes."; + } + } + // Trying to fold sequence of Gather ops to avoid additional constant folding. + if (auto folded_const = ngraph::get_constant_from_source(last_output)) { + last_output = folded_const; } // as we insert Gather operations after Constant we need to reconnect all // Constant consumers to the latest Gather. diff --git a/inference-engine/tests/functional/inference_engine/transformations/pruning_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/pruning_test.cpp index 9464fa34f60..0f46a853cef 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/pruning_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/pruning_test.cpp @@ -126,6 +126,66 @@ TEST(TransformationTests, PropagateMasksBasic) { compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); } +TEST(TransformationTests, PropagateMasksDynamicConvolution) { + PartialShape input_shape{Dimension::dynamic(), 3, 64, 64}; + Shape weights_shape{6, 3, 3, 3}; + Shape weights_shape2{6, 6, 3, 3}; + auto input = std::make_shared(element::f32, input_shape); + auto weights = opset5::Constant::create(element::f32, weights_shape, {0}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto relu = std::make_shared(conv); + + auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}}); + auto sub = std::make_shared(relu, sub_const); + + auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{2}, {}, {}}); + auto mul = std::make_shared(sub, mul_const); + + auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0}); + auto conv2 = std::make_shared(mul, weights2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + + pass::Manager m; + m.register_pass(); + m.run_passes(f); + + compare_masks(*getMask(weights->output(0)), Mask({{2}, {}, {}, {}})); + compare_masks(*getMask(conv->output(0)), Mask({{}, {2}, {}, {}})); + compare_masks(*getMask(relu->output(0)), Mask({{}, {2}, {}, {}})); + compare_masks(*getMask(sub_const), Mask({{2}, {}, {}})); + compare_masks(*getMask(mul_const), Mask({{2}, {}, {}})); + compare_masks(*getMask(weights2->output(0)), Mask({{}, {2}, {}, {}})); + compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); +} + +TEST(TransformationTests, PropagateMasksDynamicGroupConvolution) { + PartialShape input_shape{Dimension::dynamic(), 3, 64, 64}; + Shape weights_shape{3, 2, 1, 3, 3}; + Shape weights_shape2{6, 1, 1, 3, 3}; + auto input = std::make_shared(element::f32, input_shape); + auto weights = opset5::Constant::create(element::f32, weights_shape, {0}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto relu = std::make_shared(conv); + + auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}}); + auto sub = std::make_shared(relu, sub_const); + + auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{2}, {}, {}}); + auto mul = std::make_shared(sub, mul_const); + + auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0}); + auto conv2 = std::make_shared(mul, weights2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + + pass::Manager m; + m.register_pass(); + m.run_passes(f); +} + TEST(TransformationTests, PropagateMasksEmpty) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{6, 3, 3, 3};