fix TS for Interpolate + codestyle

This commit is contained in:
Tikhonov Ivan
2023-03-02 12:24:12 +00:00
parent 123835c86d
commit 13f17d254b
8 changed files with 62 additions and 63 deletions

View File

@@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include <openvino/opsets/opset10.hpp> #include <openvino/opsets/opset10.hpp>
#include <openvino/pass/pattern/op/or.hpp> #include <openvino/pass/pattern/op/or.hpp>
@@ -9,7 +11,6 @@
#include "openvino/op/util/op_types.hpp" #include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern; using namespace ov::pass::pattern;

View File

@@ -23,7 +23,8 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa
MATCHER_SCOPE(TransposeSinkingDataMovementForward); MATCHER_SCOPE(TransposeSinkingDataMovementForward);
auto const_label = wrap_type<Constant>(); auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label}); auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label = wrap_type<Pad, BatchToSpace, SpaceToBatch>({transpose_label, any_input(), any_input(), any_input()}); auto main_node_label =
wrap_type<Pad, BatchToSpace, SpaceToBatch>({transpose_label, any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map(); const auto& pattern_to_node = m.get_pattern_map();
@@ -57,7 +58,7 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa
const auto& stb = std::dynamic_pointer_cast<SpaceToBatch>(main_node); const auto& stb = std::dynamic_pointer_cast<SpaceToBatch>(main_node);
if (bts || stb) { if (bts || stb) {
main_node->input(3).replace_source_output( main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), reversed_transpose_order, axis)); ChangeValuesOrder(main_node->input_value(3), reversed_transpose_order, axis));
} }
main_node->validate_and_infer_types(); main_node->validate_and_infer_types();
@@ -114,7 +115,7 @@ ov::pass::TransposeSinkingDataMovementBackward::TransposeSinkingDataMovementBack
const auto& stb = std::dynamic_pointer_cast<SpaceToBatch>(main_node); const auto& stb = std::dynamic_pointer_cast<SpaceToBatch>(main_node);
if (bts || stb) { if (bts || stb) {
main_node->input(3).replace_source_output( main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), transpose_axis_order, axis)); ChangeValuesOrder(main_node->input_value(3), transpose_axis_order, axis));
} }
main_node->validate_and_infer_types(); main_node->validate_and_infer_types();
return true; return true;

View File

@@ -8,9 +8,9 @@
#include <vector> #include <vector>
#include "itt.hpp" #include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp" #include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/core/validation_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
@@ -19,13 +19,11 @@ using namespace opset10;
ov::pass::TransposeFuse::TransposeFuse() { ov::pass::TransposeFuse::TransposeFuse() {
MATCHER_SCOPE(TransposeFuse); MATCHER_SCOPE(TransposeFuse);
auto transpose_label = auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()});
pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map(); const auto& pattern_to_output = m.get_pattern_map();
auto transpose_1 = pattern_to_output.at(transpose_label); auto transpose_1 = pattern_to_output.at(transpose_label);
auto order_const_1 = auto order_const_1 = std::dynamic_pointer_cast<Constant>(transpose_1->input_value(1).get_node_shared_ptr());
std::dynamic_pointer_cast<Constant>(transpose_1->input_value(1).get_node_shared_ptr());
auto consumers = transpose_1->get_output_target_inputs(0); auto consumers = transpose_1->get_output_target_inputs(0);
std::vector<int64_t> saved_order_values; std::vector<int64_t> saved_order_values;

View File

@@ -14,6 +14,7 @@
#include "transformations/common_optimizations/transpose_sinking_binary.hpp" #include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp" #include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp" #include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp" #include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp" #include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
@@ -26,6 +27,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
add_matcher<ov::pass::TransposeSinkingSplitForward>(); add_matcher<ov::pass::TransposeSinkingSplitForward>();
add_matcher<ov::pass::TransposeSinkingDataMovementForward>(); add_matcher<ov::pass::TransposeSinkingDataMovementForward>();
add_matcher<ov::pass::TransposeSinkingReductionForward>(); add_matcher<ov::pass::TransposeSinkingReductionForward>();
add_matcher<ov::pass::TransposeSinkingInterpolateForward>();
add_matcher<ov::pass::TransposeFuse>(); add_matcher<ov::pass::TransposeFuse>();
} }
@@ -37,6 +39,7 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
add_matcher<ov::pass::TransposeSinkingSplitBackward>(); add_matcher<ov::pass::TransposeSinkingSplitBackward>();
add_matcher<ov::pass::TransposeSinkingDataMovementBackward>(); add_matcher<ov::pass::TransposeSinkingDataMovementBackward>();
add_matcher<ov::pass::TransposeSinkingReductionBackward>(); add_matcher<ov::pass::TransposeSinkingReductionBackward>();
add_matcher<ov::pass::TransposeSinkingInterpolateBackward>();
add_matcher<ov::pass::TransposeFuse>(); add_matcher<ov::pass::TransposeFuse>();
} }

