Support TS for Interpolate, VariadicSplit, IsInf, IsNan, IsFinite + refactoring

This commit is contained in:
Ivan 2023-03-02 19:09:36 +04:00
parent b769d21912
commit 123835c86d
27 changed files with 745 additions and 694 deletions

View File

@ -7,6 +7,8 @@
#include <memory>
#include <openvino/pass/graph_rewrite.hpp>
#include <openvino/pass/pattern/matcher.hpp>
#include <transformations/common_optimizations/transpose_sinking_fuse.hpp>
#include <transformations/common_optimizations/transpose_sinking_reduction.hpp>
#include <transformations_visibility.hpp>
#include <vector>
@ -16,34 +18,11 @@ namespace pass {
class TRANSFORMATIONS_API TransposeSinking;
class TRANSFORMATIONS_API TransposeConvert;
class TRANSFORMATIONS_API TransposeEltwise;
class TRANSFORMATIONS_API TransposeReduction;
class TRANSFORMATIONS_API TransposeReductionBackward;
class TRANSFORMATIONS_API TransposeFQReduction;
class TRANSFORMATIONS_API TransposeFuse;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
*/
class ov::pass::TransposeReduction : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeReduction", "0");
TransposeReduction();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
*/
class ov::pass::TransposeReductionBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeReductionBackward", "0");
TransposeReductionBackward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeFQReduction transformation sinks Transpose through FakeQuantize in case it is followed by reduction
@ -75,17 +54,6 @@ public:
TransposeEltwise();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeFuse transformation eliminates 2 consequtive Transposes if they result in no changes to input or
* fuses them to single Transpose if input gets changed
*/
class ov::pass::TransposeFuse : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeFuse", "0");
TransposeFuse();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinking transformation sinks Transposes through known operations
@ -95,7 +63,7 @@ public:
OPENVINO_RTTI("TransposeSinking", "0");
TransposeSinking() {
add_matcher<ov::pass::TransposeFQReduction>();
add_matcher<ov::pass::TransposeReduction>();
add_matcher<ov::pass::TransposeSinkingReductionForward>();
add_matcher<ov::pass::TransposeConvert>();
add_matcher<ov::pass::TransposeEltwise>();
add_matcher<ov::pass::TransposeFuse>();

View File

@ -1,30 +0,0 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingBatchToSpaceForward;
class TRANSFORMATIONS_API TransposeSinkingBatchToSpaceBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingBatchToSpaceForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBTSForward", "0");
TransposeSinkingBatchToSpaceForward();
};
class ov::pass::TransposeSinkingBatchToSpaceBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBTSBackward", "0");
TransposeSinkingBatchToSpaceBackward();
};

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingDataMovementForward;
class TRANSFORMATIONS_API TransposeSinkingDataMovementBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingDataMovementForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementForward", "0");
TransposeSinkingDataMovementForward();
};
class ov::pass::TransposeSinkingDataMovementBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementBackward", "0");
TransposeSinkingDataMovementBackward();
};

View File

@ -0,0 +1,28 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeFuse;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input or
* fuses them to single Transpose if input gets changed
*/
class ov::pass::TransposeFuse : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeFuse", "0");
TransposeFuse();
};

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingInterpolateForward;
class TRANSFORMATIONS_API TransposeSinkingInterpolateBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingInterpolateForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateForward", "0");
TransposeSinkingInterpolateForward();
};
class ov::pass::TransposeSinkingInterpolateBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateBackward", "0");
TransposeSinkingInterpolateBackward();
};

View File

@ -1,30 +0,0 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingPadForward;
class TRANSFORMATIONS_API TransposeSinkingPadBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingPadForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingPadForward", "0");
TransposeSinkingPadForward();
};
class ov::pass::TransposeSinkingPadBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingPadBackward", "0");
TransposeSinkingPadBackward();
};

View File

