TransposeSinking: add support for Slice op
This commit is contained in:
parent
a47a18cf55
commit
54bf0444e4
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingSliceForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingSliceBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::TransposeSinkingSliceForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSliceForward", "0");
|
||||
TransposeSinkingSliceForward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingSliceBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSliceBackward", "0");
|
||||
TransposeSinkingSliceBackward();
|
||||
};
|
@ -0,0 +1,113 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_slice.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace transpose_sinking;
|
||||
|
||||
ov::pass::TransposeSinkingSliceForward::TransposeSinkingSliceForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSliceForward);
|
||||
auto const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
|
||||
auto main_node_label = wrap_type<Slice>({transpose_label, any_input(), any_input(), any_input(), any_input()});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
|
||||
auto& main_node = pattern_to_node.at(main_node_label);
|
||||
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
||||
if (!transpose) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
|
||||
if (!transpose_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
|
||||
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
|
||||
const auto& indices = main_node->input_value(4);
|
||||
auto new_axis = std::make_shared<Gather>(data, indices, axis);
|
||||
|
||||
main_node->input(4).replace_source_output(new_axis);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingSliceBackward::TransposeSinkingSliceBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSliceBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Slice>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
});
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||
transpose_const,
|
||||
/* input_indexes= */ {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
SwapNames(main_node, transpose);
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
auto data =
|
||||
std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
|
||||
const auto& indices = main_node->input_value(4);
|
||||
auto new_axis = std::make_shared<Gather>(data, indices, axis);
|
||||
main_node->input(4).replace_source_output(new_axis);
|
||||
|
||||
const auto& interpolate = std::dynamic_pointer_cast<Slice>(main_node);
|
||||
main_node->validate_and_infer_types();
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -14,6 +14,7 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_slice.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -182,6 +183,18 @@ private:
|
||||
FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_reference) {
|
||||
return std::make_shared<InterpolateFactory>(type_name, is_reference);
|
||||
}
|
||||
|
||||
class SliceFactory : public IFactory {
|
||||
public:
|
||||
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Slice>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], parent_nodes[4]);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateSliceFactory(const std::string& type_name) {
|
||||
return std::make_shared<SliceFactory>(type_name);
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#undef CREATE_UNARY_FACTORY
|
||||
@ -213,6 +226,9 @@ FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_refere
|
||||
|
||||
#undef CREATE_INTERPOLATE_FACTORY
|
||||
#define CREATE_INTERPOLATE_FACTORY(type_name, reference_flag) CreateInterpolateFactory(#type_name, reference_flag)
|
||||
|
||||
#undef CREATE_SLICE_FACTORY
|
||||
#define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
struct Preprocessing {
|
||||
@ -721,6 +737,49 @@ auto test_forward_unsqueeze = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TransposeSinkingTestFixture, test_forward_unsqueeze());
|
||||
|
||||
auto test_forward_slice = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSliceForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(data, out, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{4}}};
|
||||
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(Slice)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TransposeSinkingTestFixture, test_forward_slice());
|
||||
// ------------------ BACKWARD --------------------
|
||||
|
||||
auto test_backward_unary = []() {
|
||||
@ -1064,5 +1123,47 @@ auto test_backward_unsqueeze = []() {
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_unsqueeze());
|
||||
|
||||
auto test_backward_slice = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSliceBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_SLICE_FACTORY(Slice)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(data, out, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {4}}};
|
||||
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
|
||||
} // namespace common
|
||||
} // namespace transpose_sinking
|
Loading…
Reference in New Issue
Block a user