TransposeSinking: add support for Slice and Reshape ops (#16208)

* Resolve the performance issues in TransposeSinking transformation

* codestyle

* fix warning as error, fix tests failures

* fix ts for Concat and Reduce

* Fix TransposeReduceBackward

* fix the issue in TransposeFuse transformation

* fix TransposeReduce transformations

* Fix TransposeReduction, fix TransposeSinkingSplit, add unsqueeze support

* delete debug print

* Add additional validations

* fix node validation

* Fix validate for split, revert changes for concat, add BatchToSpace/SpaceToBatch

* Add SpaceToBatch/BatchToSpace

* fix TS for Interpolate + codestyle

* fix gna build

* Support TS for Interpolate, VariadicSplit, IsInf, IsNan, IsFinite + refactoring

* add the missed line

* add include

* TransposeSinking tests refactoring: part1

* TransposeSinking tests refactoring: part2

* Add limited support for StridedSlice op

* codestye

* TransposeReduction: skip the case when 2nd input for Squeeze is not provided

* Transpose sinking tests refactoring: part 3. + Revert changes in MOC.

* fix build

* codestyle

* Add tests for TS backward transformations, update TransposeSinkingFuse transformation, delete StridedSlice transformation prototype + tests refactoring

* fix unary tests

* Fix warning as error on Windows

* Add new tests for Unsqueeze/Squeeze; refactoring; remove debug code

* TransposeSinking: add support for Slice op

* Add descriptions to the transformations, add additional checks

* fix a warning

* TransposeSinking Rafactoring part2: move the transformations to a separate folder, align namespaces

* TransposeSinking refactoring: class names, namespaces

* codestyle

* resolve merge conflicts

* codestyle

* TSReduction refactoring, move Unsqueeze/Squeeze transformations to separate files, added limited support for Reshape op + tests

* fix minor mistakes

* fix warnings

* Added TSSlice transformation to TSGeneral, created TransposeSinkingGeneral alias in ov::pass namespace

* refactoring

* codestyle

* fix TSSqueeze/TSUnsqueeze transformations

* delete debug serialize

* remove TransposeSinking from MOC

* fix TSSqueeze/TSUnsqueeze transformations in case of Reshape op

* delete debug code

* codestyle

* fix unit tests, revert changes for TSSlice transformation

* fix TSSqueeze transformation

* resolve review comments

* codestyle
This commit is contained in:
Ivan Tikhonov 2023-03-24 17:01:15 +04:00 committed by GitHub
parent c5b348dd4f
commit 5a8a195dad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1169 additions and 172 deletions

View File

@ -16,6 +16,9 @@ class TRANSFORMATIONS_API TSGeneralBackward;
class TRANSFORMATIONS_API TSGeneral;
} // namespace transpose_sinking
using TransposeSinkingGeneral = ov::pass::transpose_sinking::TSGeneral;
} // namespace pass
} // namespace ov

View File

@ -21,7 +21,7 @@ class TRANSFORMATIONS_API TSReductionBackward;
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
* @brief TransposeReductionForward transformation sinks Transpose through Reduce operations
* in the forward direction.
*/
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
@ -32,7 +32,7 @@ public:
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce operations
* in the backward direction.
*/
class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass {

View File

@ -0,0 +1,32 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSSliceForward;
class TRANSFORMATIONS_API TSSliceBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
class ov::pass::transpose_sinking::TSSliceForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSSliceForward", "0");
TSSliceForward();
};
class ov::pass::transpose_sinking::TSSliceBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSSliceBackward", "0");
TSSliceBackward();
};

View File

@ -0,0 +1,42 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSSqueezeForward;
class TRANSFORMATIONS_API TSSqueezeBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSSqueezeForward transformation sinks Transpose through Reshape, Squeeze operations
* in the forward direction.
*/
class ov::pass::transpose_sinking::TSSqueezeForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSSqueezeForward", "0");
TSSqueezeForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSSqueezeBackward transformation sinks Transpose through Reshape, Squeeze operations
* in the backward direction.
*/
class ov::pass::transpose_sinking::TSSqueezeBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSSqueezeBackward", "0");
TSSqueezeBackward();
};

