From 0050643e9b097829e57af1be8f171c6ee0acbeaa Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Mon, 14 Feb 2022 17:42:51 +0100 Subject: [PATCH] Add BroadcastConstRangeReplacement transformation (#10318) --- .../broadcast_const_range_replacement.hpp | 28 ++ .../broadcast_const_range_replacement.cpp | 97 ++++++ .../smart_reshape/smart_reshape.cpp | 2 + .../broadcast_const_range_replacement.cpp | 318 ++++++++++++++++++ 4 files changed, 445 insertions(+) create mode 100644 src/common/transformations/include/transformations/smart_reshape/broadcast_const_range_replacement.hpp create mode 100644 src/common/transformations/src/transformations/smart_reshape/broadcast_const_range_replacement.cpp create mode 100644 src/tests/functional/inference_engine/transformations/smart_reshape/broadcast_const_range_replacement.cpp diff --git a/src/common/transformations/include/transformations/smart_reshape/broadcast_const_range_replacement.hpp b/src/common/transformations/include/transformations/smart_reshape/broadcast_const_range_replacement.hpp new file mode 100644 index 00000000000..6c0f449f9ea --- /dev/null +++ b/src/common/transformations/include/transformations/smart_reshape/broadcast_const_range_replacement.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API BroadcastConstRangeReplacement; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief BroadcastConstRangeReplacement replaces Constant filled with range values starting from 0 and replaces it with Range op + */ + +class ngraph::pass::BroadcastConstRangeReplacement: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + BroadcastConstRangeReplacement(); +}; diff --git a/src/common/transformations/src/transformations/smart_reshape/broadcast_const_range_replacement.cpp b/src/common/transformations/src/transformations/smart_reshape/broadcast_const_range_replacement.cpp new file mode 100644 index 00000000000..bc987f6e40f --- /dev/null +++ b/src/common/transformations/src/transformations/smart_reshape/broadcast_const_range_replacement.cpp @@ -0,0 +1,97 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/utils/utils.hpp" + +#include +#include + +#include +#include +#include +#include +#include + +#include "itt.hpp" + + +NGRAPH_RTTI_DEFINITION(ngraph::pass::BroadcastConstRangeReplacement, "BroadcastConstRangeReplacement", 0); + +ngraph::pass::BroadcastConstRangeReplacement::BroadcastConstRangeReplacement() { + MATCHER_SCOPE(BroadcastConstRangeReplacement); + auto data_input = pattern::wrap_type(); + auto target_shape = pattern::any_input(); + auto broadcast_pattern_node = pattern::wrap_type({data_input, target_shape}); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto broadcast = m.get_match_root(); + // The transformation was requested only for models with BroadcastType::BIDIRECTIONAL + // Further analysis is needed for other broadcast modes enablement + const auto broadcast_ptr = std::dynamic_pointer_cast(broadcast); + if (!broadcast_ptr || broadcast_ptr->get_broadcast_spec().m_type != ngraph::op::BroadcastType::BIDIRECTIONAL) + return false; + + const auto data_const_out = broadcast->get_input_source_output(0); + const auto target_shape_out = broadcast->get_input_source_output(1); + + const auto const_node = std::dynamic_pointer_cast(data_const_out.get_node_shared_ptr()); + if (!const_node || !const_node->get_element_type().is_integral_number()) + return false; + + const auto& const_node_shape = const_node->get_output_shape(0); + const auto elem_count = shape_size(const_node_shape); + const auto one_dims_count = std::count(const_node_shape.cbegin(), const_node_shape.cend(), 1); + + constexpr size_t dim_low_limit = 5; + constexpr size_t dim_up_limit = 500; + + // To affect less models, the transformation is applied to Constants with elements count in range (5:500) + if (const_node_shape.size() - one_dims_count != 1 || elem_count <= dim_low_limit || elem_count >= dim_up_limit) + return false; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + const auto &const_values = const_node->cast_vector(); + + // Check if the value sequence is contiguous + if (const_values != sequence_pattern) + return false; + + const auto data_elem_type = data_const_out.get_element_type(); + const auto target_dim_index = std::distance(const_node_shape.cbegin(), std::find(const_node_shape.cbegin(), const_node_shape.cend(), elem_count)); + const int64_t target_dim_neg_index = target_dim_index - const_node_shape.size(); + + const auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i32, {}, {0}); + const auto target_dim_index_node = ngraph::opset8::Constant::create(ngraph::element::i64, {}, {target_dim_neg_index}); + const auto gather_dim = std::make_shared(target_shape_out, target_dim_index_node, axis_node); + + // If the corresponding target dim is 1, use the original end of range + const auto one_dim_const = ngraph::opset8::Constant::create(target_shape_out.get_element_type(), {}, {1}); + const auto dim_check_one = std::make_shared(gather_dim, one_dim_const); + + const auto start = ngraph::opset8::Constant::create(data_elem_type, {}, {0}); + const auto original_end = ngraph::opset8::Constant::create(data_elem_type, {}, {elem_count}); + + const auto cast_gather_dim = std::make_shared(gather_dim, data_elem_type); + const auto select_end = std::make_shared(dim_check_one, original_end, cast_gather_dim); + + const auto default_range_step = ngraph::opset8::Constant::create(data_elem_type, {}, {1}); + const auto range = std::make_shared(start, select_end, default_range_step, data_elem_type); + + // Unsqueeze the output of the Range op to the original shape of data input + std::vector final_shape_axes(const_node_shape.size()); + std::iota(final_shape_axes.begin(), final_shape_axes.end(), 0); + final_shape_axes.erase(final_shape_axes.begin() + target_dim_index); + const auto axes_to_unsqueeze = ngraph::opset8::Constant::create(ngraph::element::i64, {final_shape_axes.size()}, final_shape_axes); + const auto unsqueeze_range = std::make_shared(range, axes_to_unsqueeze); + + copy_runtime_info(const_node, {axis_node, target_dim_index_node, gather_dim, cast_gather_dim, one_dim_const, dim_check_one, + start, original_end, select_end, default_range_step, range, axes_to_unsqueeze, unsqueeze_range}); + broadcast->input(0).replace_source_output(unsqueeze_range); + return false; + }; + + auto m = std::make_shared(broadcast_pattern_node, matcher_name); + this->register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp b/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp index 96e8aa611a9..2cb63d5fb5e 100644 --- a/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp +++ b/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ bool ngraph::pass::SmartReshape::run_on_model(const std::shared_ptr(); static_manager.register_pass(); static_manager.register_pass(); + static_manager.register_pass(); static_manager.run_passes(f); ngraph::pass::Manager dynamic_manager; diff --git a/src/tests/functional/inference_engine/transformations/smart_reshape/broadcast_const_range_replacement.cpp b/src/tests/functional/inference_engine/transformations/smart_reshape/broadcast_const_range_replacement.cpp new file mode 100644 index 00000000000..e75dfda9995 --- /dev/null +++ b/src/tests/functional/inference_engine/transformations/smart_reshape/broadcast_const_range_replacement.cpp @@ -0,0 +1,318 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + + +using namespace testing; +using namespace ngraph; + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacement_dim_match) { + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto target_shape = ngraph::opset8::Constant::create(target_shape_elem_type, {4}, {2, 3, 4, elem_count}); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + + manager.register_pass(); + } + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + auto target_shape = ngraph::opset8::Constant::create(element::i64, {4}, {2, 3, 4, elem_count}); + + const auto target_dim_neg_index = -1; + const auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i32, Shape{}, {0}); + const auto target_dim_index_node = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{}, {target_dim_neg_index}); + const auto gather_dim = std::make_shared(target_shape, target_dim_index_node, axis_node); + + const auto one_dim_const = ngraph::opset8::Constant::create(target_shape_elem_type, {}, {1}); + const auto dim_check_one = std::make_shared(gather_dim, one_dim_const); + + const auto start = ngraph::opset8::Constant::create(data_elem_type, {}, {0}); + const auto original_end = ngraph::opset8::Constant::create(data_elem_type, {}, {elem_count}); + + const auto cast_gather_dim = std::make_shared(gather_dim, data_elem_type); + const auto select_end = std::make_shared(dim_check_one, original_end, cast_gather_dim); + + const auto default_range_step = ngraph::opset8::Constant::create(data_elem_type, {}, {1}); + const auto range = std::make_shared(start, select_end, default_range_step, data_elem_type); + const auto axes_to_unsqueeze = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{1}, {0}); + const auto unsqueeze_range = std::make_shared(range, axes_to_unsqueeze); + + const auto broadcast_node = std::make_shared(unsqueeze_range, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function_ref = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + } +} + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacement_dim_one) { + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto target_shape = ngraph::opset8::Constant::create(target_shape_elem_type, {4}, {2, 3, 4, 1}); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + + manager.register_pass(); + } + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + auto target_shape = ngraph::opset8::Constant::create(element::i64, {4}, {2, 3, 4, 1}); + + const auto target_dim_neg_index = -1; + const auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i32, Shape{}, {0}); + const auto target_dim_index_node = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{}, {target_dim_neg_index}); + const auto gather_dim = std::make_shared(target_shape, target_dim_index_node, axis_node); + + // If the corresponding target dim is 1, use the original end of range + const auto one_dim_const = ngraph::opset8::Constant::create(target_shape_elem_type, {}, {1}); + const auto dim_check_one = std::make_shared(gather_dim, one_dim_const); + + const auto start = ngraph::opset8::Constant::create(data_elem_type, {}, {0}); + const auto original_end = ngraph::opset8::Constant::create(data_elem_type, {}, {elem_count}); + + const auto cast_gather_dim = std::make_shared(gather_dim, data_elem_type); + const auto select_end = std::make_shared(dim_check_one, original_end, cast_gather_dim); + + const auto default_range_step = ngraph::opset8::Constant::create(data_elem_type, {}, {1}); + const auto range = std::make_shared(start, select_end, default_range_step, data_elem_type); + const auto axes_to_unsqueeze = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{1}, {0}); + const auto unsqueeze_range = std::make_shared(range, axes_to_unsqueeze); + + const auto broadcast_node = std::make_shared(unsqueeze_range, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function_ref = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + } +} + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacement_target_shapeof) { + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + + auto data_param = std::make_shared(data_elem_type, Shape{2, 3, 4, elem_count}); + auto target_shape = std::make_shared(data_param); + + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{data_param}); + + manager.register_pass(); + } + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + auto data_param = std::make_shared(data_elem_type, Shape{2, 3, 4, elem_count}); + auto target_shape = std::make_shared(data_param); + + const auto target_dim_neg_index = -1; + const auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i32, Shape{}, {0}); + const auto target_dim_index_node = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{}, {target_dim_neg_index}); + const auto gather_dim = std::make_shared(target_shape, target_dim_index_node, axis_node); + + // If the corresponding target dim is 1, use the original end of range + const auto one_dim_const = ngraph::opset8::Constant::create(target_shape_elem_type, {}, {1}); + const auto dim_check_one = std::make_shared(gather_dim, one_dim_const); + + const auto start = ngraph::opset8::Constant::create(data_elem_type, {}, {0}); + const auto original_end = ngraph::opset8::Constant::create(data_elem_type, {}, {elem_count}); + + const auto cast_gather_dim = std::make_shared(gather_dim, data_elem_type); + const auto select_end = std::make_shared(dim_check_one, original_end, cast_gather_dim); + + const auto default_range_step = ngraph::opset8::Constant::create(data_elem_type, {}, {1}); + const auto range = std::make_shared(start, select_end, default_range_step, data_elem_type); + const auto axes_to_unsqueeze = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{1}, {0}); + const auto unsqueeze_range = std::make_shared(range, axes_to_unsqueeze); + + const auto broadcast_node = std::make_shared(unsqueeze_range, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function_ref = std::make_shared(OutputVector{broadcast_node}, ParameterVector{data_param}); + } +} + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacement_target_shapeof_mixed_dims) { + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + + auto data_param = std::make_shared(data_elem_type, Shape{2, 3, elem_count, 4}); + auto target_shape = std::make_shared(data_param); + + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, 1, elem_count, 1}, sequence_pattern); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{data_param}); + + manager.register_pass(); + } + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + auto data_param = std::make_shared(data_elem_type, Shape{2, 3, elem_count, 4}); + auto target_shape = std::make_shared(data_param); + + const auto target_dim_neg_index = -2; + const auto axis_node = ngraph::opset8::Constant::create(ngraph::element::i32, Shape{}, {0}); + const auto target_dim_index_node = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{}, {target_dim_neg_index}); + const auto gather_dim = std::make_shared(target_shape, target_dim_index_node, axis_node); + + const auto one_dim_const = ngraph::opset8::Constant::create(target_shape_elem_type, {}, {1}); + const auto dim_check_one = std::make_shared(gather_dim, one_dim_const); + + const auto start = ngraph::opset8::Constant::create(data_elem_type, {}, {0}); + const auto original_end = ngraph::opset8::Constant::create(data_elem_type, {}, {elem_count}); + + const auto cast_gather_dim = std::make_shared(gather_dim, data_elem_type); + const auto select_end = std::make_shared(dim_check_one, original_end, cast_gather_dim); + + const auto default_range_step = ngraph::opset8::Constant::create(data_elem_type, {}, {1}); + const auto range = std::make_shared(start, select_end, default_range_step, data_elem_type); + + // Axes to unsqueeze without target dim index + const auto axes_to_unsqueeze = ngraph::opset8::Constant::create(ngraph::element::i64, Shape{3}, {0, 1, 3}); + const auto unsqueeze_range = std::make_shared(range, axes_to_unsqueeze); + + const auto broadcast_node = std::make_shared(unsqueeze_range, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function_ref = std::make_shared(OutputVector{broadcast_node}, ParameterVector{data_param}); + } +} + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacementNeg_other_mode) { + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto target_shape = ngraph::opset8::Constant::create(target_shape_elem_type, {4}, {2, 3, 4, elem_count}); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::NUMPY); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacementNeg_reversed_sequence) { + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.rbegin(), sequence_pattern.rend(), 0); + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto target_shape = ngraph::opset8::Constant::create(target_shape_elem_type, {4}, {2, 3, 4, elem_count}); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacementNeg_too_small) { + { + constexpr auto elem_count = 4; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto target_shape = ngraph::opset8::Constant::create(target_shape_elem_type, {4}, {2, 3, 4, elem_count}); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::NUMPY); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, BroadcastConstRangeReplacementNeg_too_big) { + { + constexpr auto elem_count = 1024; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto target_shape = ngraph::opset8::Constant::create(target_shape_elem_type, {4}, {2, 3, 4, elem_count}); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::NUMPY); + + function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{}); + + manager.register_pass(); + } +} + +// Model reshape call test +TEST(SmartReshapeTests, BroadcastConstRangeReplacement_reshape) { + { + constexpr auto elem_count = 236; + constexpr auto data_elem_type = element::i32; + constexpr auto target_shape_elem_type = element::i64; + + std::vector sequence_pattern(elem_count); + std::iota(sequence_pattern.begin(), sequence_pattern.end(), 0); + + auto data_param = std::make_shared(data_elem_type, Shape{2, 3, 4, elem_count}); + auto target_shape = std::make_shared(data_param); + + auto data_to_broadcast = ngraph::opset8::Constant::create(data_elem_type, {1, elem_count}, sequence_pattern); + auto broadcast_node = std::make_shared(data_to_broadcast, target_shape, ngraph::op::BroadcastType::BIDIRECTIONAL); + + auto function = std::make_shared(OutputVector{broadcast_node}, ParameterVector{data_param}); + + // BroadcastConstRangeReplacement is called as a part of SmartReshape + EXPECT_NO_THROW(function->reshape(PartialShape{1, 189})); + } +}