Add limited support for StridedSlice op

This commit is contained in:
Ivan 2023-03-06 18:23:40 +04:00
parent 3565ff2181
commit d71949fd09
2 changed files with 194 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingStridedSliceForward;
class TRANSFORMATIONS_API TransposeSinkingStridedSliceBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingStridedSliceForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingStridedSliceForward", "0");
TransposeSinkingStridedSliceForward();
};
class ov::pass::TransposeSinkingStridedSliceBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingStridedSliceBackward", "0");
TransposeSinkingStridedSliceBackward();
};

View File

@ -0,0 +1,164 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_strided_slice.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace transpose_sinking;
ov::pass::TransposeSinkingStridedSliceForward::TransposeSinkingStridedSliceForward() {
MATCHER_SCOPE(TransposeSinkingStridedSliceForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label =
wrap_type<StridedSlice>({transpose_label, any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose) {
return false;
}
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
if (!transpose_const) {
return false;
}
const auto& strided_slice = std::dynamic_pointer_cast<StridedSlice>(main_node);
if (!strided_slice) {
return false;
}
auto elipsis_mask = strided_slice->get_ellipsis_mask();
auto new_axis_mask = strided_slice->get_new_axis_mask();
auto shrink_mask = strided_slice->get_shrink_axis_mask();
if (!elipsis_mask.empty() || !new_axis_mask.empty() || !shrink_mask.empty()) {
// not supported yet
return false;
}
// remove Transpose on 1st input:
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
main_node->input(0).replace_source_output(transpose_parent);
// change the order of values for PadBegin and PadEng inputs
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
main_node->input(1).replace_source_output(
ChangeValuesOrder(main_node->input_value(1), reversed_transpose_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis));
main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), reversed_transpose_order, axis));
const auto& begin_mask = strided_slice->get_begin_mask();
const auto& end_mask = strided_slice->get_end_mask();
const auto& order_size = transpose_axis_order.size();
std::vector<int64_t> new_begin_mask(order_size), new_end_mask(order_size);
for (size_t i = 0; i < order_size; ++i) {
new_begin_mask[i] = begin_mask[transpose_axis_order[i]];
new_end_mask[i] = end_mask[transpose_axis_order[i]];
}
strided_slice->set_begin_mask(new_begin_mask);
strided_slice->set_begin_mask(new_end_mask);
main_node->validate_and_infer_types();
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingStridedSliceBackward::TransposeSinkingStridedSliceBackward() {
MATCHER_SCOPE(TransposeSinkingDataMovementBackward);
auto main_node_label = wrap_type<StridedSlice>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
const auto& strided_slice = std::dynamic_pointer_cast<StridedSlice>(main_node);
if (!strided_slice) {
return false;
}
auto elipsis_mask = strided_slice->get_ellipsis_mask();
auto new_axis_mask = strided_slice->get_new_axis_mask();
auto shrink_mask = strided_slice->get_shrink_axis_mask();
if (!elipsis_mask.empty() || !new_axis_mask.empty() || !shrink_mask.empty()) {
// not supported yet
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
main_node->input(1).replace_source_output(
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), transpose_axis_order, axis));
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
const auto& begin_mask = strided_slice->get_begin_mask();
const auto& end_mask = strided_slice->get_end_mask();
const auto& order_size = reversed_transpose_order.size();
std::vector<int64_t> new_begin_mask(order_size), new_end_mask(order_size);
for (size_t i = 0; i < order_size; ++i) {
new_begin_mask[i] = begin_mask[reversed_transpose_order[i]];
new_end_mask[i] = end_mask[reversed_transpose_order[i]];
}
strided_slice->set_begin_mask(new_begin_mask);
strided_slice->set_begin_mask(new_end_mask);
main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}