View File

@ -0,0 +1,42 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSUnsqueezeForward;
class TRANSFORMATIONS_API TSUnsqueezeBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSUnsqueezeForward transformation sinks Transpose through Unsqueeze, Reshape operations
* in the forward direction.
*/
class ov::pass::transpose_sinking::TSUnsqueezeForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSUnsqueezeForward", "0");
TSUnsqueezeForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSUnsqueezeBackward transformation sinks Transpose through Unsqueeze, Reshape operations
* in the backward direction.
*/
class ov::pass::transpose_sinking::TSUnsqueezeBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSUnsqueezeBackward", "0");
TSUnsqueezeBackward();
};

View File

@ -98,17 +98,43 @@ void UpdateForwardSinkingAbility(const std::shared_ptr<ov::Node>&);
bool HasSameOutputTransposeNodes(const ov::Output<ov::Node>&);
/**
* Removes all direct node consumers that have one output
* @brief Removes all direct node consumers that have one output
*/
void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
/**
* Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis
* @brief Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis
*/
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
const ov::AxisVector& transpose_axis_order,
const std::shared_ptr<ov::opset10::Constant>& axis);
/**
* @brief Returns the updated axes order for case when the initial axes order has more elements
* than after TransposeSinking, e.g.:
*
* before: Transpose(the initial axes order) -> ReduceMax
* after : ReduceMax -> Transpose (the updated axes order)
*
* before: Unsqueeze -> Transpose (the initial axes order)
* after : Transpose (the updated axes order) -> Unsqueeze
*/
std::vector<size_t> GetOrderAfterReduction(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values);
/**
* @brief Returns the updated axes order for case when the initial axes order has less elements
* than after TransposeSinking, e.g.:
*
* before : ReduceMax -> Transpose (the updated axes order)
* after: Transpose(the initial axes order) -> ReduceMax
*
* before: Transpose (the updated axes order) -> Unsqueeze
* after : Unsqueeze -> Transpose (the initial axes order)
*/
std::vector<size_t> GetOrderBeforeReduction(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values);
} // namespace utils
} // namespace transpose_sinking
} // namespace pass

View File