@ -0,0 +1,38 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingReductionForward;
class TRANSFORMATIONS_API TransposeSinkingReductionBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionForward transformation sinks Transpose through Reduce operations
*/
class ov::pass::TransposeSinkingReductionForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingReductionForward", "0");
TransposeSinkingReductionForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce operations
*/
class ov::pass::TransposeSinkingReductionBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingReductionBackward", "0");
TransposeSinkingReductionBackward();
};

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -11,8 +11,6 @@ namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingUnaryForward;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardSingleConsumer;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardMultiConsumers;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
} // namespace pass
@ -24,20 +22,8 @@ public:
TransposeSinkingUnaryForward();
};
class ov::pass::TransposeSinkingUnaryBackwardSingleConsumer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardSingleConsumer", "0");
TransposeSinkingUnaryBackwardSingleConsumer();
};
class ov::pass::TransposeSinkingUnaryBackwardMultiConsumers : public ov::pass::MatcherPass {
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");
TransposeSinkingUnaryBackwardMultiConsumers();
};
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackward", "0");
TransposeSinkingUnaryBackward();
};

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

View File

@ -20,65 +20,6 @@
using namespace ov;
namespace {
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() - axes_values.size();
std::vector<size_t> aligned_order(buffer_size, 0);
std::vector<size_t> values_to_reduce(axes_values);
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
values_to_reduce[i] = order_values[axes_values[i]];
}
std::sort(values_to_reduce.begin(), values_to_reduce.end());
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
continue;
}
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
++j;
}
return aligned_order;
}
std::vector<size_t> get_updated_order_backward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() + axes_values.size();
std::vector<size_t> aligned_order(buffer_size);
std::vector<int64_t> cnt_deleted(buffer_size);
int64_t cnt = 0;
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
cnt++;
}
cnt_deleted[i] = i - cnt;
}
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
aligned_order[i] = i;
continue;
}
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
++j;
}
return aligned_order;
}
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
// auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
bool keep_dims = false; // squeeze always reduces number of output dimensions
if (logical_reduce)
keep_dims = logical_reduce->get_keep_dims();
else if (arithmetic_reduce)
keep_dims = arithmetic_reduce->get_keep_dims();
return keep_dims;
}
std::shared_ptr<opset6::Constant> get_reversed_order_constant(const std::shared_ptr<opset6::Constant>& order_const) {
const auto& order = order_const->cast_vector<size_t>();
@ -168,194 +109,6 @@ ov::pass::TransposeConvert::TransposeConvert() {
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeReductionBackward::TransposeReductionBackward() {
MATCHER_SCOPE(TransposeReductionBackward);
auto reduce_or_squeeze_label =
pattern::wrap_type<op::util::ArithmeticReductionKeepDims,
op::util::LogicalReductionKeepDims,
opset6::Squeeze,
opset6::Unsqueeze>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()},
transpose_sinking::HasSameOutputTransposeNodes);
auto transpose_label =
pattern::wrap_type<opset6::Transpose>({reduce_or_squeeze_label, pattern::wrap_type<opset6::Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<opset6::Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto unsqueeze = std::dynamic_pointer_cast<opset6::Unsqueeze>(reduction);
auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>();
auto old_transpose_order_values = transpose_order_values;
std::vector<size_t> new_values;
if (unsqueeze) {
if (non_negative_axes.size() == transpose_order_values.size()) {
// input is a scalar, we unsqueeze all dims
// it's enough to eliminate such Transpose
transpose->output(0).replace(reduction);
return true;
}
for (const auto& axis : non_negative_axes) {
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
if (it != old_transpose_order_values.end()) {
new_values.push_back(it - old_transpose_order_values.begin());
}
}
}
bool special_case = false;
if (!keep_dims) {
if (non_negative_axes.empty()) {
auto input_pshape = reduction->input_value(0).get_partial_shape();
if (input_pshape.is_static()) {
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i] == 1) {
non_negative_axes.push_back(i);
}
}
special_case = true;
} else {
return false;
}
}
if (unsqueeze) {
transpose_order_values = get_updated_order_forward(new_values, transpose_order_values);
} else {
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
}
}
if (!unsqueeze) {
auto reversed_order_values = transpose_sinking::ReverseTransposeOrder(transpose_order_values);
for (const auto& axis : non_negative_axes) {
new_values.push_back(reversed_order_values[axis]);
}
}
auto new_transpose_order = std::make_shared<opset6::Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
if (special_case) {
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, reduction->input_value(1)});
new_reduction->set_friendly_name(transpose->get_friendly_name());
replace_node(transpose, new_reduction);
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
register_new_node(new_transpose);
} else {
auto new_const = std::make_shared<opset6::Constant>(reduction_axes->get_element_type(),
reduction_axes->get_shape(),
new_values);
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
replace_node(transpose, new_reduction);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
register_new_node(new_transpose);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeReduction::TransposeReduction() {
MATCHER_SCOPE(TransposeReduction);
auto transpose_label =
pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()},
pattern::consumers_count(1));
auto reduce_or_squeeze_label =
pattern::wrap_type<op::util::ArithmeticReductionKeepDims,
op::util::LogicalReductionKeepDims,
opset6::Squeeze,
opset6::Unsqueeze>({transpose_label, pattern::wrap_type<opset6::Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<opset6::Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto unsqueeze = std::dynamic_pointer_cast<opset6::Unsqueeze>(reduction);
auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>();
std::vector<size_t> new_values;
new_values.reserve(non_negative_axes.size());
for (const auto& axis : non_negative_axes) {
new_values.push_back(transpose_order_values[axis]);
}
if (!keep_dims) {
if (non_negative_axes.empty()) {
auto input_pshape = transpose->input_value(0).get_partial_shape();
if (input_pshape.is_static()) {
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i] == 1) {
non_negative_axes.push_back(i);
}
}
} else {
return false;
}
}
if (unsqueeze) {
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
} else {
transpose_order_values = get_updated_order_forward(non_negative_axes, transpose_order_values);
}
}
auto new_transpose_order = std::make_shared<opset6::Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
std::shared_ptr<Node> new_reduction;
if (!unsqueeze) {
auto new_const = std::make_shared<opset6::Constant>(reduction_axes->get_element_type(),
reduction_axes->get_shape(),
new_values);
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
}
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
replace_node(reduction, new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(reduction->get_friendly_name());
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeFQReduction::TransposeFQReduction() {
MATCHER_SCOPE(TransposeFQReduction);
@ -423,80 +176,3 @@ ov::pass::TransposeFQReduction::TransposeFQReduction() {
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeFuse::TransposeFuse() {
MATCHER_SCOPE(TransposeFuse);
auto transpose_label =
pattern::wrap_type<opset7::Transpose>({pattern::any_input(), pattern::wrap_type<opset7::Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose_1 = pattern_to_output.at(transpose_label);
auto order_const_1 =
std::dynamic_pointer_cast<opset7::Constant>(transpose_1->input_value(1).get_node_shared_ptr());
auto consumers = transpose_1->get_output_target_inputs(0);
std::vector<int64_t> saved_order_values;
auto saved_type = order_const_1->get_element_type();
for (const auto& it : consumers) {
auto out_transpose = dynamic_cast<opset7::Transpose*>(it.get_node());
if (!out_transpose) {
return false;
}
auto order = out_transpose->input_value(1).get_node_shared_ptr();
auto order_const = std::dynamic_pointer_cast<opset7::Constant>(order);
if (!order_const) {
return false;
}
auto order_values = order_const->cast_vector<int64_t>();
if (order_values.empty()) {
return false;
}
if (saved_order_values.empty()) {
saved_order_values = order_values;
} else {
if (saved_order_values != order_values) {
return false;
}
}
if (order_const->get_element_type() != saved_type) {
saved_type = element::i64;
}
}
auto order1 = order_const_1->cast_vector<int64_t>();
if (order1.size() != saved_order_values.size()) {
return false;
}
bool is_ordered = true;
for (size_t i = 0; i < order1.size(); i++) {
saved_order_values[i] = order1[saved_order_values[i]];
if (saved_order_values[i] != (int64_t)i)
is_ordered = false;
}
if (is_ordered) {
for (const auto& it : consumers) {
it.get_node()->output(0).replace(transpose_1->input_value(0));
}
} else {
auto new_order = opset7::Constant::create(saved_type, {saved_order_values.size()}, saved_order_values);
auto new_transpose = register_new_node<opset7::Transpose>(transpose_1->input_value(0), new_order);
for (const auto& it : consumers) {
new_transpose->set_friendly_name(it.get_node()->get_friendly_name());
it.get_node()->output(0).replace(new_transpose);
copy_runtime_info(transpose_1, new_transpose);
}
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -1,118 +0,0 @@
#include "transformations/common_optimizations/transpose_sinking_batch_to_space.hpp"
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset9;
using namespace transpose_sinking;
ov::pass::TransposeSinkingBatchToSpaceForward::TransposeSinkingBatchToSpaceForward() {
MATCHER_SCOPE(TransposeSinkingBatchToSpaceForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label =
wrap_type<BatchToSpace, SpaceToBatch>({transpose_label, any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose) {
return false;
}
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
if (!transpose_const) {
return false;
}
// remove Transpose on 1st input:
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
main_node->input(0).replace_source_output(transpose_parent);
// change the order of values for PadBegin and PadEng inputs
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
main_node->input(1).replace_source_output(
ChangeValuesOrder(main_node->input_value(1), reversed_transpose_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis));
main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), reversed_transpose_order, axis));
main_node->validate_and_infer_types();
// insert Transpose for Pad output
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingBatchToSpaceBackward::TransposeSinkingBatchToSpaceBackward() {
MATCHER_SCOPE(TransposeSinkingBatchToSpaceBackward);
auto main_node_label = wrap_type<BatchToSpace, SpaceToBatch>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
main_node->input(1).replace_source_output(
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), transpose_axis_order, axis));
main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -1,23 +1,20 @@
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/opsets/opset9.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset9;
using namespace ov::opset10;
using namespace transpose_sinking;
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {

View File

@ -1,23 +1,19 @@
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset9;
using namespace ov::opset10;
using namespace transpose_sinking;
ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
@ -89,9 +85,9 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
register_new_node(new_node);
}
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
const auto reversed_transpose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const auto transposed_concat_axis = reversed_transpose_axis_order[concat_axis];
concat_node->set_axis(static_cast<int64_t>(transposed_concat_axis));
concat_node->set_concatenation_axis(-1);
concat_node->validate_and_infer_types();
// remove output transposes

View File

@ -1,4 +1,8 @@
#include "transformations/common_optimizations/transpose_sinking_pad.hpp"
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include <openvino/pass/pattern/op/or.hpp>
@ -10,16 +14,16 @@
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace transpose_sinking;
ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
MATCHER_SCOPE(TransposeSinkingPadForward);
ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForward() {
MATCHER_SCOPE(TransposeSinkingDataMovementForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label = wrap_type<Pad>({transpose_label, any_input(), any_input(), any_input()});
auto main_node_label = wrap_type<Pad, BatchToSpace, SpaceToBatch>({transpose_label, any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
@ -49,8 +53,14 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis));
const auto& bts = std::dynamic_pointer_cast<BatchToSpace>(main_node);
const auto& stb = std::dynamic_pointer_cast<SpaceToBatch>(main_node);
if (bts || stb) {
main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), reversed_transpose_order, axis));
}
main_node->validate_and_infer_types();
// insert Transpose for Pad output
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
@ -63,10 +73,10 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() {
MATCHER_SCOPE(TransposeSinkingPadBackward);
ov::pass::TransposeSinkingDataMovementBackward::TransposeSinkingDataMovementBackward() {
MATCHER_SCOPE(TransposeSinkingDataMovementBackward);
auto main_node_label = wrap_type<Pad>([](const Output<Node>& output) -> bool {
auto main_node_label = wrap_type<Pad, BatchToSpace, SpaceToBatch>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
@ -99,6 +109,13 @@ ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() {
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
const auto& bts = std::dynamic_pointer_cast<BatchToSpace>(main_node);
const auto& stb = std::dynamic_pointer_cast<SpaceToBatch>(main_node);
if (bts || stb) {
main_node->input(3).replace_source_output(
ChangeValuesOrder(main_node->input_value(3), transpose_axis_order, axis));
}
main_node->validate_and_infer_types();
return true;
};

View File

@ -0,0 +1,95 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
#include <memory>
#include <vector>
#include "itt.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/core/validation_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace opset10;
ov::pass::TransposeFuse::TransposeFuse() {
MATCHER_SCOPE(TransposeFuse);
auto transpose_label =
pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose_1 = pattern_to_output.at(transpose_label);
auto order_const_1 =
std::dynamic_pointer_cast<Constant>(transpose_1->input_value(1).get_node_shared_ptr());
auto consumers = transpose_1->get_output_target_inputs(0);
std::vector<int64_t> saved_order_values;
auto saved_type = order_const_1->get_element_type();
for (const auto& it : consumers) {
auto out_transpose = dynamic_cast<Transpose*>(it.get_node());
if (!out_transpose) {
return false;
}
auto order = out_transpose->input_value(1).get_node_shared_ptr();
auto order_const = std::dynamic_pointer_cast<Constant>(order);
if (!order_const) {
return false;
}
auto order_values = order_const->cast_vector<int64_t>();
if (order_values.empty()) {
return false;
}
if (saved_order_values.empty()) {
saved_order_values = order_values;
} else {
if (saved_order_values != order_values) {
return false;
}
}
if (order_const->get_element_type() != saved_type) {
saved_type = element::i64;
}
}
auto order1 = order_const_1->cast_vector<int64_t>();
if (order1.size() != saved_order_values.size()) {
return false;
}
bool is_ordered = true;
for (size_t i = 0; i < order1.size(); i++) {
saved_order_values[i] = order1[saved_order_values[i]];
if (saved_order_values[i] != (int64_t)i)
is_ordered = false;
}
if (is_ordered) {
for (const auto& it : consumers) {
it.get_node()->output(0).replace(transpose_1->input_value(0));
}
} else {
auto new_order = Constant::create(saved_type, {saved_order_values.size()}, saved_order_values);
auto new_transpose = register_new_node<Transpose>(transpose_1->input_value(0), new_order);
for (const auto& it : consumers) {
new_transpose->set_friendly_name(it.get_node()->get_friendly_name());
it.get_node()->output(0).replace(new_transpose);
copy_runtime_info(transpose_1, new_transpose);
}
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
}
return true;
};
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -1,21 +1,19 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/graph_rewrite.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include "itt.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/transpose_sinking_batch_to_space.hpp"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_pad.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/utils/utils.hpp"
@ -26,9 +24,8 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
add_matcher<ov::pass::TransposeSinkingBinaryForward>();
add_matcher<ov::pass::TransposeSinkingConcatForward>();
add_matcher<ov::pass::TransposeSinkingSplitForward>();
add_matcher<ov::pass::TransposeSinkingPadForward>();
add_matcher<ov::pass::TransposeReduction>();
add_matcher<ov::pass::TransposeSinkingBatchToSpaceForward>();
add_matcher<ov::pass::TransposeSinkingDataMovementForward>();
add_matcher<ov::pass::TransposeSinkingReductionForward>();
add_matcher<ov::pass::TransposeFuse>();
}
@ -38,25 +35,24 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
add_matcher<ov::pass::TransposeSinkingBinaryBackward>();
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
add_matcher<ov::pass::TransposeSinkingPadBackward>();
add_matcher<ov::pass::TransposeReductionBackward>();
add_matcher<ov::pass::TransposeSinkingBatchToSpaceBackward>();
add_matcher<ov::pass::TransposeSinkingDataMovementBackward>();
add_matcher<ov::pass::TransposeSinkingReductionBackward>();
add_matcher<ov::pass::TransposeFuse>();
}
bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral);
{
ngraph::pass::Manager manager(get_pass_config());
ov::pass::Manager manager(get_pass_config());
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ov::pass::ConstantFolding>();
manager.run_passes(f);
}
{
ngraph::pass::Manager manager(get_pass_config());
ov::pass::Manager manager(get_pass_config());
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ov::pass::ConstantFolding>();
manager.run_passes(f);
}

View File

@ -0,0 +1,145 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace transpose_sinking;
ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward() {
MATCHER_SCOPE(TransposeSinkingInterpolateForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label = wrap_type<Interpolate>({transpose_label, any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose) {
return false;
}
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
if (!transpose_const) {
return false;
}
// remove Transpose on 1st input:
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
main_node->input(0).replace_source_output(transpose_parent);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
const auto& interpolate = std::dynamic_pointer_cast<Interpolate>(main_node);
auto data = std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
const auto& indices = main_node->input_value(3);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(3).replace_source_output(new_axis);
if (interpolate) {
op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs();
if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) {
const auto &order_size = reversed_transpose_order.size();
attrs.pads_begin.resize(order_size);
attrs.pads_end.resize(order_size);
std::vector<size_t> new_pads_begin(order_size), new_pads_end(order_size);
for (size_t i = 0; i < order_size; ++i) {
new_pads_begin[i] = attrs.pads_begin[reversed_transpose_order[i]];
new_pads_end[i] = attrs.pads_end[reversed_transpose_order[i]];
}
std::swap(attrs.pads_begin, new_pads_begin);
std::swap(attrs.pads_end, new_pads_end);
interpolate->set_attrs(attrs);
}
}
main_node->validate_and_infer_types();
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingInterpolateBackward::TransposeSinkingInterpolateBackward() {
MATCHER_SCOPE(TransposeSinkingInterpolateBackward);
auto main_node_label = wrap_type<Interpolate>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
const auto& indices = main_node->input_value(3);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
const auto& interpolate = std::dynamic_pointer_cast<Interpolate>(main_node);
if (interpolate) {
op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs();
if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) {
const auto &order_size = transpose_axis_order.size();
attrs.pads_begin.resize(order_size);
attrs.pads_end.resize(order_size);
std::vector<size_t> new_pads_begin(order_size), new_pads_end(order_size);
for (size_t i = 0; i < order_size; ++i) {
new_pads_begin[i] = attrs.pads_begin[transpose_axis_order[i]];
new_pads_end[i] = attrs.pads_end[transpose_axis_order[i]];
}
std::swap(attrs.pads_begin, new_pads_begin);
std::swap(attrs.pads_end, new_pads_end);
interpolate->set_attrs(attrs);
}
}
main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,268 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include <memory>
#include <vector>
#include "itt.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/core/validation_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace opset10;
namespace {
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() - axes_values.size();
std::vector<size_t> aligned_order(buffer_size, 0);
std::vector<size_t> values_to_reduce(axes_values);
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
values_to_reduce[i] = order_values[axes_values[i]];
}
std::sort(values_to_reduce.begin(), values_to_reduce.end());
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
continue;
}
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
++j;
}
return aligned_order;
}
std::vector<size_t> get_updated_order_backward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() + axes_values.size();
std::vector<size_t> aligned_order(buffer_size);
std::vector<int64_t> cnt_deleted(buffer_size);
int64_t cnt = 0;
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
cnt++;
}
cnt_deleted[i] = i - cnt;
}
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
aligned_order[i] = i;
continue;
}
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
++j;
}
return aligned_order;
}
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
// auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
bool keep_dims = false; // squeeze always reduces number of output dimensions
if (logical_reduce)
keep_dims = logical_reduce->get_keep_dims();
else if (arithmetic_reduce)
keep_dims = arithmetic_reduce->get_keep_dims();
return keep_dims;
}
}
ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
MATCHER_SCOPE(TransposeSinkingReductionForward);
auto transpose_label =
pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
pattern::consumers_count(1));
auto reduce_or_squeeze_label =
pattern::wrap_type<op::util::ArithmeticReductionKeepDims,
op::util::LogicalReductionKeepDims,
Squeeze,
Unsqueeze>({transpose_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>();
std::vector<size_t> new_values;
new_values.reserve(non_negative_axes.size());
for (const auto& axis : non_negative_axes) {
new_values.push_back(transpose_order_values[axis]);
}
if (!keep_dims) {
if (non_negative_axes.empty()) {
auto input_pshape = transpose->input_value(0).get_partial_shape();
if (input_pshape.is_static()) {
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i] == 1) {
non_negative_axes.push_back(i);
}
}
} else {
return false;
}
}
if (unsqueeze) {
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
} else {
transpose_order_values = get_updated_order_forward(non_negative_axes, transpose_order_values);
}
}
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
std::shared_ptr<Node> new_reduction;
if (!unsqueeze) {
auto new_const = std::make_shared<Constant>(reduction_axes->get_element_type(),
reduction_axes->get_shape(),
new_values);
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
}
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
replace_node(reduction, new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(reduction->get_friendly_name());
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
return true;
};
auto m = std::make_shared<pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() {
MATCHER_SCOPE(TransposeSinkingReductionBackward);
auto reduce_or_squeeze_label =
pattern::wrap_type<op::util::ArithmeticReductionKeepDims,
op::util::LogicalReductionKeepDims,
Squeeze,
Unsqueeze>({pattern::any_input(), pattern::wrap_type<Constant>()},
transpose_sinking::HasSameOutputTransposeNodes);
auto transpose_label =
pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
auto rank =
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>();
auto old_transpose_order_values = transpose_order_values;
std::vector<size_t> new_values;
if (unsqueeze) {
if (non_negative_axes.size() == transpose_order_values.size()) {
// input is a scalar, we unsqueeze all dims
// it's enough to eliminate such Transpose
transpose->output(0).replace(reduction);
return true;
}
for (const auto& axis : non_negative_axes) {
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
if (it != old_transpose_order_values.end()) {
new_values.push_back(it - old_transpose_order_values.begin());
}
}
}
bool special_case = false;
if (!keep_dims) {
if (non_negative_axes.empty()) {
auto input_pshape = reduction->input_value(0).get_partial_shape();
if (input_pshape.is_static()) {
for (size_t i = 0; i < input_pshape.size(); ++i) {
if (input_pshape[i] == 1) {
non_negative_axes.push_back(i);
}
}
special_case = true;
} else {
return false;
}
}
if (unsqueeze) {
transpose_order_values = get_updated_order_forward(new_values, transpose_order_values);
} else {
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
}
}
if (!unsqueeze) {
auto reversed_order_values = transpose_sinking::ReverseTransposeOrder(transpose_order_values);
for (const auto& axis : non_negative_axes) {
new_values.push_back(reversed_order_values[axis]);
}
}
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
if (special_case) {
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, reduction->input_value(1)});
new_reduction->set_friendly_name(transpose->get_friendly_name());
replace_node(transpose, new_reduction);
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
register_new_node(new_transpose);
} else {
auto new_const = std::make_shared<Constant>(reduction_axes->get_element_type(),
reduction_axes->get_shape(),
new_values);
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
replace_node(transpose, new_reduction);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
register_new_node(new_transpose);
}
return true;
};
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -1,23 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset9;
using namespace ov::opset10;
using namespace transpose_sinking;
namespace {
@ -30,7 +31,7 @@ struct OutputTranspose {
Constant* transpose_const;
};
OutputTranspose GetOutputTransposes(NodePtr node) {
OutputTranspose GetOutputTransposes(const NodePtr& node) {
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
for (auto& input : node->get_output_target_inputs(output_idx)) {
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
@ -116,7 +117,7 @@ bool GetSplitAxis(const std::shared_ptr<Constant>& split_axis, const ov::Rank& r
* We sink Transpose through Split operation in a backward way only if all the output
* nodes are the same Transpose. We can:
* - clone Split with all outputs except Transpose
* causes perfomance problems
* causes performance problems
* - add reversed Transpose operations on all outputs except sinking Transpose
* nothing to do with new added output Transposes
*/
@ -131,6 +132,13 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
NodePtr split = FindInputNode<Split>(transpose_label_node);
if (!split) {
split = FindInputNode<VariadicSplit>(transpose_label_node);
}
if (!split) {
return false;
}
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
if (!split_axis_constant) {
return false;
@ -187,21 +195,20 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
MATCHER_SCOPE(TransposeSinkingSplitForward);
auto main_node_label = wrap_type<Split>(IfNodeHasTransposeInputs);
auto main_node_label = wrap_type<Split, VariadicSplit>(IfNodeHasTransposeInputs);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
auto split = as_type_ptr<Split>(main_node);
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
auto split_axis_constant = as_type_ptr<Constant>(main_node->input_value(1).get_node_shared_ptr());
if (!split_axis_constant) {
return false;
}
int64_t split_axis;
if (!GetSplitAxis(split_axis_constant, split->input_value(0).get_partial_shape().rank(), split_axis)) {
if (!GetSplitAxis(split_axis_constant, main_node->input_value(0).get_partial_shape().rank(), split_axis)) {
return false;
}
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
@ -211,10 +218,10 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
const size_t transposed_split_axis = transpose_axis_order[split_axis];
auto new_split_axis_const =
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
split->input(1).replace_source_output(new_split_axis_const);
main_node->input(1).replace_source_output(new_split_axis_const);
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
new_split_axis_const);
split->validate_and_infer_types();
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);

View File

@ -1,16 +1,20 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov;
using namespace ov::opset9;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace transpose_sinking;
@ -45,27 +49,20 @@ NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
return std::make_pair(new_first_node, new_second_node);
}
NodePair Swap(NodePtr first_node, NodePtr second_node) {
NodePair new_nodes;
new_nodes = SwapNodes(first_node, second_node);
return new_nodes;
}
} // namespace
ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
MATCHER_SCOPE(TransposeSinkingUnaryForward);
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({transpose_label});
auto unary_label = wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({transpose_label});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
const NodePair new_nodes = Swap(transpose, unary);
const NodePair new_nodes = SwapNodes(transpose, unary);
register_new_node(new_nodes.first);
register_new_node(new_nodes.second);
@ -84,49 +81,16 @@ bool IfSinkingEnabled(const Output<Node>& output) {
}
} // namespace
ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBackwardSingleConsumer() {
MATCHER_SCOPE(TransposeSinkingUnaryBackwardSingleConsumer);
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
consumers_count(1));
auto transpose_label = wrap_type<Transpose>({unary_label, any_input()}, IfSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
const NodePair new_nodes = Swap(unary, transpose);
register_new_node(new_nodes.first);
register_new_node(new_nodes.second);
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardSingleConsumer");
register_matcher(m, matcher_pass_callback);
}
namespace {
std::function<bool(Output<Node>)> consumers_more_than(size_t n) {
return [=](Output<Node> output) -> bool {
return output.get_target_inputs().size() > n;
};
}
} // namespace
ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBackwardMultiConsumers() {
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers);
auto unary_restrictions = [](const Output<Node>& output) -> bool {
return consumers_more_than(1)(output) && HasSameOutputTransposeNodes(output);
return HasSameOutputTransposeNodes(output);
};
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
unary_restrictions);
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
{any_input()}, unary_restrictions);
auto transpose_const_label = wrap_type<Constant>();
@ -147,12 +111,6 @@ ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBack
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardMultiConsumers");
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackward");
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
MATCHER_SCOPE(TransposeSinkingUnaryBackward);
add_matcher<ov::pass::TransposeSinkingUnaryBackwardSingleConsumer>();
add_matcher<ov::pass::TransposeSinkingUnaryBackwardMultiConsumers>();
}

