diff --git a/src/common/offline_transformations/include/mask_attribute.hpp b/src/common/offline_transformations/include/mask_attribute.hpp index 0d57808949b..338e32a620b 100644 --- a/src/common/offline_transformations/include/mask_attribute.hpp +++ b/src/common/offline_transformations/include/mask_attribute.hpp @@ -149,19 +149,24 @@ public: m_dependencies.push_back(mask.get()); } + /* Modify state of this mask by corresponding callback, + which returns modifying success status (bool) and then + modify all dependent masks by their corresponding callbacks*/ bool apply_callback(Mask::Ptr mask) { // TODO: in case if callback returns false we need to propagate original value const auto & ref_state = Mask(*this); + // Modify this mask by recived mask if (!m_callbacks.at(mask.get())(shared_from_this())) { return false; } - + // In case this mask already visited and didn't change by + // callback call - stop recursion if (!m_need_initialization && *this == ref_state) { return true; } - + // Mark mask as visited m_need_initialization = false; - + // recursively apply callbacks for each dependent mask for (const auto & m_dependency : m_dependencies) { if (!m_dependency->apply_callback(shared_from_this())) { return false; @@ -185,13 +190,21 @@ public: } } + /* Ask mask to update ther dependencies + even if mask value wasn't changed on callback*/ + void initialize_dependencies() { + m_need_initialization = true; + } + private: bool m_is_shape_like{false}; + // Masks dependent on this mask vs methods, specifying how + // this mask will be modifed by correspondent dependent mask std::map> m_callbacks; - + // Vector of all dependent masks std::vector m_dependencies; - + // Param used like visiting label (visited or not) during mask applying call bool m_need_initialization{true}; }; @@ -203,4 +216,14 @@ Mask::Ptr getMask(const Output & output); void setMask(Output output, const Mask::Ptr & mask); +void setMask(Input node, const Mask::Ptr & mask); + +#ifdef ENABLE_OPENVINO_DEBUG +/* Get mask which was defined on InitMasks matcher pass*/ +Mask::Ptr getInitMask(const Output & output); + +/* Set mask which was defined on InitMasks matcher pass*/ +void setInitMask(Output output, const Mask::Ptr & mask); +#endif + } // namespace ngraph diff --git a/src/common/offline_transformations/src/pruning/init_const_mask.cpp b/src/common/offline_transformations/src/pruning/init_const_mask.cpp index b27c28e772d..5e73c41188a 100644 --- a/src/common/offline_transformations/src/pruning/init_const_mask.cpp +++ b/src/common/offline_transformations/src/pruning/init_const_mask.cpp @@ -56,6 +56,9 @@ ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims, } setMask(const_node, mask); +#ifdef ENABLE_OPENVINO_DEBUG + setInitMask(const_node, mask); +#endif if (!mask->all_dims_are_empty()) { NGRAPH_DEBUG << "MASK (" << const_node->get_friendly_name() << ") " << *mask << std::endl; } diff --git a/src/common/offline_transformations/src/pruning/mask_attribute.cpp b/src/common/offline_transformations/src/pruning/mask_attribute.cpp index 90f79e049f9..fd1e353da34 100644 --- a/src/common/offline_transformations/src/pruning/mask_attribute.cpp +++ b/src/common/offline_transformations/src/pruning/mask_attribute.cpp @@ -14,16 +14,21 @@ namespace ngraph { Mask::Ptr getMask(const Output & output) { auto &rtInfo = output.get_rt_info(); - if (!rtInfo.count(Mask::get_type_info_static())) return nullptr; - const auto &attr = rtInfo.at(Mask::get_type_info_static()); + const auto attr_it = rtInfo.find(Mask::get_type_info_static()); + if (attr_it == rtInfo.end()) return nullptr; + + const auto &attr = attr_it->second; return attr.as(); } Mask::Ptr getMask(const Output & output) { auto &rtInfo = output.get_rt_info(); - if (!rtInfo.count(Mask::get_type_info_static())) return nullptr; - const auto &attr = rtInfo.at(Mask::get_type_info_static()); + + const auto attr_it = rtInfo.find(Mask::get_type_info_static()); + if (attr_it == rtInfo.end()) return nullptr; + + const auto &attr = attr_it->second; return attr.as(); } @@ -32,6 +37,31 @@ void setMask(Output output, const Mask::Ptr & mask) { rtInfo[Mask::get_type_info_static()] = mask; } +void setMask(Input node, const Mask::Ptr & mask) { + auto &rtInfo = node.get_rt_info(); + rtInfo[Mask::get_type_info_static()] = mask; +} + +#ifdef ENABLE_OPENVINO_DEBUG +static const char g_init_mask_key[] = "InitMask"; +Mask::Ptr getInitMask(const Output & output) { + auto &rtInfo = output.get_rt_info(); + + const auto attr_it = rtInfo.find(g_init_mask_key); + if (attr_it == rtInfo.end()) return nullptr; + + const auto &attr = attr_it->second; + return attr.as(); +} + +void setInitMask(Output output, const Mask::Ptr & mask) { + auto &rtInfo = output.get_rt_info(); + auto copy_mask = std::make_shared(); + std::copy(mask->begin(), mask->end(), std::back_inserter(*copy_mask)); + rtInfo[g_init_mask_key] = copy_mask; +} +#endif + std::ostream & operator<< (std::ostream & out, const Mask & mask) { out << "[ "; for (auto & dim : mask) { diff --git a/src/common/offline_transformations/src/pruning/propagate_masks.cpp b/src/common/offline_transformations/src/pruning/propagate_masks.cpp index 765f51c23bc..f386f12c5c6 100644 --- a/src/common/offline_transformations/src/pruning/propagate_masks.cpp +++ b/src/common/offline_transformations/src/pruning/propagate_masks.cpp @@ -5,6 +5,8 @@ #include "pruning.hpp" #include "mask_attribute.hpp" +#include + #include #include #include @@ -22,6 +24,7 @@ class GroupConvolution; class GroupConvolutionReshape; class Elementwise; class PassThrough; +class Reduce; class StopPropagation; class FakeQuantize; class Concat; @@ -228,7 +231,11 @@ public: return false; } auto input_mask_row = input_mask.get(); - auto output_mask = std::make_shared(m_output.get_partial_shape().rank().get_length()); + // Check reshape mask already initialized during StopPropagation pass + auto output_mask = getMask(m_output); + if (!output_mask) + output_mask = std::make_shared(m_output.get_partial_shape().rank().get_length()); + auto output_mask_row = output_mask.get(); // Depthwise Convolution pruned only by input channels (== groups) -> @@ -301,7 +308,6 @@ public: } InitConstMask({0, 1}).apply(m_weights.get_node_shared_ptr()); - auto weights_mask = getMask(m_weights); if (!weights_mask) { NGRAPH_DEBUG << "No weights mask for: " << m_output.get_node()->get_friendly_name() << std::endl; @@ -561,9 +567,12 @@ public: class ngraph::pass::mask_propagation::PassThrough : public MatcherPass { public: PassThrough() { - auto unary_op = pattern::wrap_type(); + auto unary_op = pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { const auto & pattern_map = m.get_pattern_value_map(); @@ -582,19 +591,92 @@ public: } }; +class ngraph::pass::mask_propagation::Reduce : public MatcherPass { +public: + Reduce() { + auto inputs = pattern::any_input(); + auto weights = pattern::wrap_type(); + auto pooling_by_reduce = pattern::wrap_type({inputs, weights}); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto & pattern_map = m.get_pattern_value_map(); + const auto m_weights = pattern_map.at(weights); + const auto & m_input = pattern_map.at(inputs); + const auto & m_output = pattern_map.at(pooling_by_reduce); + + + // Check reduce operation reduces only dimension without masks + if (auto input_mask = getMask(m_input)) { + auto output_mask = std::make_shared(m_output.get_partial_shape().rank().get_length()); + const auto constant = std::dynamic_pointer_cast(m_weights.get_node_shared_ptr()); + const auto reduce_dims = constant->cast_vector(); + + auto input_mask_row = input_mask.get(); + auto output_mask_row = output_mask.get(); + input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->copy_value_from_mask(output_mask_row); + return true; + }, output_mask); + output_mask->add_callback([input_mask_row, reduce_dims](Mask::Ptr cur_mask) -> bool{ + // Propagate masks through dimension only if this dimension isn't reduced + for (size_t dim = 0; dim < std::min(cur_mask->size(), input_mask_row->size()); ++dim) + if (std::find(reduce_dims.begin(), reduce_dims.end(), dim) == reduce_dims.end()) + cur_mask->at(dim) = input_mask_row->at(dim); + else if (cur_mask->at(dim) != input_mask_row->at(dim)) + cur_mask->initialize_dependencies(); + return true; + }, input_mask); + + // Invalidate current mask and its parent masks + output_mask->apply_callback(input_mask); + setMask(m_output, output_mask); + } + + return true; + }; + + auto m = std::make_shared(pooling_by_reduce, "PassThroughReduceMaskPropagation"); + register_matcher(m, callback); + } +}; + class ngraph::pass::mask_propagation::StopPropagation : public MatcherPass { public: StopPropagation() { auto any_node = pattern::any_input(); - ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto & pattern_map = m.get_pattern_value_map(); + const auto & m_output = pattern_map.at(any_node); const auto & node = m.get_match_root(); + + auto output_mask = std::make_shared(m_output.get_partial_shape().rank().get_length()); + bool any_input_with_masks = false; for (const auto & input : node->input_values()) { - if (auto mask = getMask(input)) { - // Invalidate current mask and its parent masks - mask->invalidate(); - NGRAPH_DEBUG << "Invalidate masks for " << *input.get_node() << " because " << node << " is unknown\n"; + if (auto input_mask = getMask(input)) { + auto input_mask_row = input_mask.get(); + input_mask->add_callback([](Mask::Ptr cur_mask) -> bool { + cur_mask->clean_dim_values(); + return true; + }, output_mask); + output_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool{ + cur_mask->copy_value_from_mask(input_mask_row); + return true; + }, input_mask); + + // Invalidate current mask and its parent masks + output_mask->apply_callback(input_mask); + NGRAPH_DEBUG << "Invalidate masks for " << *input.get_node() << " because " << node << " is in scope of stop ops.\n"; + any_input_with_masks = true; + } } + if (any_input_with_masks) { + // Set mask to stop op first input tensor to prevent mask rewriting for + // nodes which share output tensor with previous node. + if (ngraph::is_type(m_output.get_node_shared_ptr())) + setMask(*m_output.get_node()->inputs().begin(), output_mask); + else + setMask(m_output, output_mask); } return true; }; @@ -610,6 +692,7 @@ ngraph::pass::PropagateMasks::PropagateMasks() { add_matcher(); add_matcher(); add_matcher(); + add_matcher(); add_matcher(); add_matcher(); add_matcher(); diff --git a/src/common/offline_transformations/src/pruning/shrink_weights.cpp b/src/common/offline_transformations/src/pruning/shrink_weights.cpp index 1a02133c7dd..290f0582197 100644 --- a/src/common/offline_transformations/src/pruning/shrink_weights.cpp +++ b/src/common/offline_transformations/src/pruning/shrink_weights.cpp @@ -31,6 +31,25 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptroutput(0)); + +#ifdef ENABLE_OPENVINO_DEBUG + auto init_mask = getInitMask(const_node->output(0)); + if (!mask && init_mask) + NGRAPH_DEBUG << "Mask was ruined for node:" << const_node->get_friendly_name() << "\nInit mask: " << *init_mask; + if (mask && init_mask) { + for (size_t dim = 0; dim < init_mask->size(); ++dim) { + auto& dim_init_set = (*init_mask)[dim]; + auto& dim_current_set = (*mask)[dim]; + if (!dim_init_set.empty() && !std::includes(dim_current_set.begin(), dim_current_set.end(), + dim_init_set.begin(), dim_init_set.end())) { + NGRAPH_DEBUG << "Mask was ruined for node:" << const_node->get_friendly_name() + << "\nInit mask: " << *init_mask << "\nCurrent mask: " << *mask; + break; + } + } + } +#endif + if (!mask) continue; auto last_output = const_node->output(0); diff --git a/src/inference/src/file_utils.cpp b/src/inference/src/file_utils.cpp index a150e1346d7..460e1991a82 100644 --- a/src/inference/src/file_utils.cpp +++ b/src/inference/src/file_utils.cpp @@ -2,6 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 // +#ifndef FILE_UTILS_CPP +#define FILE_UTILS_CPP + #include #include #include @@ -130,3 +133,5 @@ std::string getIELibraryPath() { } } // namespace InferenceEngine + +#endif diff --git a/src/tests/functional/inference_engine/transformations/pruning_test.cpp b/src/tests/functional/inference_engine/transformations/pruning_test.cpp index d87a9271f8e..4b0890ea1cf 100644 --- a/src/tests/functional/inference_engine/transformations/pruning_test.cpp +++ b/src/tests/functional/inference_engine/transformations/pruning_test.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -17,6 +18,12 @@ #include #include + +#include "common_test_utils/ngraph_test_utils.hpp" + +#define VISUALIZE_TESTS_TREE false +#define VISUALIZE_TREE_ROOT "/tmp/" + using namespace testing; using namespace ngraph; @@ -46,6 +53,7 @@ Output create_constant_with_zeros(const Shape & shape, const Mask & mask) return std::make_shared(element::f32, shape, values); } + TEST(TransformationTests, InitMasksOI) { Shape weights_shape{6, 3, 3, 3}; auto weights = opset5::Constant::create(element::f32, weights_shape, {0}); @@ -54,10 +62,10 @@ TEST(TransformationTests, InitMasksOI) { compare_masks(*getMask(weights->output(0)), {{0, 1, 2, 3, 4, 5}, {0, 1, 2}, {}, {}}); } + TEST(TransformationTests, InitMasksOutputChannel) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{6, 3, 3, 3}; - std::vector values(shape_size(weights_shape), 1); NGRAPH_SUPPRESS_DEPRECATED_START CoordinateTransform iter(weights_shape, {0, 1, 0, 0}, {6, 2, 3, 3}); @@ -72,6 +80,7 @@ TEST(TransformationTests, InitMasksOutputChannel) { compare_masks(*getMask(weights->output(0)), {{}, {1}, {}, {}}); } + // TODO: add test init masks with subgraph TEST(TransformationTests, TestInitMasks) { Shape weights_shape{6, 3, 3, 3}; @@ -89,6 +98,7 @@ TEST(TransformationTests, TestInitMasks) { compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), {{1, 2, 3}, {}, {}, {}}); } + TEST(TransformationTests, InitMasksNegative) { Shape weights_shape{6, 3, 3, 3}; auto weights = opset5::Constant::create(element::f32, weights_shape, {0.5}); @@ -97,6 +107,7 @@ TEST(TransformationTests, InitMasksNegative) { compare_masks(*getMask(weights->output(0)), {{}, {}, {}, {}}); } + TEST(TransformationTests, PropagateMasksNegative) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{6, 3, 3, 3}; @@ -115,7 +126,8 @@ TEST(TransformationTests, PropagateMasksNegative) { compare_masks(*getMask(conv->output(0)), {{}, {}, {}, {}}); } -TEST(TransformationTests, PropagateMasksBasic) { + +TEST_F(TransformationTestsF, PropagateMasksBasic) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{6, 3, 3, 3}; Shape weights_shape2{6, 6, 3, 3}; @@ -132,18 +144,43 @@ TEST(TransformationTests, PropagateMasksBasic) { auto sub = std::make_shared(add, sub_const); auto mul_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {4}, {}, {}}); - auto mul = std::make_shared(sub, mul_const); + auto mul = std::make_shared(sub, mul_const); auto weights2 = create_constant_with_zeros(weights_shape2, {{1, 2}, {1, 2, 3}, {}, {}}); auto conv2 = std::make_shared(mul, weights2, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + function = std::make_shared(NodeVector{conv2}, ParameterVector{input}); - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); + { + auto input = std::make_shared(element::f32, input_shape); + auto weights = opset5::Constant::create(element::f32, {weights_shape[0] - 4, weights_shape[1], weights_shape[2] , weights_shape[3]}, {0}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto relu = std::make_shared(conv); + + auto add_const = opset5::Constant::create(element::f32, Shape{1, 2, 1, 1}, {1}); + auto add = std::make_shared(relu, add_const); + + auto sub_const = opset5::Constant::create(element::f32, Shape{2, 1, 1}, {1}); + auto sub = std::make_shared(add, sub_const); + + auto mul_const = opset5::Constant::create(element::f32, Shape{1, 2, 1, 1}, {1}); + auto mul = std::make_shared(sub, mul_const); + + auto weights2 = opset5::Constant::create(element::f32, {weights_shape2[0], weights_shape2[1] - 4, weights_shape2[2], weights_shape2[3]}, {1}); + auto conv2 = std::make_shared(mul, weights2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function_ref = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksBasic.svg").run_on_function(function); + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } compare_masks(*getMask(weights->output(0)), Mask({{1, 2, 3, 4}, {}, {}, {}})); compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}})); compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}})); @@ -155,9 +192,17 @@ TEST(TransformationTests, PropagateMasksBasic) { compare_masks(*getMask(mul->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}})); compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}})); compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } -TEST(TransformationTests, PropagateMasksDynamicConvolution) { + +TEST_F(TransformationTestsF, PropagateMasksDynamicConvolution) { PartialShape input_shape{Dimension::dynamic(), 3, 64, 64}; Shape weights_shape{6, 3, 3, 3}; Shape weights_shape2{6, 6, 3, 3}; @@ -176,12 +221,34 @@ TEST(TransformationTests, PropagateMasksDynamicConvolution) { auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0}); auto conv2 = std::make_shared(mul, weights2, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + function = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + { + auto input = std::make_shared(element::f32, input_shape); + auto weights = opset5::Constant::create(element::f32, {weights_shape[0] - 1, weights_shape[1], weights_shape[2], weights_shape[3]}, {0}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto relu = std::make_shared(conv); - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); + auto sub_const = create_constant_with_zeros(Shape{5, 1, 1}, {{}, {}, {}}); + auto sub = std::make_shared(relu, sub_const); + + auto mul_const = create_constant_with_zeros(Shape{5, 1, 1}, {{2}, {}, {}}); + auto mul = std::make_shared(sub, mul_const); + + auto weights2 = opset5::Constant::create(element::f32, {weights_shape2[0], weights_shape2[1] - 1, weights_shape2[2], weights_shape2[3]}, {0}); + auto conv2 = std::make_shared(mul, weights2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function_ref = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksDynamicConvolution.svg").run_on_function(function); + + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } compare_masks(*getMask(weights->output(0)), Mask({{2}, {}, {}, {}})); compare_masks(*getMask(conv->output(0)), Mask({{}, {2}, {}, {}})); @@ -190,8 +257,16 @@ TEST(TransformationTests, PropagateMasksDynamicConvolution) { compare_masks(*getMask(mul_const), Mask({{2}, {}, {}})); compare_masks(*getMask(weights2->output(0)), Mask({{}, {2}, {}, {}})); compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } + TEST(TransformationTests, PropagateMasksDynamicGroupConvolution) { PartialShape input_shape{Dimension::dynamic(), 3, 64, 64}; Shape weights_shape{3, 2, 1, 3, 3}; @@ -213,12 +288,16 @@ TEST(TransformationTests, PropagateMasksDynamicGroupConvolution) { CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksDynamicGroupConvolution.svg").run_on_function(f); + pass::Manager m; m.register_pass(); m.register_pass(); m.run_passes(f); } + TEST(TransformationTests, PropagateMasksEmpty) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{6, 3, 3, 3}; @@ -240,6 +319,9 @@ TEST(TransformationTests, PropagateMasksEmpty) { CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksEmpty.svg").run_on_function(f); + pass::Manager m; m.register_pass(); m.register_pass(); @@ -254,7 +336,8 @@ TEST(TransformationTests, PropagateMasksEmpty) { compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); } -TEST(TransformationTests, PropagateMaskPassThrough) { + +TEST_F(TransformationTestsF, PropagateMaskPassThrough) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{8, 3, 3, 3}; Shape weight_shape2{3, 8, 3, 3}; @@ -284,20 +367,53 @@ TEST(TransformationTests, PropagateMaskPassThrough) { auto weights2 = opset5::Constant::create(element::f32, weight_shape2, {0}); auto conv2 = std::make_shared(max_pool, weights2, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + function = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + { + auto input = std::make_shared(element::f32, input_shape); + auto weights_const_1 = create_constant_with_zeros({weights_shape[0] - 3, weights_shape[1], weights_shape[2], weights_shape[3]} , {{}, {}, {}, {}}); + weights_const_1.get_node_shared_ptr()->set_friendly_name("weights_1"); - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); + auto conv_1 = std::make_shared(input, weights_const_1, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + // Adding a couple of PassThrough operations + auto relu = std::make_shared(conv_1); + auto clamp = std::make_shared(relu, 0, 6); + + auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1}); + auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2}); + auto pad = std::make_shared(clamp, pads_begin, pads_end, op::PadMode::CONSTANT); + auto max_pool = std::make_shared(pad, Strides{1, 1}, + Shape{0, 0}, Shape{1, 1}, Shape{4, 4}); + + auto weights2 = opset5::Constant::create(element::f32, {weight_shape2[0], weight_shape2[1] - 3, weight_shape2[2], weight_shape2[3]}, {0}); + auto conv2 = std::make_shared(max_pool, weights2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function_ref = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMaskPassThrough.svg").run_on_function(function); + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } compare_masks(*getMask(weights_const_1.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}})); compare_masks(*getMask(conv_1->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); compare_masks(*getMask(clamp->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); compare_masks(*getMask(max_pool->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } + TEST(TransformationTests, PropagateMasksHardDependencies) { Shape input_shape{1, 3, 3, 3}; @@ -348,19 +464,34 @@ TEST(TransformationTests, PropagateMasksHardDependencies) { auto f = std::make_shared(NodeVector{matmul, conv3}, ParameterVector{input1, input2}); + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksHardDependencies.svg").run_on_function(f); + pass::Manager m; m.register_pass(); m.run_passes(f); + compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv1->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(add1->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(add2->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(matmul->output(0)), Mask({{}, {}})); + // TODO: add checks after MatMul/Reshape/Pooling mask propagation is ready -// compare_masks(*getMask(weights), Mask({{0, 1, 2, 3, 4, 5}, {}, {}, {}})); -// compare_masks(*getMask(conv), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}})); -// compare_masks(*getMask(relu), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}})); -// compare_masks(*getMask(weights2), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}})); -// compare_masks(*getMask(conv2), Mask({{}, {}, {}, {}})); + //compare_masks(*getMask(weights), Mask({{0, 1, 2, 3, 4, 5}, {}, {}, {}})); + //compare_masks(*getMask(conv), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}})); + //compare_masks(*getMask(relu), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}})); + //compare_masks(*getMask(weights2), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}})); + //compare_masks(*getMask(conv2), Mask({{}, {}, {}, {}})); } -TEST(TransformationTests, PropagateMasksQuantizedGroupConvolution) { + +TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{8, 3, 3, 3}; Shape weights_group_shape{8, 1, 3, 3}; @@ -398,30 +529,79 @@ TEST(TransformationTests, PropagateMasksQuantizedGroupConvolution) { auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0}); auto conv2 = std::make_shared(add, weights_2, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + function = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + { + auto input = std::make_shared(element::f32, input_shape); - pass::Manager m; - m.register_pass(); - m.run_passes(f); + auto weights1 = create_constant_with_zeros({weights_shape[0] - 4, weights_shape[1], weights_shape[2], weights_shape[3]}, {{}, {}, {}, {}}); + auto conv1 = std::make_shared(input, weights1, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto weights_group = opset5::Constant::create(element::i8, + { + weights_group_shape[0] - 4, + weights_group_shape[1], + weights_group_shape[2], + weights_group_shape[3] + }, {0}); - compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}})); + auto convert = std::make_shared(weights_group, element::f32); + + auto sub_const = create_constant_with_zeros(Shape{4, 1, 1, 1}, {{}, {}, {}, {}}); + + auto sub = std::make_shared(convert, sub_const); + + auto mul_const = create_constant_with_zeros(Shape{4, 1, 1, 1}, {{}, {}, {}, {}}); + auto mul = std::make_shared(sub, mul_const); + + auto reshape = std::make_shared(mul, opset5::Constant::create(element::i64, Shape{5}, {4, 1, 1, 3, 3}), false); + + auto conv_group = std::make_shared(conv1, reshape, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + + auto add_const = create_constant_with_zeros(Shape{1, 4, 1, 1}, {{}, {}, {}, {}});; + auto add = std::make_shared(conv_group, add_const); + + auto weights_2 = opset5::Constant::create(element::f32, {weight_shape2[0], weight_shape2[1] - 4, weight_shape2[2], weight_shape2[3]}, {0}); + auto conv2 = std::make_shared(add, weights_2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function_ref = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksQuantizedGroupConvolution.svg").run_on_function(function); + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } + + compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}})); - compare_masks(*getMask(weights_group->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}})); - compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}})); - compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}})); - compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}})); - compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(weights_group->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(sub->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(mul->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); - compare_masks(*getMask(reshape->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}, {}})); + compare_masks(*getMask(reshape->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}, {}})); - compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}})); + compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}})); compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } -TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor) { + +TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{8, 3, 3, 3}; Shape weight_shape2{3, 8, 3, 3}; @@ -459,11 +639,59 @@ TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor) { auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0}); auto conv2 = std::make_shared(fq, weights_2, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + function = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + { + auto input = std::make_shared(element::f32, input_shape); + auto weights_1 = opset5::Constant::create(element::i8, { + weights_shape[0] - 5, + weights_shape[1], + weights_shape[2], + weights_shape[3], + }, {0}); - pass::Manager m; - m.register_pass(); - m.run_passes(f); + auto convert = std::make_shared(weights_1, element::f32); + + auto sub_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}}); + + auto sub = std::make_shared(convert, sub_const); + + auto mul_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}}); + auto mul = std::make_shared(sub, mul_const); + + auto conv1 = std::make_shared(input, mul, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + + auto add_const = create_constant_with_zeros(Shape{1, 3, 1, 1}, {{}, {}, {}, {}});; + auto add = std::make_shared(conv1, add_const); + + auto input_low = opset5::Constant::create(element::f32, Shape{1}, {0}); + auto input_high = opset5::Constant::create(element::f32, Shape{1, 1, 1, 1}, {20}); + auto output_low = opset5::Constant::create(element::f32, Shape{}, {1}); + auto output_high = opset5::Constant::create(element::f32, Shape{}, {10}); + auto fq = std::make_shared(add, input_low, input_high, output_low, output_high, 8); + + auto weights_2 = opset5::Constant::create(element::f32, { + weight_shape2[0], + weight_shape2[1] - 5, + weight_shape2[2], + weight_shape2[3], + }, {0}); + auto conv2 = std::make_shared(fq, weights_2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function_ref = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksFakeQuantizePerTensor.svg").run_on_function(function); + + { + pass::Manager m; + // Masks for fq input parammeters didn't saved after + // ShrinkWeights pass so pruning transformation is splitted + // on propagation and shrinking passes + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } pass::Manager m; compare_masks(*getMask(weights_1->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}})); compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}})); @@ -481,9 +709,17 @@ TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor) { compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } -TEST(TransformationTests, PropagateMasksFakeQuantizePerChannel) { + +TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape{8, 3, 3, 3}; Shape weight_shape2{3, 8, 3, 3}; @@ -521,37 +757,90 @@ TEST(TransformationTests, PropagateMasksFakeQuantizePerChannel) { auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0}); auto conv2 = std::make_shared(fq, weights_2, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + function = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + { + auto input = std::make_shared(element::f32, input_shape); + auto weights_1 = opset5::Constant::create(element::i8, { + weights_shape[0] - 5, + weights_shape[1], + weights_shape[2], + weights_shape[3] + }, {0}); - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); + auto convert = std::make_shared(weights_1, element::f32); - compare_masks(*getMask(weights_1->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}})); + auto sub_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}}); + + auto sub = std::make_shared(convert, sub_const); + + auto mul_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}}); + auto mul = std::make_shared(sub, mul_const); + + auto conv1 = std::make_shared(input, mul, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + + auto add_const = create_constant_with_zeros(Shape{1, 3, 1, 1}, {{}, {}, {}, {}});; + auto add = std::make_shared(conv1, add_const); + + auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {0}); + auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20}); + auto output_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1}); + auto output_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {10}); + auto fq = std::make_shared(add, input_low, input_high, output_low, output_high, 8); + + auto weights_2 = opset5::Constant::create(element::f32, { + weight_shape2[0], + weight_shape2[1] - 5, + weight_shape2[2], + weight_shape2[3] + } , {0}); + auto conv2 = std::make_shared(fq, weights_2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function_ref = std::make_shared(NodeVector{conv2}, ParameterVector{input}); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksFakeQuantizePerChannel.svg").run_on_function(function); + { + pass::Manager m; + // Masks for fq input parammeters didn't saved after + // ShrinkWeights pass so pruning transformation is splitted + // on propagation and shrinking passes + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } + compare_masks(*getMask(weights_1->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}})); compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}})); - compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}})); + compare_masks(*getMask(sub->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}})); compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}})); - compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}})); + compare_masks(*getMask(mul->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}})); - compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(conv1->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); - compare_masks(*getMask(add->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(add->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); - compare_masks(*getMask(fq->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(fq->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); - compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}})); - compare_masks(*getMask(fq->input(1).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); - compare_masks(*getMask(fq->input(2).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); - compare_masks(*getMask(fq->input(3).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); - compare_masks(*getMask(fq->input(4).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(fq->input(1).get_source_output()), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(fq->input(2).get_source_output()), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(fq->input(3).get_source_output()), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); + compare_masks(*getMask(fq->input(4).get_source_output()), Mask({{}, {0, 1, 2, 3, 4}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } -TEST(TransformationTests, TestConcatMaskPropagation) { + +TEST_F(TransformationTestsF, TestConcatMaskPropagation) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape1{8, 3, 3, 3}; Shape weights_shape2{16, 3, 3, 3}; @@ -576,14 +865,57 @@ TEST(TransformationTests, TestConcatMaskPropagation) { auto weights_out_conv = create_constant_with_zeros(weight_shape_out_conv, {{}, {}, {}, {}}); auto conv_out = std::make_shared(concat, weights_out_conv, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function = std::make_shared(NodeVector{conv_out}, ParameterVector{input}); + { + auto input = std::make_shared(element::f32, input_shape); + auto weights_1 = create_constant_with_zeros({ + weights_shape1[0] - 4, + weights_shape1[1], + weights_shape1[2], + weights_shape1[3] + }, {{}, {}, {}, {}}); + auto conv1 = std::make_shared(input, weights_1, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv_out}, ParameterVector{input}); + auto weights_2 = create_constant_with_zeros({ + weights_shape2[0] - 4, + weights_shape2[1], + weights_shape2[2], + weights_shape2[3], + }, {{}, {}, {}, {}}); + auto conv2 = std::make_shared(input, weights_2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); + auto weights_3 = create_constant_with_zeros({ + weights_shape3[0] - 4, + weights_shape3[1], + weights_shape3[2], + weights_shape3[3], + }, {{}, {}, {}, {}}); + auto conv3 = std::make_shared(input, weights_3, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto concat = std::make_shared(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1); + + auto weights_out_conv = create_constant_with_zeros({ + weight_shape_out_conv[0], + weight_shape_out_conv[1] - 12, + weight_shape_out_conv[2], + weight_shape_out_conv[3], + }, {{}, {}, {}, {}}); + auto conv_out = std::make_shared(concat, weights_out_conv, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function_ref = std::make_shared(NodeVector{conv_out}, ParameterVector{input}); + } + + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "TestConcatMaskPropagation.svg").run_on_function(function); + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); compare_masks(*getMask(conv1->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}})); @@ -595,10 +927,17 @@ TEST(TransformationTests, TestConcatMaskPropagation) { compare_masks(*getMask(concat->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}})); compare_masks(*getMask(weights_out_conv.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } -TEST(TransformationTests, TestConcatMaskPropagationUp) { +TEST_F(TransformationTestsF, TestConcatMaskPropagationUp) { Shape input_shape{1, 3, 64, 64}; Shape weights_shape1{8, 3, 3, 3}; Shape weights_shape2{16, 3, 3, 3}; @@ -626,14 +965,60 @@ TEST(TransformationTests, TestConcatMaskPropagationUp) { auto weights_out_conv = create_constant_with_zeros(weight_shape_out_conv, {{}, {}, {}, {}}); auto conv_out = std::make_shared(add, weights_out_conv, Strides(2, 1), CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + function = std::make_shared(NodeVector{conv_out}, ParameterVector{input}); + { + auto input = std::make_shared(element::f32, input_shape); + auto weights_1 = create_constant_with_zeros({ + weights_shape1[0] - 4, + weights_shape1[1], + weights_shape1[2], + weights_shape1[3], + }, {{}, {}, {}, {}}); + auto conv1 = std::make_shared(input, weights_1, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - auto f = std::make_shared(NodeVector{conv_out}, ParameterVector{input}); + auto weights_2 = create_constant_with_zeros({ + weights_shape2[0] - 4, + weights_shape2[1], + weights_shape2[2], + weights_shape2[3], + }, {{}, {}, {}, {}}); + auto conv2 = std::make_shared(input, weights_2, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); + auto weights_3 = create_constant_with_zeros({ + weights_shape3[0] - 4, + weights_shape3[1], + weights_shape3[2], + weights_shape3[3], + }, {{}, {}, {}, {}}); + auto conv3 = std::make_shared(input, weights_3, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + auto concat = std::make_shared(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1); + + auto add_const = create_constant_with_zeros(Shape{1, 20, 1, 1}, {{}, {}, {}, {}}); + auto add = std::make_shared(concat, add_const); + + auto weights_out_conv = create_constant_with_zeros({ + weight_shape_out_conv[0], + weight_shape_out_conv[1] - 12, + weight_shape_out_conv[2], + weight_shape_out_conv[3], + }, {{}, {}, {}, {}}); + auto conv_out = std::make_shared(add, weights_out_conv, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + + function_ref = std::make_shared(NodeVector{conv_out}, ParameterVector{input}); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "TestConcatMaskPropagationUp.svg").run_on_function(function); + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}})); compare_masks(*getMask(conv1->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}})); @@ -649,6 +1034,13 @@ TEST(TransformationTests, TestConcatMaskPropagationUp) { compare_masks(*getMask(concat->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}})); compare_masks(*getMask(weights_out_conv.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); } @@ -679,6 +1071,9 @@ TEST(TransformationTests, TestConcatMaskPropagationUpEmpty) { auto f = std::make_shared(NodeVector{add}, ParameterVector{input}); + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "TestConcatMaskPropagationUpEmpty.svg").run_on_function(f); + pass::Manager m; m.register_pass(); m.register_pass(); @@ -699,3 +1094,405 @@ TEST(TransformationTests, TestConcatMaskPropagationUpEmpty) { compare_masks(*getMask(concat->output(0)), Mask({{}, {}, {}, {}})); } + + +TEST_F(TransformationTestsF, PruneConvIsClosingAndInGroup) { + auto inputShapes = PartialShape{1, 6, 16, 16}; + auto weightsShape = Shape{6, 6, 1, 1}; + + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto add_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {1, 2, 3, 4, 5}, {}, {}}); + auto add = std::make_shared(conv, add_const); + + auto conv_1_shape = Shape{weightsShape[0], weightsShape[0], 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{1, 2, 3}, {}, {}, {}}); + auto conv_1 = std::make_shared(add, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto add_1 = std::make_shared(conv_1, conv); + + auto end_conv_shape = Shape{weightsShape[1], weightsShape[0], 1, 1}; + auto weights_end_conv = create_constant_with_zeros(end_conv_shape, {{1, 2, 3}, {}, {}, {}}); + auto end_conv = std::make_shared(add_1, weights_end_conv, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + function = std::make_shared(OutputVector{end_conv}, ParameterVector{input}, "ReshapeMulBranching"); + + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneConvIsClosingAndInGroup.svg").run_on_function(function); + { + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros({ + weightsShape[0] - 3, + weightsShape[1], + weightsShape[2], + weightsShape[3], + }, {{}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto add_const = create_constant_with_zeros(Shape{1, 3, 1, 1}, {{}, {}, {}, {}}); + auto add = std::make_shared(conv, add_const); + + auto conv_1_shape = Shape{weightsShape[0] - 3, weightsShape[0] - 3, 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{}, {}, {}, {}}); + auto conv_1 = std::make_shared(add, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + + auto add_1 = std::make_shared(conv_1, conv); + + auto end_conv_shape = Shape{weightsShape[1], weightsShape[0] - 3, 1, 1}; + auto weights_end_conv = create_constant_with_zeros(end_conv_shape, {{}, {}, {}, {}}); + auto end_conv = std::make_shared(add_1, weights_end_conv, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + function_ref = std::make_shared(OutputVector{end_conv}, ParameterVector{input}, "ReshapeMulBranching"); + } + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } + compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + + compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(conv_1->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + + compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(add->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(add_1->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + + compare_masks(*getMask(weights_end_conv.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(end_conv->output(0)), Mask({{}, {}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); +} + + +TEST(TransformationTests, PruneBranchingStopOp) { + // Checks case of branching with stop op + auto inputShapes = PartialShape{1, 6, 16, 16}; + auto weightsShape = Shape{6, 6, 1, 1}; + + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + // Branching stop op + Shape group_conv_weights_shape{3, 2, 2, 1, 1}; + auto group_conv_weights = opset5::Constant::create(element::f32, group_conv_weights_shape, {0}); + auto group_conv = std::make_shared(conv, group_conv_weights, Strides(2, 1), + CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1)); + + auto conv_1_shape = Shape{weightsShape[0], 6, 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{1, 2, 3}, {}, {}, {}}); + auto conv_1 = std::make_shared(group_conv, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + // Multiply will try to propagate a non zero masks of the conv_1 up + // and the mask should be invalidated by group conv stop op mask + auto mul = std::make_shared(conv_1, conv); + + auto end_conv_shape = Shape{weightsShape[1], weightsShape[0], 1, 1}; + auto weights_end_conv = create_constant_with_zeros(end_conv_shape, {{1, 2, 3}, {}, {}, {}}); + auto end_conv = std::make_shared(mul, weights_end_conv, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto function = std::make_shared(OutputVector{end_conv}, ParameterVector{input}, "RestrictedReduceMeanBranching"); + + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneBranchingStopOp.svg").run_on_function(function); + + pass::Manager m; + m.register_pass(); + m.run_passes(function); + + compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(weights_end_conv.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(end_conv->output(0)), Mask({{}, {}, {}, {}})); +} + + +TEST_F(TransformationTestsF, PruneReducelayerUp) { + auto inputShapes = PartialShape{1, 6, 16, 16}; + auto weightsShape = Shape{6, 6, 1, 1}; + + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto reduce_const = opset5::Constant::create(element::i64, Shape{2}, {2, 3}); + auto reduce_mean = std::make_shared(conv, reduce_const, true); + + auto conv_1_shape = Shape{12, 6, 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{1, 2, 3}, {}, {}, {}}); + auto conv_1 = std::make_shared(reduce_mean, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + function = std::make_shared(OutputVector{conv_1}, ParameterVector{input}, "GoodReshapeModel"); + { + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros({ + weightsShape[0] - 3, + weightsShape[1], + weightsShape[2], + weightsShape[3] + }, {{}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto reduce_const = opset5::Constant::create(element::i64, Shape{2}, {2, 3}); + auto reduce_mean = std::make_shared(conv, reduce_const, true); + + auto conv_1_shape = Shape{12, 3, 1, 1}; + auto conv_1_weights = create_constant_with_zeros({ + conv_1_shape[0], + conv_1_shape[1], + conv_1_shape[2], + conv_1_shape[3] + }, {{}, {}, {}, {}}); + auto conv_1 = std::make_shared(reduce_mean, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + function_ref = std::make_shared(OutputVector{conv_1}, ParameterVector{input}, "GoodReshapeModel"); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneReducelayerUp.svg").run_on_function(function); + + pass::Manager m; + m.register_pass(); + m.run_passes(function); + + compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + + compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}})); +} + + +TEST_F(TransformationTestsF, PruneReduceLayerDown) { + auto inputShapes = PartialShape{1, 6, 16, 16}; + auto weightsShape = Shape{6, 6, 1, 1}; + + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto reduce_const = opset5::Constant::create(element::i64, Shape{2}, {2, 3}); + auto reduce_mean = std::make_shared(conv, reduce_const, true); + + auto conv_1_shape = Shape{weightsShape[0], weightsShape[0], 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{1, 2, 3}, {}, {}, {}}); + auto conv_1 = std::make_shared(reduce_mean, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto add_1 = std::make_shared(conv_1, conv); + + auto end_conv_shape = Shape{weightsShape[1], weightsShape[0], 1, 1}; + auto weights_end_conv = create_constant_with_zeros(end_conv_shape, {{1, 2, 3}, {}, {}, {}}); + auto end_conv = std::make_shared(add_1, weights_end_conv, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + function = std::make_shared(OutputVector{end_conv}, ParameterVector{input}, "GoodReshapeDonw"); + { + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros({ + weightsShape[0] - 3, + weightsShape[1], + weightsShape[2], + weightsShape[3], + }, {{}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto reduce_const = opset5::Constant::create(element::i64, Shape{2}, {2, 3}); + auto reduce_mean = std::make_shared(conv, reduce_const, true); + + auto conv_1_shape = Shape{weightsShape[0] - 3, weightsShape[0] - 3, 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{}, {}, {}, {}}); + auto conv_1 = std::make_shared(reduce_mean, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + + auto add_1 = std::make_shared(conv_1, conv); + + auto end_conv_shape = Shape{weightsShape[1], weightsShape[0] - 3, 1, 1}; + auto weights_end_conv = create_constant_with_zeros(end_conv_shape, {{}, {}, {}, {}}); + auto end_conv = std::make_shared(add_1, weights_end_conv, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + function_ref = std::make_shared(OutputVector{end_conv}, ParameterVector{input}, "GoodReshapeDown"); + } + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneReduceLayerDown.svg").run_on_function(function); + { + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(function); + } + compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}})); + compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + + compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(conv_1->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + + compare_masks(*getMask(reduce_mean->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(add_1->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + + compare_masks(*getMask(weights_end_conv.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}})); + compare_masks(*getMask(end_conv->output(0)), Mask({{}, {}, {}, {}})); + { + pass::Manager m; + m.register_pass(); + m.run_passes(function); + } + disable_rt_info_check(); + enable_accuracy_check(); +} + + +TEST(TransformationTests, PruneStopReducelayerUp) { + auto inputShapes = PartialShape{1, 6, 16, 16}; + auto weightsShape = Shape{6, 6, 1, 1}; + + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto reduce_const = opset5::Constant::create(element::i64, Shape{3}, {1, 2, 3}); + auto reduce_mean = std::make_shared(conv, reduce_const, true); + + auto conv_1_shape = Shape{12, 1, 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{1, 2, 3}, {}, {}, {}}); + auto conv_1 = std::make_shared(reduce_mean, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto function = std::make_shared(OutputVector{conv_1}, ParameterVector{input}, "BadReshapeModel"); + + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneStopReducelayerUp.svg").run_on_function(function); + + pass::Manager m; + m.register_pass(); + m.run_passes(function); + + compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}})); +} + + +TEST(TransformationTests, PruneStopReduceLayerDown) { + // Checks case of branching with stop op + auto inputShapes = PartialShape{1, 6, 16, 16}; + auto weightsShape = Shape{6, 6, 1, 1}; + + auto input = std::make_shared(element::f32, inputShapes); + auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}}); + auto conv = std::make_shared(input, weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + // Branching stop op + auto reduce_const = opset5::Constant::create(element::i64, Shape{3}, {1, 2, 3}); + auto reduce_mean = std::make_shared(conv, reduce_const, true); + + auto conv_1_shape = Shape{weightsShape[0], 1, 1, 1}; + auto conv_1_weights = create_constant_with_zeros(conv_1_shape, {{1, 2, 3}, {}, {}, {}}); + auto conv_1 = std::make_shared(reduce_mean, conv_1_weights, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + // Multiply will try to propagate a non zero masks of the conv_1 up + // and the mask should be invalidated by reduce_mean stop op mask + auto mul = std::make_shared(conv_1, conv); + + auto end_conv_shape = Shape{weightsShape[1], weightsShape[0], 1, 1}; + auto weights_end_conv = create_constant_with_zeros(end_conv_shape, {{1, 2, 3}, {}, {}, {}}); + auto end_conv = std::make_shared(mul, weights_end_conv, Strides(2, 1), + CoordinateDiff(2, 0), + CoordinateDiff(2, 0), + Strides(2, 1)); + + auto function = std::make_shared(OutputVector{end_conv}, ParameterVector{input}, "RestrictedReduceMeanBranching"); + + if (VISUALIZE_TESTS_TREE) + ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneStopReduceLayerDown.svg").run_on_function(function); + + pass::Manager m; + m.register_pass(); + m.run_passes(function); + + compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}})); + + compare_masks(*getMask(weights_end_conv.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}})); + compare_masks(*getMask(end_conv->output(0)), Mask({{}, {}, {}, {}})); +}