From c734100ec5ffa5c06240037862056232bb8cafce Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Fri, 14 Jan 2022 12:17:03 +0300 Subject: [PATCH] [Transformations] SkipGather transformation (#9627) * [Transformations] SkipGather transformation implemented * review fixes --- ...ip_gather_before_transpose_and_reshape.hpp | 30 +++ .../common_optimizations.cpp | 2 + ...ip_gather_before_transpose_and_reshape.cpp | 90 ++++++++ ...ther_before_transpose_and_reshape_test.cpp | 213 ++++++++++++++++++ 4 files changed, 335 insertions(+) create mode 100644 src/common/transformations/include/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.cpp create mode 100644 src/tests/functional/inference_engine/transformations/skip_gather_before_transpose_and_reshape_test.cpp diff --git a/src/common/transformations/include/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp b/src/common/transformations/include/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp new file mode 100644 index 00000000000..1259f96def6 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API SkipGatherBeforeTransposeAndReshape; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief SkipGatherBeforeTransposeAndReshape transformation removes Gather from the Gather->Transpose->Reshape sequence + * in case when input has batch=1 and gather has axis=0 and indices={0}. + * Also, this transformation corrects a transpose constant to save semantic. + */ +class ngraph::pass::SkipGatherBeforeTransposeAndReshape : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + SkipGatherBeforeTransposeAndReshape(); +}; \ No newline at end of file diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 5eb2d1c3f93..7592839baab 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -45,6 +45,7 @@ #include "transformations/common_optimizations/dilated_convolution_converter.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/common_optimizations/split_squeeze_concat_fusion.hpp" +#include "transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp" #include "transformations/common_optimizations/transpose_to_reshape.hpp" #include "transformations/common_optimizations/strides_optimization.hpp" #include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp" @@ -130,6 +131,7 @@ bool ngraph::pass::CommonOptimizations::run_on_model(const std::shared_ptradd_matcher(); common_fusions->add_matcher(); common_fusions->add_matcher(); + common_fusions->add_matcher(); common_fusions->set_name("ngraph::pass::CommonFusions"); manager.register_pass(); diff --git a/src/common/transformations/src/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.cpp b/src/common/transformations/src/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.cpp new file mode 100644 index 00000000000..530e4537fe6 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/skip_gather_before_transpose_and_reshape.cpp @@ -0,0 +1,90 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp" +#include "itt.hpp" + +#include +#include + +#include +#include +#include + +#include "transformations/utils/utils.hpp" + +NGRAPH_RTTI_DEFINITION(ngraph::pass::SkipGatherBeforeTransposeAndReshape, "SkipGatherBeforeTransposeAndReshape", 0); + +ngraph::pass::SkipGatherBeforeTransposeAndReshape::SkipGatherBeforeTransposeAndReshape() { + MATCHER_SCOPE(SkipGatherBeforeTransposeAndReshape); + + auto input_m = ngraph::pattern::any_input(ngraph::pattern::has_static_dim(0)); + + auto indices_m = ngraph::pattern::wrap_type(); + auto axis_m = ngraph::pattern::wrap_type(); + auto gather_m = ngraph::pattern::wrap_type({input_m, indices_m, axis_m}); + + auto transpose_const_m = ngraph::pattern::wrap_type(); + auto transpose_m = ngraph::pattern::wrap_type({gather_m, transpose_const_m}); + + auto reshape_const_m = ngraph::pattern::wrap_type(); + auto reshape_m = ngraph::pattern::wrap_type({transpose_m, reshape_const_m}); + + ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + const auto& input = pattern_map.at(input_m); + if (input.get_partial_shape()[0] != 1) { + return false; + } + + const auto gather = pattern_map.at(gather_m).get_node_shared_ptr(); + const auto indices = as_type_ptr(pattern_map.at(indices_m).get_node_shared_ptr()); + const auto axis = as_type_ptr(pattern_map.at(axis_m).get_node_shared_ptr()); + if (!indices || !axis) { + return false; + } + + const std::vector expected_gather_value{0}; + if (indices->cast_vector() != expected_gather_value || + axis->cast_vector() != expected_gather_value) { + return false; + } + + const auto transpose = pattern_map.at(transpose_m).get_node_shared_ptr(); + const auto transpose_const = as_type_ptr(pattern_map.at(transpose_const_m).get_node_shared_ptr()); + if (!transpose_const) { + return false; + } + + const auto reshape_const = as_type_ptr(pattern_map.at(reshape_const_m).get_node_shared_ptr()); + if (!reshape_const) { + return false; + } + + const auto reshape_vals = reshape_const->cast_vector(); + if (std::any_of(reshape_vals.begin(), reshape_vals.end(), [](const std::int64_t x) { return x == 0; })) { + return false; + } + + const auto transpose_vals = transpose_const->cast_vector(); + std::vector new_transpose_vals{0}; + // update the transpose const to compensate for the removal of Gather + for (auto elem : transpose_vals) { + new_transpose_vals.push_back(++elem); + } + + const auto new_transpose_const = ngraph::opset8::Constant::create(transpose_const->get_element_type(), + {new_transpose_vals.size()}, + new_transpose_vals); + const auto new_transpose = transpose->clone_with_new_inputs({input, new_transpose_const}); + new_transpose->set_friendly_name(transpose->get_friendly_name()); + ngraph::copy_runtime_info({transpose, gather}, new_transpose); + ngraph::replace_node(transpose, new_transpose); + + return false; + }; + + auto m = std::make_shared(reshape_m, matcher_name); + register_matcher(m, callback); +} diff --git a/src/tests/functional/inference_engine/transformations/skip_gather_before_transpose_and_reshape_test.cpp b/src/tests/functional/inference_engine/transformations/skip_gather_before_transpose_and_reshape_test.cpp new file mode 100644 index 00000000000..11612e26c77 --- /dev/null +++ b/src/tests/functional/inference_engine/transformations/skip_gather_before_transpose_and_reshape_test.cpp @@ -0,0 +1,213 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include + +#include + +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; +using namespace ov; + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeStaticShapeFpData) { + PartialShape data_shape{1, 3, 12, 12}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {0}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, data_shape); + + auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1}); + auto transpose = std::make_shared(data, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeStaticShapeIntData) { + PartialShape data_shape{1, 3, 12, 12}; + { + auto data = std::make_shared(element::i64, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {0}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::i64, data_shape); + + auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1}); + auto transpose = std::make_shared(data, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicShapeStaticBatch) { + PartialShape data_shape{1, -1, -1, -1}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {0}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, data_shape); + + auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1}); + auto transpose = std::make_shared(data, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeIncorrectGatherAxis) { + PartialShape data_shape{1, 3, 12, 12}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {2}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicBatch) { + PartialShape data_shape{-1, -1, -1, -1}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {0}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicRank) { + PartialShape data_shape = PartialShape::dynamic(); + { + auto data = std::make_shared(element::f32, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {0}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeBatchNotEqualTo1) { + PartialShape data_shape{2, 3, 12, 12}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {0}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } +} + +TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeUnsuitableReshapePattern) { + PartialShape data_shape{1, -1, -1, -1}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto indices_node = opset8::Constant::create(element::i64, {}, {0}); + auto axis_node = opset8::Constant::create(element::i64, {}, {0}); + auto gather = std::make_shared(data, indices_node, axis_node); + + auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0}); + auto transpose = std::make_shared(gather, transpose_const); + + auto reshape_const = opset8::Constant::create(element::i64, {2}, {0, -1}); + auto reshape = std::make_shared(transpose, reshape_const, true); + + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } +} \ No newline at end of file