View File

@ -6,7 +6,7 @@
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
@ -16,23 +16,17 @@
namespace transpose_sinking {
using namespace ov;
using namespace ov::opset9;
using namespace ov::opset10;
using NodePtr = std::shared_ptr<Node>;
Output<Node> ChangeValuesOrder(const Output<Node>& input,
const AxisVector& transpose_axis_order,
const std::shared_ptr<Constant>& axis) {
auto rank = transpose_axis_order.size();
auto split_pad = std::make_shared<Split>(input, axis, rank);
auto split_outputs = split_pad->outputs();
OutputVector new_order(split_outputs.size());
for (size_t i = 0; i < rank; ++i) {
new_order[i] = split_outputs[transpose_axis_order[i]];
}
auto concat_pad = std::make_shared<Concat>(new_order, 0);
copy_runtime_info(input.get_node_shared_ptr(), {split_pad, concat_pad});
return concat_pad;
auto indices = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
auto gather = std::make_shared<Gather>(input, indices, axis);
copy_runtime_info(input.get_node_shared_ptr(), gather);
return gather;
}
TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node) {

View File

@ -7,7 +7,7 @@
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/transpose_sinking_pad.hpp>
#include <transformations/common_optimizations/transpose_sinking_data_movement.hpp>
#include <transformations/common_optimizations/transpose_sinking_utils.hpp>
#include <transformations/init_node_info.hpp>
@ -285,7 +285,7 @@ std::vector<size_t> pad_operations_numbers = {1, 10};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadForwardSingleConsumerTestSuite,
TransposeSinkingPadTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward)),
::testing::ValuesIn(pad_operations_numbers),
::testing::Values(forward::single_consumer::CreateFunction),
::testing::Values(forward::single_consumer::CreateReferenceFunction),
@ -294,7 +294,7 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadForwardSingleConsumerTestSuite,
INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadBackwardSingleConsumerTestSuite,
TransposeSinkingPadTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward)),
::testing::ValuesIn(pad_operations_numbers),
::testing::Values(backward::single_consumer::CreateFunction),
::testing::Values(backward::single_consumer::CreateReferenceFunction),
@ -304,7 +304,7 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadBackwardSingleConsumerTestSuite,
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingPadBackwardSingleConsumerMultiTransposesTestSuite,
TransposeSinkingPadTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward)),
::testing::ValuesIn(pad_operations_numbers),
::testing::Values(backward::output_transpose_mult_transposes::CreateFunction),
::testing::Values(backward::output_transpose_mult_transposes::CreateReferenceFunction),

View File

@ -101,7 +101,7 @@ TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::TransposeFQReduction>();
manager.register_pass<ov::pass::TransposeReduction>();
manager.register_pass<ov::pass::TransposeSinkingReductionForward>();
manager.register_pass<ngraph::pass::CheckUniqueNames>(unh);
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
@ -219,7 +219,7 @@ TEST_P(TransposeSinking, TransposeReduction) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::TransposeReduction>();
manager.register_pass<ov::pass::TransposeSinkingReductionForward>();
manager.register_pass<ngraph::pass::CheckUniqueNames>(unh);
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
@ -340,7 +340,7 @@ TEST_F(TransformationTestsF, TransposeReduceNegative) {
auto sub = std::make_shared<opset6::Subtract>(transpose, reduce_mean);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{sub}, ngraph::ParameterVector{input});
manager.register_pass<ov::pass::TransposeReduction>();
manager.register_pass<ov::pass::TransposeSinkingReductionForward>();
}
}