[Transformations] SkipGather transformation (#9627)
* [Transformations] SkipGather transformation implemented * review fixes
This commit is contained in:
parent
29d57ad0a8
commit
c734100ec5
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SkipGatherBeforeTransposeAndReshape;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SkipGatherBeforeTransposeAndReshape transformation removes Gather from the Gather->Transpose->Reshape sequence
|
||||
* in case when input has batch=1 and gather has axis=0 and indices={0}.
|
||||
* Also, this transformation corrects a transpose constant to save semantic.
|
||||
*/
|
||||
class ngraph::pass::SkipGatherBeforeTransposeAndReshape : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SkipGatherBeforeTransposeAndReshape();
|
||||
};
|
@ -45,6 +45,7 @@
|
||||
#include "transformations/common_optimizations/dilated_convolution_converter.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/split_squeeze_concat_fusion.hpp"
|
||||
#include "transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp"
|
||||
#include "transformations/common_optimizations/transpose_to_reshape.hpp"
|
||||
#include "transformations/common_optimizations/strides_optimization.hpp"
|
||||
#include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp"
|
||||
@ -130,6 +131,7 @@ bool ngraph::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ngrap
|
||||
common_fusions->add_matcher<ngraph::pass::SpaceToBatchFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::BatchToSpaceFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::InterpolateSequenceFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConcatReduceFusion>();
|
||||
|
@ -0,0 +1,90 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SkipGatherBeforeTransposeAndReshape, "SkipGatherBeforeTransposeAndReshape", 0);
|
||||
|
||||
ngraph::pass::SkipGatherBeforeTransposeAndReshape::SkipGatherBeforeTransposeAndReshape() {
|
||||
MATCHER_SCOPE(SkipGatherBeforeTransposeAndReshape);
|
||||
|
||||
auto input_m = ngraph::pattern::any_input(ngraph::pattern::has_static_dim(0));
|
||||
|
||||
auto indices_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto axis_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto gather_m = ngraph::pattern::wrap_type<ngraph::op::util::GatherBase>({input_m, indices_m, axis_m});
|
||||
|
||||
auto transpose_const_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto transpose_m = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({gather_m, transpose_const_m});
|
||||
|
||||
auto reshape_const_m = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto reshape_m = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({transpose_m, reshape_const_m});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
const auto& input = pattern_map.at(input_m);
|
||||
if (input.get_partial_shape()[0] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto gather = pattern_map.at(gather_m).get_node_shared_ptr();
|
||||
const auto indices = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(indices_m).get_node_shared_ptr());
|
||||
const auto axis = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(axis_m).get_node_shared_ptr());
|
||||
if (!indices || !axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::vector<std::int64_t> expected_gather_value{0};
|
||||
if (indices->cast_vector<std::int64_t>() != expected_gather_value ||
|
||||
axis->cast_vector<std::int64_t>() != expected_gather_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
|
||||
const auto transpose_const = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(transpose_const_m).get_node_shared_ptr());
|
||||
if (!transpose_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto reshape_const = as_type_ptr<ngraph::opset8::Constant>(pattern_map.at(reshape_const_m).get_node_shared_ptr());
|
||||
if (!reshape_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto reshape_vals = reshape_const->cast_vector<std::int64_t>();
|
||||
if (std::any_of(reshape_vals.begin(), reshape_vals.end(), [](const std::int64_t x) { return x == 0; })) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto transpose_vals = transpose_const->cast_vector<std::int64_t>();
|
||||
std::vector<std::int64_t> new_transpose_vals{0};
|
||||
// update the transpose const to compensate for the removal of Gather
|
||||
for (auto elem : transpose_vals) {
|
||||
new_transpose_vals.push_back(++elem);
|
||||
}
|
||||
|
||||
const auto new_transpose_const = ngraph::opset8::Constant::create(transpose_const->get_element_type(),
|
||||
{new_transpose_vals.size()},
|
||||
new_transpose_vals);
|
||||
const auto new_transpose = transpose->clone_with_new_inputs({input, new_transpose_const});
|
||||
new_transpose->set_friendly_name(transpose->get_friendly_name());
|
||||
ngraph::copy_runtime_info({transpose, gather}, new_transpose);
|
||||
ngraph::replace_node(transpose, new_transpose);
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_m, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,213 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <transformations/common_optimizations/skip_gather_before_transpose_and_reshape.hpp>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov;
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeStaticShapeFpData) {
|
||||
PartialShape data_shape{1, 3, 12, 12};
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(data, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeStaticShapeIntData) {
|
||||
PartialShape data_shape{1, 3, 12, 12};
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::i64, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::i64, data_shape);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(data, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicShapeStaticBatch) {
|
||||
PartialShape data_shape{1, -1, -1, -1};
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {4}, {0, 2, 3, 1});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(data, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeIncorrectGatherAxis) {
|
||||
PartialShape data_shape{1, 3, 12, 12};
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {2});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicBatch) {
|
||||
PartialShape data_shape{-1, -1, -1, -1};
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeDynamicRank) {
|
||||
PartialShape data_shape = PartialShape::dynamic();
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeBatchNotEqualTo1) {
|
||||
PartialShape data_shape{2, 3, 12, 12};
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {1}, {-1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SkipGatherBeforeTransposeAndReshapeUnsuitableReshapePattern) {
|
||||
PartialShape data_shape{1, -1, -1, -1};
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto indices_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto axis_node = opset8::Constant::create(element::i64, {}, {0});
|
||||
auto gather = std::make_shared<opset8::Gather>(data, indices_node, axis_node);
|
||||
|
||||
auto transpose_const = opset8::Constant::create(element::i64, {3}, {1, 2, 0});
|
||||
auto transpose = std::make_shared<opset8::Transpose>(gather, transpose_const);
|
||||
|
||||
auto reshape_const = opset8::Constant::create(element::i64, {2}, {0, -1});
|
||||
auto reshape = std::make_shared<opset8::Reshape>(transpose, reshape_const, true);
|
||||
|
||||
function = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::SkipGatherBeforeTransposeAndReshape>();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user