View File

@@ -44,12 +44,11 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
main_node->input(0).replace_source_output(transpose_parent); main_node->input(0).replace_source_output(transpose_parent);
const auto transpose_axis_order = transpose_const->get_axis_vector_val(); const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0}); auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
const auto& interpolate = std::dynamic_pointer_cast<Interpolate>(main_node); const auto& interpolate = std::dynamic_pointer_cast<Interpolate>(main_node);
auto data = std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order); auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
const auto& indices = main_node->input_value(3); const auto& indices = main_node->input_value(3);
auto new_axis = std::make_shared<Gather>(data, indices, axis); auto new_axis = std::make_shared<Gather>(data, indices, axis);
@@ -57,13 +56,13 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
if (interpolate) { if (interpolate) {
op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs(); op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs();
if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) { if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) {
const auto &order_size = reversed_transpose_order.size(); const auto& order_size = transpose_axis_order.size();
attrs.pads_begin.resize(order_size); attrs.pads_begin.resize(order_size);
attrs.pads_end.resize(order_size); attrs.pads_end.resize(order_size);
std::vector<size_t> new_pads_begin(order_size), new_pads_end(order_size); std::vector<size_t> new_pads_begin(order_size), new_pads_end(order_size);
for (size_t i = 0; i < order_size; ++i) { for (size_t i = 0; i < order_size; ++i) {
new_pads_begin[i] = attrs.pads_begin[reversed_transpose_order[i]]; new_pads_begin[i] = attrs.pads_begin[transpose_axis_order[i]];
new_pads_end[i] = attrs.pads_end[reversed_transpose_order[i]]; new_pads_end[i] = attrs.pads_end[transpose_axis_order[i]];
} }
std::swap(attrs.pads_begin, new_pads_begin); std::swap(attrs.pads_begin, new_pads_begin);
std::swap(attrs.pads_end, new_pads_end); std::swap(attrs.pads_end, new_pads_end);
@@ -71,7 +70,6 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
} }
} }
main_node->validate_and_infer_types(); main_node->validate_and_infer_types();
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
@@ -95,9 +93,9 @@ ov::pass::TransposeSinkingInterpolateBackward::TransposeSinkingInterpolateBackwa
auto transpose_const_label = wrap_type<Constant>(); auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool { wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output); return has_static_rank()(output) && is_sinking_node(output);
}); });
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
@@ -107,7 +105,7 @@ ov::pass::TransposeSinkingInterpolateBackward::TransposeSinkingInterpolateBackwa
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const, transpose_const,
/* input_indexes= */ {0})) { /* input_indexes= */ {0})) {
register_new_node(new_node); register_new_node(new_node);
} }
@@ -115,21 +113,23 @@ ov::pass::TransposeSinkingInterpolateBackward::TransposeSinkingInterpolateBackwa
RemoveSingleOutputConsumers(main_node); RemoveSingleOutputConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val(); const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0}); auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order); auto data =
std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
const auto& indices = main_node->input_value(3); const auto& indices = main_node->input_value(3);
auto new_axis = std::make_shared<Gather>(data, indices, axis); auto new_axis = std::make_shared<Gather>(data, indices, axis);
const auto& interpolate = std::dynamic_pointer_cast<Interpolate>(main_node); const auto& interpolate = std::dynamic_pointer_cast<Interpolate>(main_node);
if (interpolate) { if (interpolate) {
op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs(); op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs();
if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) { if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) {
const auto &order_size = transpose_axis_order.size(); const auto& order_size = reversed_transpose_order.size();
attrs.pads_begin.resize(order_size); attrs.pads_begin.resize(order_size);
attrs.pads_end.resize(order_size); attrs.pads_end.resize(order_size);
std::vector<size_t> new_pads_begin(order_size), new_pads_end(order_size); std::vector<size_t> new_pads_begin(order_size), new_pads_end(order_size);
for (size_t i = 0; i < order_size; ++i) { for (size_t i = 0; i < order_size; ++i) {
new_pads_begin[i] = attrs.pads_begin[transpose_axis_order[i]]; new_pads_begin[i] = attrs.pads_begin[reversed_transpose_order[i]];
new_pads_end[i] = attrs.pads_end[transpose_axis_order[i]]; new_pads_end[i] = attrs.pads_end[reversed_transpose_order[i]];
} }
std::swap(attrs.pads_begin, new_pads_begin); std::swap(attrs.pads_begin, new_pads_begin);
std::swap(attrs.pads_end, new_pads_end); std::swap(attrs.pads_end, new_pads_end);

View File

@@ -8,9 +8,9 @@
#include <vector> #include <vector>
#include "itt.hpp" #include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp" #include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/core/validation_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
@@ -77,19 +77,16 @@ bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
keep_dims = arithmetic_reduce->get_keep_dims(); keep_dims = arithmetic_reduce->get_keep_dims();
return keep_dims; return keep_dims;
} }
} } // namespace
ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() { ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
MATCHER_SCOPE(TransposeSinkingReductionForward); MATCHER_SCOPE(TransposeSinkingReductionForward);
auto transpose_label = auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()}, pattern::consumers_count(1));
pattern::consumers_count(1)); auto reduce_or_squeeze_label = pattern::
auto reduce_or_squeeze_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, {transpose_label, pattern::wrap_type<Constant>()});
op::util::LogicalReductionKeepDims,
Squeeze,
Unsqueeze>({transpose_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
@@ -104,9 +101,9 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
return false; return false;
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction); auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
auto rank = auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank(); unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto non_negative_axes = auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank); normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>(); auto transpose_order_values = transpose_order->cast_vector<size_t>();
std::vector<size_t> new_values; std::vector<size_t> new_values;
@@ -136,14 +133,13 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
} }
} }
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(), auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()}, Shape{transpose_order_values.size()},
transpose_order_values); transpose_order_values);
std::shared_ptr<Node> new_reduction; std::shared_ptr<Node> new_reduction;
if (!unsqueeze) { if (!unsqueeze) {
auto new_const = std::make_shared<Constant>(reduction_axes->get_element_type(), auto new_const =
reduction_axes->get_shape(), std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
new_values);
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const}); new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
} else { } else {
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)}); new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
@@ -165,14 +161,11 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() { ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() {
MATCHER_SCOPE(TransposeSinkingReductionBackward); MATCHER_SCOPE(TransposeSinkingReductionBackward);
auto reduce_or_squeeze_label = auto reduce_or_squeeze_label = pattern::
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
op::util::LogicalReductionKeepDims, {pattern::any_input(), pattern::wrap_type<Constant>()},
Squeeze, transpose_sinking::HasSameOutputTransposeNodes);
Unsqueeze>({pattern::any_input(), pattern::wrap_type<Constant>()}, auto transpose_label = pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
transpose_sinking::HasSameOutputTransposeNodes);
auto transpose_label =
pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
@@ -186,9 +179,9 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward()
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction); auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
auto rank = auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank(); unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto non_negative_axes = auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank); normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>(); auto transpose_order_values = transpose_order->cast_vector<size_t>();
auto old_transpose_order_values = transpose_order_values; auto old_transpose_order_values = transpose_order_values;
@@ -249,9 +242,8 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward()
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction}); copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
register_new_node(new_transpose); register_new_node(new_transpose);
} else { } else {
auto new_const = std::make_shared<Constant>(reduction_axes->get_element_type(), auto new_const =
reduction_axes->get_shape(), std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
new_values);
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order}); auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const}); auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
replace_node(transpose, new_reduction); replace_node(transpose, new_reduction);