@ -236,7 +236,6 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
REGISTER_PASS(manager, ReverseInputChannelsFusion)
REGISTER_PASS(manager, AlignEltwiseInputRanks)
REGISTER_PASS(manager, ConstantFolding)
manager.run_passes(f);
if (!m_use_shapes) {

View File

@ -15,8 +15,11 @@
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov::pass::transpose_sinking;
@ -29,7 +32,10 @@ TSGeneralForward::TSGeneralForward() {
add_matcher<TSSplitForward>();
add_matcher<TSDataMovementForward>();
add_matcher<TSReductionForward>();
add_matcher<TSSqueezeForward>();
add_matcher<TSUnsqueezeForward>();
add_matcher<TSInterpolateForward>();
add_matcher<TSSliceForward>();
add_matcher<TSFuse>();
}
@ -41,7 +47,10 @@ TSGeneralBackward::TSGeneralBackward() {
add_matcher<TSSplitBackward>();
add_matcher<TSDataMovementBackward>();
add_matcher<TSReductionBackward>();
add_matcher<TSSqueezeBackward>();
add_matcher<TSUnsqueezeBackward>();
add_matcher<TSInterpolateBackward>();
add_matcher<TSSliceBackward>();
add_matcher<TSFuse>();
}

View File

@ -18,60 +18,15 @@
using namespace ov;
using namespace opset10;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
namespace {
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() - axes_values.size();
std::vector<size_t> aligned_order(buffer_size, 0);
std::vector<size_t> values_to_reduce(axes_values);
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
values_to_reduce[i] = order_values[axes_values[i]];
}
std::sort(values_to_reduce.begin(), values_to_reduce.end());
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
continue;
}
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
++j;
}
return aligned_order;
}
std::vector<size_t> get_updated_order_backward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() + axes_values.size();
std::vector<size_t> aligned_order(buffer_size);
std::vector<int64_t> cnt_deleted(buffer_size);
int64_t cnt = 0;
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
cnt++;
}
cnt_deleted[i] = i - cnt;
}
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
aligned_order[i] = i;
continue;
}
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
++j;
}
return aligned_order;
}
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
auto arithmetic_reduce = std::dynamic_pointer_cast<ov::op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = std::dynamic_pointer_cast<ov::op::util::LogicalReductionKeepDims>(reduction);
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
bool keep_dims = false; // squeeze/unsqueeze always reduces number of output dimensions
if (logical_reduce)
@ -80,32 +35,29 @@ bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
keep_dims = arithmetic_reduce->get_keep_dims();
return keep_dims;
}
} // namespace
TSReductionForward::TSReductionForward() {
MATCHER_SCOPE(TSReductionForward);
auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
pattern::consumers_count(1));
auto reduce_or_squeeze_label = pattern::
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
{transpose_label, pattern::wrap_type<Constant>()});
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{transpose_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
auto transpose = pattern_to_output.at(transpose_label);
auto reduction = pattern_to_output.at(reduce_label);
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto rank = reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
@ -117,38 +69,17 @@ TSReductionForward::TSReductionForward() {
}
if (!keep_dims) {
if (non_negative_axes.empty()) {
auto input_pshape = transpose->input_value(0).get_partial_shape();
if (input_pshape.is_static()) {
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i] == 1) {
non_negative_axes.push_back(i);
}
}
} else {
return false;
}
}
if (unsqueeze) {
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
} else {
transpose_order_values = get_updated_order_forward(non_negative_axes, transpose_order_values);
}
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
}
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
std::shared_ptr<Node> new_reduction;
if (!unsqueeze) {
auto new_const =
std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
}
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
auto new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
replace_node(reduction, new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(reduction->get_friendly_name());
@ -158,98 +89,55 @@ TSReductionForward::TSReductionForward() {
return true;
};
auto m = std::make_shared<pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
auto m = std::make_shared<pattern::Matcher>(reduce_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
TSReductionBackward::TSReductionBackward() {
MATCHER_SCOPE(TSReductionBackward);
auto reduce_or_squeeze_label = pattern::
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
{pattern::any_input(), pattern::wrap_type<Constant>()},
HasSameOutputTransposeNodes);
auto transpose_label = pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{any_input(), wrap_type<Constant>()},
HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto reduction = pattern_to_output.at(reduce_label);
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto rank = reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>();
auto old_transpose_order_values = transpose_order_values;
std::vector<size_t> new_values;
if (unsqueeze) {
if (non_negative_axes.size() == transpose_order_values.size()) {
// input is a scalar, we unsqueeze all dims
// it's enough to eliminate such Transpose
transpose->output(0).replace(reduction);
return true;
}
for (const auto& axis : non_negative_axes) {
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
if (it != old_transpose_order_values.end()) {
new_values.push_back(it - old_transpose_order_values.begin());
}
}
}
bool squeeze_all_dims = false;
if (!keep_dims) {
if (non_negative_axes.empty()) {
auto input_pshape = reduction->input_value(0).get_partial_shape();
if (input_pshape.is_static()) {
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i] == 1) {
non_negative_axes.push_back(i);
}
}
squeeze_all_dims = true;
} else {
return false;
}
}
if (unsqueeze) {
transpose_order_values = get_updated_order_forward(new_values, transpose_order_values);
} else {
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
}
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
}
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
std::vector<size_t> new_values;
for (const auto& axis : non_negative_axes) {
new_values.push_back(reversed_order_values[axis]);
}
if (!unsqueeze) {
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
for (const auto& axis : non_negative_axes) {
new_values.push_back(reversed_order_values[axis]);
}
}
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
std::shared_ptr<Node> new_transpose, new_reduction;
if (squeeze_all_dims) {
new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
new_reduction = reduction->clone_with_new_inputs({new_transpose, reduction->input_value(1)});
} else {
auto new_const =
std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
}
replace_node(transpose, new_reduction);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
UpdateForwardSinkingAbility(new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(reduction->get_friendly_name());
register_new_node(new_transpose);
return true;
};

