[Transformations] SkipGather transformation (#9627)

* [Transformations] SkipGather transformation implemented

* review fixes
This commit is contained in:
Vladislav Golubev 2022-01-14 12:17:03 +03:00 committed by GitHub
parent 29d57ad0a8
commit c734100ec5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 335 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
#include <vector>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SkipGatherBeforeTransposeAndReshape;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SkipGatherBeforeTransposeAndReshape transformation removes Gather from the Gather->Transpose->Reshape sequence
* in case when input has batch=1 and gather has axis=0 and indices={0}.
* Also, this transformation corrects a transpose constant to save semantic.
*/
class ngraph::pass::SkipGatherBeforeTransposeAndReshape : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SkipGatherBeforeTransposeAndReshape();
};

View File

@ -45,6 +45,7 @@
#include "transformations/common_optimizations/dilated_convolution_converter.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/split_squeeze_concat_fusion.hpp"
#include "transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp"
#include "transformations/common_optimizations/transpose_to_reshape.hpp"
#include "transformations/common_optimizations/strides_optimization.hpp"
#include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp"
@ -130,6 +131,7 @@ bool ngraph::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ngrap
common_fusions->add_matcher<ngraph::pass::SpaceToBatchFusion>();
common_fusions->add_matcher<ngraph::pass::BatchToSpaceFusion>();
common_fusions->add_matcher<ngraph::pass::InterpolateSequenceFusion>();
common_fusions->add_matcher<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.register_pass<ngraph::pass::ConcatReduceFusion>();

View File

@ -0,0 +1,90 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp"
#include "itt.hpp"
#include <memory>
#include <vector>
#include <openvino/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "transformations/utils/utils.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::SkipGatherBeforeTransposeAndReshape, "SkipGatherBeforeTransposeAndReshape", 0);
ngraph::pass::SkipGatherBeforeTransposeAndReshape::SkipGatherBeforeTransposeAndReshape() {
MATCHER_SCOPE(SkipGatherBeforeTransposeAndReshape);
auto input_m = ngraph::pattern::any_input(ngraph::pattern::has_static_dim(0));
auto indices_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto axis_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto gather_m = ngraph::pattern::wrap_type<ngraph::op::util::GatherBase>({input_m, indices_m, axis_m});
auto transpose_const_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto transpose_m = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({gather_m, transpose_const_m});
auto reshape_const_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto reshape_m = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({transpose_m, reshape_const_m});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto& input = pattern_map.at(input_m);
if (input.get_partial_shape()[0] != 1) {
return false;
}
const auto gather = pattern_map.at(gather_m).get_node_shared_ptr();
const auto indices = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(indices_m).get_node_shared_ptr());
const auto axis = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(axis_m).get_node_shared_ptr());
if (!indices || !axis) {
return false;
}
const std::vector<std::int64_t> expected_gather_value{0};
if (indices->cast_vector<std::int64_t>() != expected_gather_value ||
axis->cast_vector<std::int64_t>() != expected_gather_value) {
return false;
}
const auto transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
const auto transpose_const = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(transpose_const_m).get_node_shared_ptr());
if (!transpose_const) {
return false;
}
const auto reshape_const = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(reshape_const_m).get_node_shared_ptr());
if (!reshape_const) {
return false;
}
const auto reshape_vals = reshape_const->cast_vector<std::int64_t>();
if (std::any_of(reshape_vals.begin(), reshape_vals.end(), [](const std::int64_t x) { return x == 0; })) {
return false;
}
const auto transpose_vals = transpose_const->cast_vector<std::int64_t>();
std::vector<std::int64_t> new_transpose_vals{0};
// update the transpose const to compensate for the removal of Gather
for (auto elem : transpose_vals) {
new_transpose_vals.push_back(++elem);
}
const auto new_transpose_const = ngraph::opset8::Constant::create(transpose_const->get_element_type(),
{new_transpose_vals.size()},
new_transpose_vals);
const auto new_transpose = transpose->clone_with_new_inputs({input, new_transpose_const});
new_transpose->set_friendly_name(transpose->get_friendly_name());
ngraph::copy_runtime_info({transpose, gather}, new_transpose);
ngraph::replace_node(transpose, new_transpose);
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_m, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,213 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp>
#include <ngraph/function.hpp>
#include <openvino/opsets/opset8.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ov;
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeStaticShapeFpData) {
PartialShape data_shape{1, 3, 12, 12};
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1});
auto transpose = std::make_shared<opset8::Transpose>(data, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeStaticShapeIntData) {
PartialShape data_shape{1, 3, 12, 12};
{
auto data = std::make_shared<opset8::Parameter>(element::i64, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::i64, data_shape);
auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1});
auto transpose = std::make_shared<opset8::Transpose>(data, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicShapeStaticBatch) {
PartialShape data_shape{1, -1, -1, -1};
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1});
auto transpose = std::make_shared<opset8::Transpose>(data, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeIncorrectGatherAxis) {
PartialShape data_shape{1, 3, 12, 12};
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {2});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
}
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicBatch) {
PartialShape data_shape{-1, -1, -1, -1};
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
}
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicRank) {
PartialShape data_shape = PartialShape::dynamic();
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
}
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeBatchNotEqualTo1) {
PartialShape data_shape{2, 3, 12, 12};
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
}
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeUnsuitableReshapePattern) {
PartialShape data_shape{1, -1, -1, -1};
{
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
auto reshape_const = opset8::Constant::create(element::i64, {2}, {0, -1});
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
}
}