TransposeSinking: add support for Slice op

This commit is contained in:
Ivan 2023-03-10 19:54:19 +04:00
parent a47a18cf55
commit 54bf0444e4
3 changed files with 244 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingSliceForward;
class TRANSFORMATIONS_API TransposeSinkingSliceBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingSliceForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSliceForward", "0");
TransposeSinkingSliceForward();
};
class ov::pass::TransposeSinkingSliceBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSliceBackward", "0");
TransposeSinkingSliceBackward();
};

View File

@ -0,0 +1,113 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_slice.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace transpose_sinking;
ov::pass::TransposeSinkingSliceForward::TransposeSinkingSliceForward() {
MATCHER_SCOPE(TransposeSinkingSliceForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label = wrap_type<Slice>({transpose_label, any_input(), any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose) {
return false;
}
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
if (!transpose_const) {
return false;
}
// remove Transpose on 1st input:
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
main_node->input(0).replace_source_output(transpose_parent);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
main_node->validate_and_infer_types();
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingSliceBackward::TransposeSinkingSliceBackward() {
MATCHER_SCOPE(TransposeSinkingSliceBackward);
auto main_node_label = wrap_type<Slice>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(main_node, transpose);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data =
std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
const auto& interpolate = std::dynamic_pointer_cast<Slice>(main_node);
main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -14,6 +14,7 @@
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/common_optimizations/transpose_sinking_slice.hpp"
#include "transpose_sinking_test_utils.hpp"
using namespace std;
@ -182,6 +183,18 @@ private:
FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_reference) {
return std::make_shared<InterpolateFactory>(type_name, is_reference);
}
class SliceFactory : public IFactory {
public:
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Slice>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], parent_nodes[4]);
}
};
FactoryPtr CreateSliceFactory(const std::string& type_name) {
return std::make_shared<SliceFactory>(type_name);
}
// ----------------------------------------------------------------------------
#undef CREATE_UNARY_FACTORY
@ -213,6 +226,9 @@ FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_refere
#undef CREATE_INTERPOLATE_FACTORY
#define CREATE_INTERPOLATE_FACTORY(type_name, reference_flag) CreateInterpolateFactory(#type_name, reference_flag)
#undef CREATE_SLICE_FACTORY
#define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name)
// ----------------------------------------------------------------------------
struct Preprocessing {
@ -721,6 +737,49 @@ auto test_forward_unsqueeze = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TransposeSinkingTestFixture, test_forward_unsqueeze());
auto test_forward_slice = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSliceForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
test_case.model.model_template = transpose_sinking::common::create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(data, out, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{4}}};
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(Slice)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TransposeSinkingTestFixture, test_forward_slice());
// ------------------ BACKWARD --------------------
auto test_backward_unary = []() {
@ -1064,5 +1123,47 @@ auto test_backward_unsqueeze = []() {
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_unsqueeze());
auto test_backward_slice = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSliceBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
test_case.model.main_op = {CREATE_SLICE_FACTORY(Slice)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(data, out, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {4}}};
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
} // namespace common
} // namespace transpose_sinking