View File

@ -0,0 +1,116 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
TSSliceForward::TSSliceForward() {
MATCHER_SCOPE(TSSliceForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label = wrap_type<Slice>({transpose_label, any_input(), any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose || main_node->get_input_size() < 5) {
return false;
}
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
if (!transpose_const) {
return false;
}
// remove Transpose on 1st input:
auto transpose_parent = transpose->input_value(0);
main_node->input(0).replace_source_output(transpose_parent);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
main_node->validate_and_infer_types();
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
TSSliceBackward::TSSliceBackward() {
MATCHER_SCOPE(TSSliceBackward);
auto main_node_label = wrap_type<Slice>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (main_node->get_input_size() < 5) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(main_node, transpose);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data =
std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,271 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include <memory>
#include <vector>
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace opset10;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
namespace {
/**
* @brief Checks that Reshape operation is equal to Squeeze:
* Only 1 dims are deleted, all other dims must be the same.
* Converts these 1 dims to axes format.
* @arg reshape Reshape operation.
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
* @arg result_axes Contains axes which will be squeezed.
*/
bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
result_axes.clear();
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
// supported the case if Reshape is equal to Squeeze
const auto& new_shape = reduction_axes_values;
const auto& input_pshape = reshape->get_input_partial_shape(0);
// todo: support dynamic case
if (input_pshape.is_dynamic()) {
return false;
}
const auto input_shape = input_pshape.to_shape();
if (new_shape.size() < input_shape.size()) {
size_t j = 0;
for (size_t i = 0; i < input_shape.size(); i++) {
const auto input_dim = static_cast<int64_t>(input_shape[i]);
if (j < new_shape.size() && new_shape[j] == input_dim) {
j++;
} else if (input_dim != 1) {
return false;
} else {
result_axes.push_back(i);
}
}
if (j != new_shape.size()) {
// not all new_shape values are in input_shape
return false;
}
} else {
// another reshape type, not Squeeze
// todo: move this checks in the pattern
return false;
}
return true;
}
/**
* @brief Converts squeezed_axes to actual shape (2nd input) for Reshape operation
* using the shape of the 1st input to Reshape.
* @arg input_node 1st input to Reshape op.
* @arg squeeze_axes In case of Reshape op is equal to squeeze, these axes indicate the places where 1 dims have
* to be deleted.
*/
bool squeeze_axes_to_shape(const Output<Node>& input_node,
std::vector<size_t> squeeze_axes,
std::vector<size_t>& to_shape) {
to_shape.clear();
std::sort(squeeze_axes.begin(), squeeze_axes.end());
const auto& input_pshape = input_node.get_partial_shape();
if (input_pshape.is_dynamic()) {
return false;
}
const auto& input_shape = input_pshape.get_shape();
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
++j;
continue;
}
to_shape.push_back(input_shape[i]);
}
return true;
}
} // namespace
TSSqueezeForward::TSSqueezeForward() {
MATCHER_SCOPE(TSSqueezeForward);
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
auto squeeze_label = wrap_type<Squeeze, Reshape>({transpose_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto squeeze = pattern_to_output.at(squeeze_label);
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
if (!transpose_order || !squeeze_axes) {
return false;
}
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(squeeze)) {
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
}
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
if (non_negative_axes.empty()) {
auto input_pshape = transpose->output(0).get_partial_shape();
if (input_pshape.is_dynamic()) {
return false;
}
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i].get_length() == 1) {
non_negative_axes.push_back(i);
}
}
}
auto transpose_order_values = transpose_order->cast_vector<size_t>();
std::vector<size_t> new_values;
new_values.reserve(non_negative_axes.size());
for (const auto& axis : non_negative_axes) {
new_values.push_back(transpose_order_values[axis]);
}
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
if (as_type_ptr<Reshape>(squeeze)) {
std::vector<size_t> to_shape;
auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape);
if (!success) {
return false;
}
new_values = to_shape;
}
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
auto new_squeeze = squeeze->clone_with_new_inputs({transpose->input_value(0), new_const});
auto new_transpose = transpose->clone_with_new_inputs({new_squeeze, new_transpose_order});
replace_node(squeeze, new_transpose);
new_squeeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(squeeze->get_friendly_name());
UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
return true;
};
auto m = std::make_shared<pattern::Matcher>(squeeze_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
TSSqueezeBackward::TSSqueezeBackward() {
MATCHER_SCOPE(TSSqueezeBackward);
auto squeeze_label = wrap_type<Squeeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto transpose_label =
wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto squeeze = pattern_to_output.at(squeeze_label);
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
if (!transpose_order || !squeeze_axes) {
return false;
}
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(squeeze)) {
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
}
bool squeeze_all_dims = false;
if (non_negative_axes.empty()) {
auto input_pshape = squeeze->input_value(0).get_partial_shape();
if (input_pshape.is_dynamic()) {
return false;
}
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i] == 1) {
non_negative_axes.push_back(i);
}
}
squeeze_all_dims = true;
}
auto transpose_order_values = transpose_order->cast_vector<size_t>();
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
std::vector<size_t> new_values;
for (const auto& axis : non_negative_axes) {
new_values.push_back(reversed_order_values[axis]);
}
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(squeeze)) {
std::vector<size_t> to_shape;
auto success = squeeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
if (!success) {
return false;
}
new_values = to_shape;
}
std::shared_ptr<Node> new_squeeze;
if (squeeze_all_dims) {
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
} else {
auto new_const =
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
}
replace_node(transpose, new_squeeze);
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
new_squeeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(squeeze->get_friendly_name());
register_new_node(new_transpose);
return true;
};
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,243 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include <memory>
#include <vector>
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace opset10;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
namespace {
/**
* @brief Checks that Reshape operation is equal to Unsqueeze:
* Only 1 dims are inserted, all other dims must be the same.
* Converts these 1 dims to axes format.
* @arg reshape Reshape operation.
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
* @arg result_axes contains axes which will be unsqueezed.
*/
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
result_axes.clear();
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
// supported the case if Reshape is equal to Unsqueeze
const auto& new_shape = reduction_axes_values;
const auto& input_pshape = reshape->get_input_partial_shape(0);
// todo: support dynamic case
if (input_pshape.is_dynamic()) {
return false;
}
const auto input_shape = input_pshape.to_shape();
if (new_shape.size() > input_shape.size()) {
size_t j = 0;
for (size_t i = 0; i < new_shape.size(); ++i) {
if (j < input_shape.size() && static_cast<int64_t>(input_shape[j]) == new_shape[i]) {
j++;
} else if (new_shape[i] != 1) {
return false;
} else {
result_axes.push_back(i);
}
}
if (j != input_shape.size()) {
// not all input_shape values are in new_shape
return false;
}
} else {
// another reshape type, not Unsqueeze
// todo: move this checks in the pattern
return false;
}
return true;
}
/**
* @brief Converts unsqueeze_axes to actual shape (2nd input) for Reshape operation
* using the shape of the 1st input to Reshape.
* @arg input_node 1st input to Reshape op.
* @arg unsqueeze_axes In case of Reshape op is equal to Unsqueeze, these axes indicate the places where 1 dims have
* to be inserted.
*/
bool unsqueeze_axes_to_shape(const Output<Node>& input_node,
std::vector<size_t> unsqueeze_axes,
std::vector<size_t>& to_shape) {
to_shape.clear();
const auto& input_pshape = input_node.get_partial_shape();
if (input_pshape.is_dynamic()) {
return false;
}
const auto& input_shape = input_pshape.get_shape();
to_shape.resize(input_shape.size() + unsqueeze_axes.size());
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
for (size_t i = 0, j = 0, k = 0; i < to_shape.size(); ++i) {
if (j < unsqueeze_axes.size() && i == unsqueeze_axes[j]) {
to_shape[i] = 1;
j++;
} else if (k < input_shape.size()) {
to_shape[i] = input_shape[k];
k++;
}
}
return true;
}
} // namespace
TSUnsqueezeForward::TSUnsqueezeForward() {
MATCHER_SCOPE(TSUnsqueezeForward);
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({transpose_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
if (!transpose_order || !unsqueeze_axes) {
return false;
}
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(unsqueeze)) {
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
}
auto ts_order_values = transpose_order->cast_vector<size_t>();
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
auto new_transpose_order =
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
std::shared_ptr<Node> new_unsqueeze;
if (as_type_ptr<Reshape>(unsqueeze)) {
std::vector<size_t> new_values;
auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values);
if (!success) {
return false;
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
}
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
replace_node(unsqueeze, new_transpose);
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
return true;
};
auto m = std::make_shared<pattern::Matcher>(unsqueeze_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
TSUnsqueezeBackward::TSUnsqueezeBackward() {
MATCHER_SCOPE(TSUnsqueezeBackward);
auto unsqueeze_label =
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto transpose_label =
wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));
if (!transpose_order || !unsqueeze_axes)
return false;
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(unsqueeze)) {
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
}
auto transpose_order_values = transpose_order->cast_vector<size_t>();
auto old_transpose_order_values = transpose_order_values;
std::vector<size_t> new_values;
if (non_negative_axes.size() == transpose_order_values.size()) {
// input is a scalar, we unsqueeze all dims
// it's enough to eliminate such Transpose
transpose->output(0).replace(unsqueeze);
return true;
}
for (const auto& axis : non_negative_axes) {
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
if (it != old_transpose_order_values.end()) {
new_values.push_back(it - old_transpose_order_values.begin());
}
}
transpose_order_values = GetOrderAfterReduction(new_values, transpose_order_values);
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(unsqueeze)) {
std::vector<size_t> to_shape;
auto success = unsqueeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
if (!success) {
return false;
}
new_values = to_shape;
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
replace_node(transpose, new_unsqueeze);
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
register_new_node(new_transpose);
return true;
};
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -379,6 +379,53 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
}
}
std::vector<size_t> GetOrderAfterReduction(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() - axes_values.size();
std::vector<size_t> aligned_order(buffer_size, 0);
std::vector<size_t> values_to_reduce(axes_values);
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
values_to_reduce[i] = order_values[axes_values[i]];
}
std::sort(values_to_reduce.begin(), values_to_reduce.end());
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
continue;
}
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
++j;
}
return aligned_order;
}
std::vector<size_t> GetOrderBeforeReduction(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() + axes_values.size();
std::vector<size_t> aligned_order(buffer_size);
std::vector<int64_t> cnt_deleted(buffer_size);
int64_t cnt = 0;
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
cnt++;
}
cnt_deleted[i] = i - cnt;
}
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
aligned_order[i] = i;
continue;
}
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
++j;
}
return aligned_order;
}
} // namespace utils
} // namespace transpose_sinking
} // namespace pass