View File

@@ -55,7 +55,9 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
MATCHER_SCOPE(TransposeSinkingUnaryForward); MATCHER_SCOPE(TransposeSinkingUnaryForward);
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()}); auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
auto unary_label = wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({transpose_label}); auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
{transpose_label});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
@@ -90,7 +92,8 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
auto unary_label = auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>( wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
{any_input()}, unary_restrictions); {any_input()},
unary_restrictions);
auto transpose_const_label = wrap_type<Constant>(); auto transpose_const_label = wrap_type<Constant>();

View File

@@ -292,14 +292,15 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadForwardSingleConsumerTestSuite,
::testing::Values(element::f32)), ::testing::Values(element::f32)),
TransposeSinkingPadTestFixture::get_test_name); TransposeSinkingPadTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadBackwardSingleConsumerTestSuite, INSTANTIATE_TEST_SUITE_P(
TransposeSinkingPadTestFixture, TransposeSinkingPadBackwardSingleConsumerTestSuite,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward)), TransposeSinkingPadTestFixture,
::testing::ValuesIn(pad_operations_numbers), ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward)),
::testing::Values(backward::single_consumer::CreateFunction), ::testing::ValuesIn(pad_operations_numbers),
::testing::Values(backward::single_consumer::CreateReferenceFunction), ::testing::Values(backward::single_consumer::CreateFunction),
::testing::Values(element::f32)), ::testing::Values(backward::single_consumer::CreateReferenceFunction),
TransposeSinkingPadTestFixture::get_test_name); ::testing::Values(element::f32)),
TransposeSinkingPadTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
TransposeSinkingPadBackwardSingleConsumerMultiTransposesTestSuite, TransposeSinkingPadBackwardSingleConsumerMultiTransposesTestSuite,