View File

@ -11,8 +11,11 @@
#include "transformations/transpose_sinking/ts_data_movement.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "ts_test_utils.hpp"
using namespace std;
@ -183,6 +186,34 @@ private:
FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_reference) {
return std::make_shared<InterpolateFactory>(type_name, is_reference);
}
class SliceFactory : public IFactory {
public:
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Slice>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
parent_nodes[3],
parent_nodes[4]);
}
};
FactoryPtr CreateSliceFactory(const std::string& type_name) {
return std::make_shared<SliceFactory>(type_name);
}
class ReshapeFactory : public IFactory {
public:
explicit ReshapeFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Reshape>(parent_nodes[0], parent_nodes[1], false);
}
};
FactoryPtr CreateReshapeFactory(const std::string& type_name) {
return std::make_shared<ReshapeFactory>(type_name);
}
// ----------------------------------------------------------------------------
#undef CREATE_UNARY_FACTORY
@ -214,6 +245,12 @@ FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_refere
#undef CREATE_INTERPOLATE_FACTORY
#define CREATE_INTERPOLATE_FACTORY(type_name, reference_flag) CreateInterpolateFactory(#type_name, reference_flag)
#undef CREATE_SLICE_FACTORY
#define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name)
#undef CREATE_RESHAPE_FACTORY
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
// ----------------------------------------------------------------------------
struct Preprocessing {
@ -651,7 +688,7 @@ auto test_forward_squeeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
@ -685,7 +722,7 @@ auto test_forward_unsqueeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
@ -722,6 +759,128 @@ auto test_forward_unsqueeze = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TransposeSinkingTestFixture, test_forward_unsqueeze());
auto test_forward_slice = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSSliceForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(data, out, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{4}}};
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(Slice)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TransposeSinkingTestFixture, test_forward_slice());
auto test_forward_reshape_squeeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 1, 5, 1, 4}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
};
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
TransposeSinkingTestFixture,
test_forward_reshape_squeeze());
auto test_forward_reshape_unsqueeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 5, 4}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
};
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model.model_template = create_model;
// Reference model description:
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
auto order = make_shared<Constant>(element::i32, Shape{5}, std::vector<int64_t>{4, 1, 2, 3, 0});
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
return new_out_vec;
};
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
out_vec[1].get_shape(),
std::vector<int64_t>{6, 1, 5, 1, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model_ref.preprocess_outputs_of_main = {{new_transpose}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
TransposeSinkingTestFixture,
test_forward_reshape_unsqueeze());
// ------------------ BACKWARD --------------------
auto test_backward_unary = []() {
@ -1003,7 +1162,7 @@ auto test_backward_squeeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
@ -1036,7 +1195,7 @@ auto test_backward_unsqueeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
@ -1066,6 +1225,126 @@ auto test_backward_unsqueeze = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
TransposeSinkingTestFixture,
test_backward_unsqueeze());
auto test_backward_slice = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSSliceBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
test_case.model.main_op = {CREATE_SLICE_FACTORY(Slice)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(data, out, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {4}}};
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
auto test_backward_reshape_squeeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {4, 1, 5, 1, 6}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
};
// Test model description:
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
auto order = make_shared<Constant>(element::i32, Shape{5}, std::vector<int64_t>{4, 1, 2, 3, 0});
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
new_out_vec[1] = out_vec[1];
return new_out_vec;
};
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose, new_constant}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
TransposeSinkingTestFixture,
test_backward_reshape_squeeze());
auto test_backward_reshape_unsqueeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {4, 5, 6}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
};
// Test model description:
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
out_vec[1].get_shape(),
std::vector<int64_t>{6, 1, 5, 1, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
TransposeSinkingTestFixture,
test_backward_reshape_unsqueeze());
} // namespace common
} // namespace testing
} // namespace transpose_sinking
} // namespace transpose_sinking

View File

@ -281,7 +281,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
{
// perform transpose sinking and reverse infer if the model contains only OpenVINO operations
ov::pass::Manager manager;
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.run_passes(model);
}

View File

@ -268,7 +268,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
manager.register_pass<ov::frontend::tensorflow_lite::pass::TFLQuantizeResolver>();
manager.register_pass<ov::frontend::tensorflow_lite::pass::Rfft2dSimplifier>();
manager.register_pass<ov::pass::TransposeSinking>();
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.run_passes(function);
}