TSForwardBase class and a transformation to reset no_sinking
attribute (#17913)
* Added TSForwardBase class and a new transformation to reset no_sinking attribute * Refactoring * fix an issue with legacy_output_names * resolve review comments * Resolve review comments
This commit is contained in:
parent
2f59e5d697
commit
6f14a43ea6
@ -11,6 +11,7 @@
|
||||
namespace ov {
|
||||
|
||||
TRANSFORMATIONS_API void mark_as_no_sinking_node(const std::shared_ptr<Node>& node);
|
||||
TRANSFORMATIONS_API void reset_no_sinking_attribute(const std::shared_ptr<Node>& node);
|
||||
|
||||
TRANSFORMATIONS_API bool is_sinking_node(const std::shared_ptr<Node>& node);
|
||||
TRANSFORMATIONS_API bool is_sinking_node(const Node* node);
|
||||
|
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSForwardBase;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSForwardBase is a base class for all forward transformations.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSForwardBase : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSForwardBase", "0");
|
||||
TSForwardBase() = default;
|
||||
|
||||
template <class... Types>
|
||||
void create_pattern(bool const_transpose_input, std::vector<size_t> transpose_indices = {}) {
|
||||
m_const_transpose_input = const_transpose_input;
|
||||
m_tranpose_indices = std::move(transpose_indices);
|
||||
m_pattern = ov::pass::pattern::wrap_type<Types...>([&](const Output<Node>& output) -> bool {
|
||||
return if_node_has_transpose_inputs(output, m_const_transpose_input, m_tranpose_indices);
|
||||
});
|
||||
}
|
||||
|
||||
using sinking_function =
|
||||
std::function<bool(const std::shared_ptr<Node>& main_node, const utils::TransposeInputsInfo& transpose_info)>;
|
||||
|
||||
void transpose_sinking(const std::string& pass_name, const sinking_function& sinking_transformation = nullptr);
|
||||
|
||||
protected:
|
||||
static bool default_inputs_update(const std::shared_ptr<Node>& main_node,
|
||||
const utils::TransposeInputsInfo& transpose_info);
|
||||
|
||||
void default_outputs_update(const std::shared_ptr<Node>& main_node,
|
||||
const utils::TransposeInputsInfo& transpose_info);
|
||||
|
||||
private:
|
||||
static bool if_node_has_transpose_inputs(const Output<Node>& output,
|
||||
bool const_transpose_input,
|
||||
const std::vector<size_t>& transpose_indices);
|
||||
|
||||
std::shared_ptr<Node> m_pattern;
|
||||
bool m_const_transpose_input = true;
|
||||
std::vector<size_t> m_tranpose_indices;
|
||||
};
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSBinaryBackward;
|
||||
* @brief TSBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic,
|
||||
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSBinaryForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSBinaryForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
|
||||
TSBinaryForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSConcatBackward;
|
||||
* @brief TSConcatForward transformation sinks Transpose through Concat operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSConcatForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSConcatForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSConcatForward", "0");
|
||||
TSConcatForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -25,7 +26,7 @@ class TRANSFORMATIONS_API TSDataMovementBackward;
|
||||
* ReverseSequence and Pad operations in the forward direction.
|
||||
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSDataMovementForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSDataMovementForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSDataMovementForward", "0");
|
||||
TSDataMovementForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSGatherBackward;
|
||||
* @brief TSGatherForward transformation sinks Transpose through Gather operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSGatherForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSGatherForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSGatherForward", "0");
|
||||
TSGatherForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSInterpolateBackward;
|
||||
* @brief TSInterpolateForward transformation sinks Transpose through Interpolate operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSInterpolateForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSInterpolateForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSInterpolateForward", "0");
|
||||
TSInterpolateForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSReductionBackward;
|
||||
* @brief TSReductionForward transformation sinks Transpose through Reduce operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSReductionForward", "0");
|
||||
TSReductionForward();
|
||||
|
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSResetNoSinkingAttribute;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSResetNoSinkingAttribute transformation resets all NoSinkingAttribute runtime attributes
|
||||
* in Transpose operations.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSResetNoSinkingAttribute : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSResetNoSinkingAttribute", "0");
|
||||
TSResetNoSinkingAttribute();
|
||||
};
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -19,7 +20,7 @@ class TRANSFORMATIONS_API TSSliceBackward;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::transpose_sinking::TSSliceForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSSliceForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSliceForward", "0");
|
||||
TSSliceForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSSplitForward;
|
||||
* @brief TSSplitForward transformation sinks Transpose through Split, VariadicSplit operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSplitForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSSplitForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSplitForward", "0");
|
||||
TSSplitForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSSqueezeBackward;
|
||||
* @brief TSSqueezeForward transformation sinks Transpose through Reshape, Squeeze operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSqueezeForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSSqueezeForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSqueezeForward", "0");
|
||||
TSSqueezeForward();
|
||||
|
@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -23,7 +24,7 @@ class TRANSFORMATIONS_API TSUnaryBackward;
|
||||
* @brief TSUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
|
||||
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnaryForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSUnaryForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("TSUnaryForward", "0");
|
||||
TSUnaryForward();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSUnsqueezeBackward;
|
||||
* @brief TSUnsqueezeForward transformation sinks Transpose through Unsqueeze, Reshape operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnsqueezeForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSUnsqueezeForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSUnsqueezeForward", "0");
|
||||
TSUnsqueezeForward();
|
||||
|
@ -30,12 +30,14 @@ struct TransposeInputsInfo {
|
||||
* @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo
|
||||
* for it
|
||||
*/
|
||||
TransposeInputsInfo GetFirstTransposeInput(const std::shared_ptr<ov::Node>&);
|
||||
TransposeInputsInfo GetFirstTransposeInput(const std::shared_ptr<ov::Node>&,
|
||||
bool const_transpose_order,
|
||||
const std::vector<size_t>& indices = {});
|
||||
|
||||
/**
|
||||
* @brief Checks if @arg has any input node that is a transpose operation
|
||||
*/
|
||||
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>&);
|
||||
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>&, const std::vector<size_t>& indices = {});
|
||||
|
||||
/**
|
||||
* @brief Reverses order of transpose operation. Do it in a such way that if we had couple following one after
|
||||
@ -86,8 +88,6 @@ ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr<ov::Node>& main_n
|
||||
std::vector<size_t> input_indexes = {});
|
||||
} // namespace sink_backward
|
||||
|
||||
void UpdateForwardSinkingAbility(const std::shared_ptr<ov::Node>&);
|
||||
|
||||
/**
|
||||
* @brief Checks if @arg has consumers that are all the same Transpose operation
|
||||
* and that sinking is enabled for all these Transpose ops. Otherwise returns false.
|
||||
|
@ -11,6 +11,14 @@ void ov::mark_as_no_sinking_node(const std::shared_ptr<Node>& node) {
|
||||
rt_info[NoTransposeSinkingAttr::get_type_info_static()] = NoTransposeSinkingAttr();
|
||||
}
|
||||
|
||||
void ov::reset_no_sinking_attribute(const std::shared_ptr<Node>& node) {
|
||||
auto& rt_info = node->get_rt_info();
|
||||
auto it = rt_info.find(NoTransposeSinkingAttr::get_type_info_static());
|
||||
if (it != rt_info.end()) {
|
||||
rt_info.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename NodePtr>
|
||||
bool is_sinking_node_private(NodePtr node) {
|
||||
|
@ -0,0 +1,77 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/fake_quantize.hpp"
|
||||
#include "openvino/op/prelu.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
void TSForwardBase::transpose_sinking(const std::string& pass_name,
|
||||
const TSForwardBase::sinking_function& sinking_transformation) {
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto main_node = pattern_to_output.at(m_pattern).get_node_shared_ptr();
|
||||
utils::TransposeInputsInfo transpose_input_info =
|
||||
utils::GetFirstTransposeInput(main_node, m_const_transpose_input, m_tranpose_indices);
|
||||
|
||||
if (transformation_callback(main_node)) {
|
||||
mark_as_no_sinking_node(transpose_input_info.transpose);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res;
|
||||
if (sinking_transformation) {
|
||||
// use custom function to sink transpose
|
||||
res = sinking_transformation(main_node, transpose_input_info);
|
||||
} else {
|
||||
// default transpose sinking function:
|
||||
res = default_inputs_update(main_node, transpose_input_info);
|
||||
if (res) {
|
||||
default_outputs_update(main_node, transpose_input_info);
|
||||
}
|
||||
}
|
||||
if (!res) {
|
||||
mark_as_no_sinking_node(transpose_input_info.transpose);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(m_pattern, pass_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
bool TSForwardBase::default_inputs_update(const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) {
|
||||
return utils::sink_forward::UpdateInputTransposes(main_node, transpose_info);
|
||||
}
|
||||
|
||||
void TSForwardBase::default_outputs_update(const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) {
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : utils::sink_forward::InsertOutputTransposes(main_node, transpose_info)) {
|
||||
register_new_node(new_node);
|
||||
mark_as_no_sinking_node(new_node);
|
||||
}
|
||||
}
|
||||
|
||||
bool TSForwardBase::if_node_has_transpose_inputs(const Output<Node>& output,
|
||||
bool const_transpose_input,
|
||||
const std::vector<size_t>& transpose_indices) {
|
||||
utils::TransposeInputsInfo inputs_info =
|
||||
utils::GetFirstTransposeInput(output.get_node_shared_ptr(), const_transpose_input, transpose_indices);
|
||||
return !inputs_info.isEmpty();
|
||||
}
|
@ -20,42 +20,14 @@ using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
TSBinaryForward::TSBinaryForward() {
|
||||
TSBinaryForward::TSBinaryForward() : TSForwardBase() {
|
||||
MATCHER_SCOPE(TSBinaryForward);
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||
create_pattern<op::util::BinaryElementwiseArithmetic,
|
||||
op::util::BinaryElementwiseComparison,
|
||||
op::util::BinaryElementwiseLogical,
|
||||
ov::op::v0::PRelu,
|
||||
ov::op::v0::FakeQuantize>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
// todo: support dynamic rank case
|
||||
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
if (!updated) {
|
||||
return false;
|
||||
}
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
ov::op::v0::FakeQuantize>(true);
|
||||
transpose_sinking(matcher_name);
|
||||
}
|
||||
|
||||
TSBinaryBackward::TSBinaryBackward() {
|
||||
|
@ -21,45 +21,35 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
TSConcatForward::TSConcatForward() {
|
||||
MATCHER_SCOPE(TSConcatForward);
|
||||
|
||||
auto main_node_label = wrap_type<ov::op::v0::Concat>(IfNodeHasTransposeInputs);
|
||||
create_pattern<ov::op::v0::Concat>(true);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
if (transformation_callback(main_node)) {
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
// todo: support dynamic rank case
|
||||
auto concat_node = as_type_ptr<ov::op::v0::Concat>(main_node);
|
||||
if (!concat_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
auto concat_node = as_type_ptr<ov::op::v0::Concat>(main_node);
|
||||
auto concat_axis = concat_node->get_concatenation_axis();
|
||||
if (concat_axis < 0) {
|
||||
return false;
|
||||
}
|
||||
// todo: support dyn rank case
|
||||
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_info);
|
||||
if (!updated) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
|
||||
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSConcatBackward::TSConcatBackward() {
|
||||
|
@ -38,35 +38,17 @@ std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_nod
|
||||
|
||||
TSDataMovementForward::TSDataMovementForward() {
|
||||
MATCHER_SCOPE(TSDataMovementForward);
|
||||
auto const_label = wrap_type<ov::op::v0::Constant>();
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), const_label});
|
||||
auto main_node_label =
|
||||
wrap_type<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
|
||||
{transpose_label, any_input(), any_input(), any_input()});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
|
||||
auto& main_node = pattern_to_node.at(main_node_label);
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose = std::dynamic_pointer_cast<ov::op::v1::Transpose>(pattern_to_node.at(transpose_label));
|
||||
if (!transpose) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_const = as_type_ptr<ov::op::v0::Constant>(pattern_to_node.at(const_label));
|
||||
if (!transpose_const) {
|
||||
return false;
|
||||
}
|
||||
create_pattern<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
|
||||
true,
|
||||
{0});
|
||||
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
|
||||
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
|
||||
|
||||
@ -80,17 +62,12 @@ TSDataMovementForward::TSDataMovementForward() {
|
||||
reverse_seq->set_batch_axis(transpose_axis_order[reverse_seq->get_batch_axis()]);
|
||||
reverse_seq->set_sequence_axis(transpose_axis_order[reverse_seq->get_sequence_axis()]);
|
||||
}
|
||||
main_node->validate_and_infer_types();
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSDataMovementBackward::TSDataMovementBackward() {
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
@ -72,7 +73,7 @@ TSFuse::TSFuse() {
|
||||
copy_runtime_info(transpose1, new_transpose);
|
||||
ov::replace_node(transpose1, new_transpose);
|
||||
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
mark_as_no_sinking_node(new_transpose);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
@ -23,22 +23,18 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
TSGatherForward::TSGatherForward() {
|
||||
MATCHER_SCOPE(TSGatherForward);
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), wrap_type<ov::op::v0::Constant>()});
|
||||
auto gather_label =
|
||||
wrap_type<ov::op::v8::Gather>({transpose_label, any_input(), wrap_type<ov::op::v0::Constant>()});
|
||||
create_pattern<ov::op::v8::Gather>(true, {0});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = as_type_ptr<ov::op::v1::Transpose>(pattern_to_output.at(transpose_label));
|
||||
auto main_node = as_type_ptr<ov::op::v8::Gather>(pattern_to_output.at(gather_label));
|
||||
if (transformation_callback(main_node) || !main_node) {
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
auto gather = as_type_ptr<ov::op::v8::Gather>(main_node);
|
||||
if (!gather) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_order = as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto transpose_order = transpose_info.transpose_const;
|
||||
auto gather_axis = as_type_ptr<ov::op::v0::Constant>(main_node->get_input_node_shared_ptr(2));
|
||||
if (!transpose || !transpose_order || !gather_axis) {
|
||||
if (!gather_axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -53,7 +49,7 @@ TSGatherForward::TSGatherForward() {
|
||||
}
|
||||
|
||||
const auto& order_val = transpose_order->cast_vector<size_t>();
|
||||
auto batch_dims = static_cast<size_t>(main_node->get_batch_dims());
|
||||
auto batch_dims = static_cast<size_t>(gather->get_batch_dims());
|
||||
for (size_t i = 0; i < batch_dims; ++i) {
|
||||
// transpose changes the order of batch dims
|
||||
if (order_val[i] != i) {
|
||||
@ -88,7 +84,7 @@ TSGatherForward::TSGatherForward() {
|
||||
auto new_order_const = ov::op::v0::Constant::create(transpose_order->get_element_type(),
|
||||
{new_transpose_order.size()},
|
||||
new_transpose_order);
|
||||
TransposeInputsInfo transpose_input_info = {transpose, new_order_const, 0};
|
||||
TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_order_const, 0};
|
||||
// deletes Transpose from 0 input
|
||||
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
|
||||
if (!success) {
|
||||
@ -98,17 +94,12 @@ TSGatherForward::TSGatherForward() {
|
||||
ov::op::v0::Constant::create(gather_axis->get_element_type(), gather_axis->get_shape(), {order_val[axis]});
|
||||
main_node->input(2).replace_source_output(new_axis);
|
||||
copy_runtime_info(gather_axis, new_axis);
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
default_outputs_update(main_node, transpose_input_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(gather_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSGatherBackward::TSGatherBackward() {
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "transformations/transpose_sinking/ts_gather.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp"
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
@ -73,6 +74,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
manager.register_pass<DisableShapeOfConstantFolding>();
|
||||
manager.register_pass<TSGeneralBackward>();
|
||||
manager.register_pass<ConstantFolding>();
|
||||
manager.register_pass<TSResetNoSinkingAttribute>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
|
@ -22,39 +22,21 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
TSInterpolateForward::TSInterpolateForward() {
|
||||
MATCHER_SCOPE(TSInterpolateForward);
|
||||
auto const_label = wrap_type<ov::op::v0::Constant>();
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), const_label});
|
||||
auto main_node_label = wrap_type<ov::op::v4::Interpolate>({transpose_label, any_input(), any_input(), any_input()});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
|
||||
auto& main_node = pattern_to_node.at(main_node_label);
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose = std::dynamic_pointer_cast<ov::op::v1::Transpose>(pattern_to_node.at(transpose_label));
|
||||
if (!transpose) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_const = as_type_ptr<ov::op::v0::Constant>(pattern_to_node.at(const_label));
|
||||
if (!transpose_const) {
|
||||
return false;
|
||||
}
|
||||
create_pattern<ov::op::v4::Interpolate>(true, {0});
|
||||
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = transpose->input_value(0);
|
||||
auto transpose_parent = transpose_info.transpose->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
|
||||
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
|
||||
|
||||
const auto& interpolate = std::dynamic_pointer_cast<ov::op::v4::Interpolate>(main_node);
|
||||
const auto& new_axes = ChangeAxes(main_node->input_value(3), transpose_axis_order, axis);
|
||||
main_node->input(3).replace_source_output(new_axes);
|
||||
|
||||
const auto& interpolate = std::dynamic_pointer_cast<ov::op::v4::Interpolate>(main_node);
|
||||
if (interpolate) {
|
||||
op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs();
|
||||
if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) {
|
||||
@ -72,17 +54,11 @@ TSInterpolateForward::TSInterpolateForward() {
|
||||
}
|
||||
}
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSInterpolateBackward::TSInterpolateBackward() {
|
||||
|
@ -42,20 +42,11 @@ bool get_keep_dims(const std::shared_ptr<Node>& main_node) {
|
||||
TSReductionForward::TSReductionForward() {
|
||||
MATCHER_SCOPE(TSReductionForward);
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), wrap_type<ov::op::v0::Constant>()});
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{transpose_label, wrap_type<ov::op::v0::Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
auto transpose = as_type_ptr<ov::op::v1::Transpose>(pattern_to_output.at(transpose_label));
|
||||
auto main_node = pattern_to_output.at(reduce_label);
|
||||
if (!transpose || transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
create_pattern<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(true, {0});
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
auto keep_dims = get_keep_dims(main_node);
|
||||
auto transpose_order = as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto transpose_order = transpose_info.transpose_const;
|
||||
auto reduction_axes = as_type_ptr<ov::op::v0::Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
@ -84,7 +75,7 @@ TSReductionForward::TSReductionForward() {
|
||||
auto new_const =
|
||||
ov::op::v0::Constant::create(reduction_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
main_node->input(1).replace_source_output(new_const);
|
||||
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
|
||||
TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_transpose_order, 0};
|
||||
// deletes Transpose from 0 input
|
||||
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
|
||||
if (!success) {
|
||||
@ -92,16 +83,12 @@ TSReductionForward::TSReductionForward() {
|
||||
}
|
||||
|
||||
copy_runtime_info(reduction_axes, new_const);
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
default_outputs_update(main_node, transpose_input_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(reduce_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSReductionBackward::TSReductionBackward() {
|
||||
|
@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
|
||||
TSResetNoSinkingAttribute::TSResetNoSinkingAttribute() {
|
||||
MATCHER_SCOPE(TSResetNoSinkingAttribute);
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>([](const Output<Node>& output) -> bool {
|
||||
const auto& rt_info = output.get_node()->get_rt_info();
|
||||
return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) != rt_info.end();
|
||||
});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
const auto& transpose = pattern_to_output.at(transpose_label);
|
||||
ov::reset_no_sinking_attribute(transpose);
|
||||
return false;
|
||||
};
|
||||
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -23,34 +23,19 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
TSSliceForward::TSSliceForward() {
|
||||
MATCHER_SCOPE(TSSliceForward);
|
||||
auto const_label = wrap_type<ov::op::v0::Constant>();
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), const_label});
|
||||
auto main_node_label =
|
||||
wrap_type<ov::op::v8::Slice>({transpose_label, any_input(), any_input(), any_input(), any_input()});
|
||||
create_pattern<ov::op::v8::Slice>(true, {0});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
|
||||
auto& main_node = pattern_to_node.at(main_node_label);
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose = std::dynamic_pointer_cast<ov::op::v1::Transpose>(pattern_to_node.at(transpose_label));
|
||||
if (!transpose || main_node->get_input_size() < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_const = as_type_ptr<ov::op::v0::Constant>(pattern_to_node.at(const_label));
|
||||
if (!transpose_const) {
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
if (main_node->get_input_size() < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = transpose->input_value(0);
|
||||
auto transpose_parent = transpose_info.transpose->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
|
||||
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
|
||||
auto data = std::make_shared<ov::op::v0::Constant>(element::i32,
|
||||
@ -61,17 +46,11 @@ TSSliceForward::TSSliceForward() {
|
||||
|
||||
main_node->input(4).replace_source_output(new_axis);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSSliceBackward::TSSliceBackward() {
|
||||
|
@ -95,6 +95,40 @@ bool GetSplitAxis(const std::shared_ptr<ov::op::v0::Constant>& split_axis, const
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSSplitForward::TSSplitForward() {
|
||||
MATCHER_SCOPE(TSSplitForward);
|
||||
|
||||
create_pattern<ov::op::v1::Split, ov::op::v1::VariadicSplit>(true, {0});
|
||||
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
auto split_axis_constant = as_type_ptr<ov::op::v0::Constant>(main_node->input_value(1).get_node_shared_ptr());
|
||||
if (!split_axis_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t split_axis;
|
||||
if (!GetSplitAxis(split_axis_constant, main_node->input_value(0).get_partial_shape().rank(), split_axis)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
|
||||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
|
||||
const size_t transposed_split_axis = transpose_axis_order[split_axis];
|
||||
auto new_split_axis_const = std::make_shared<ov::op::v0::Constant>(split_axis_constant->get_element_type(),
|
||||
Shape{},
|
||||
transposed_split_axis);
|
||||
main_node->input(1).replace_source_output(new_split_axis_const);
|
||||
copy_runtime_info({split_axis_constant, transpose_info.transpose, transpose_info.transpose_const},
|
||||
new_split_axis_const);
|
||||
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
/*
|
||||
* We follow Transpose operations rather than Split. We cannot create matcher pattern
|
||||
* for Split with Transpose outputs since Split can have different number of outputs.
|
||||
@ -191,51 +225,3 @@ TSSplitBackward::TSSplitBackward() {
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSSplitForward::TSSplitForward() {
|
||||
MATCHER_SCOPE(TSSplitForward);
|
||||
|
||||
auto main_node_label = wrap_type<ov::op::v1::Split, ov::op::v1::VariadicSplit>(IfNodeHasTransposeInputs);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto split_axis_constant = as_type_ptr<ov::op::v0::Constant>(main_node->input_value(1).get_node_shared_ptr());
|
||||
if (!split_axis_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t split_axis;
|
||||
if (!GetSplitAxis(split_axis_constant, main_node->input_value(0).get_partial_shape().rank(), split_axis)) {
|
||||
return false;
|
||||
}
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const size_t transposed_split_axis = transpose_axis_order[split_axis];
|
||||
auto new_split_axis_const = std::make_shared<ov::op::v0::Constant>(split_axis_constant->get_element_type(),
|
||||
Shape{},
|
||||
transposed_split_axis);
|
||||
main_node->input(1).replace_source_output(new_split_axis_const);
|
||||
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
|
||||
new_split_axis_const);
|
||||
main_node->validate_and_infer_types();
|
||||
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
@ -104,32 +104,10 @@ bool squeeze_axes_to_shape(const Output<Node>& input_node,
|
||||
TSSqueezeForward::TSSqueezeForward() {
|
||||
MATCHER_SCOPE(TSSqueezeForward);
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), wrap_type<ov::op::v0::Constant>()});
|
||||
auto squeeze_with_1_input = wrap_type<ov::op::v0::Squeeze>({transpose_label});
|
||||
auto squeeze_label =
|
||||
wrap_type<ov::op::v0::Squeeze, ov::op::v1::Reshape>({transpose_label, wrap_type<ov::op::v0::Constant>()});
|
||||
auto pattern = std::make_shared<pattern::op::Or>(OutputVector{squeeze_with_1_input, squeeze_label});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = as_type_ptr<ov::op::v1::Transpose>(pattern_to_output.at(transpose_label));
|
||||
std::shared_ptr<Node> main_node;
|
||||
if (pattern_to_output.count(squeeze_label)) {
|
||||
main_node = pattern_to_output.at(squeeze_label);
|
||||
} else {
|
||||
main_node = pattern_to_output.at(squeeze_with_1_input);
|
||||
}
|
||||
if (!transpose || transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_order = as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
|
||||
if (!transpose_order) {
|
||||
return false;
|
||||
}
|
||||
create_pattern<ov::op::v0::Squeeze, ov::op::v1::Reshape>(true, {0});
|
||||
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
std::vector<size_t> non_negative_axes;
|
||||
std::shared_ptr<ov::op::v0::Constant> squeeze_axes;
|
||||
if (main_node->get_input_size() > 1) {
|
||||
@ -153,7 +131,7 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
|
||||
// if 2nd input to main_node is empty then all '1' dims will be deleted.
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = transpose->output(0).get_partial_shape();
|
||||
auto input_pshape = transpose_info.transpose->output(0).get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
@ -164,7 +142,7 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
}
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
auto transpose_order_values = transpose_info.transpose_const->cast_vector<size_t>();
|
||||
std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
@ -172,13 +150,13 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
}
|
||||
|
||||
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
|
||||
auto new_transpose_order = ov::op::v0::Constant::create(transpose_order->get_element_type(),
|
||||
auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
if (as_type_ptr<ov::op::v1::Reshape>(main_node)) {
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape);
|
||||
auto success = squeeze_axes_to_shape(transpose_info.transpose->input_value(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
@ -192,24 +170,18 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
copy_runtime_info(squeeze_axes, new_const);
|
||||
}
|
||||
|
||||
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
|
||||
TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_transpose_order, 0};
|
||||
// deletes Transpose from 0 input
|
||||
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
default_outputs_update(main_node, transpose_input_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(pattern, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSSqueezeBackward::TSSqueezeBackward() {
|
||||
|
@ -32,38 +32,12 @@ namespace {
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
using NodePair = std::pair<NodePtr, NodePtr>;
|
||||
|
||||
/**
|
||||
* @brief SwapNodes allows to perform swapping nodes even if there are more than one consumers but has less performance
|
||||
*
|
||||
* @param first_node first node pointer
|
||||
* @param second_node first node pointer
|
||||
* @return NodePair pair of nodes in new order that allows to register them in MatcherPass
|
||||
*/
|
||||
NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
|
||||
auto second_node_inputs = second_node->input_values();
|
||||
second_node_inputs[0] = first_node->input_value(0);
|
||||
|
||||
auto new_first_node = second_node->clone_with_new_inputs(second_node_inputs);
|
||||
|
||||
auto first_node_inputs = first_node->input_values();
|
||||
first_node_inputs[0] = new_first_node;
|
||||
auto new_second_node = first_node->clone_with_new_inputs(first_node_inputs);
|
||||
|
||||
new_second_node->set_friendly_name(second_node->get_friendly_name());
|
||||
ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node});
|
||||
|
||||
ov::replace_node(second_node, new_second_node);
|
||||
|
||||
return std::make_pair(new_first_node, new_second_node);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TSUnaryForward::TSUnaryForward() {
|
||||
MATCHER_SCOPE(TSUnaryForward);
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), any_input()});
|
||||
auto unary_label = wrap_type<UnaryElementwiseArithmetic,
|
||||
create_pattern<UnaryElementwiseArithmetic,
|
||||
ov::op::v0::Clamp,
|
||||
ov::op::v0::Elu,
|
||||
ov::op::v4::SoftPlus,
|
||||
@ -71,27 +45,8 @@ TSUnaryForward::TSUnaryForward() {
|
||||
ov::op::v0::Convert,
|
||||
ov::op::v10::IsInf,
|
||||
ov::op::v10::IsNaN,
|
||||
ov::op::v10::IsFinite>({transpose_label});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||
if (transformation_callback(unary)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const NodePair new_nodes = SwapNodes(transpose, unary);
|
||||
|
||||
register_new_node(new_nodes.first);
|
||||
register_new_node(new_nodes.second);
|
||||
|
||||
UpdateForwardSinkingAbility(new_nodes.second);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TSUnaryForward");
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
ov::op::v10::IsFinite>(true);
|
||||
transpose_sinking(matcher_name);
|
||||
}
|
||||
|
||||
TSUnaryBackward::TSUnaryBackward() {
|
||||
|
@ -104,22 +104,12 @@ bool unsqueeze_axes_to_shape(const Output<Node>& input_node,
|
||||
TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
MATCHER_SCOPE(TSUnsqueezeForward);
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({any_input(), wrap_type<ov::op::v0::Constant>()});
|
||||
auto unsqueeze_label =
|
||||
wrap_type<ov::op::v0::Unsqueeze, ov::op::v1::Reshape>({transpose_label, wrap_type<ov::op::v0::Constant>()});
|
||||
create_pattern<ov::op::v0::Unsqueeze, ov::op::v1::Reshape>(true, {0});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = as_type_ptr<ov::op::v1::Transpose>(pattern_to_output.at(transpose_label));
|
||||
auto main_node = pattern_to_output.at(unsqueeze_label);
|
||||
if (!transpose || transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_order = as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
auto unsqueeze_axes = as_type_ptr<ov::op::v0::Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !unsqueeze_axes) {
|
||||
if (!unsqueeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -136,17 +126,17 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
normalize_axes(main_node->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
auto ts_order_values = transpose_order->cast_vector<size_t>();
|
||||
auto ts_order_values = transpose_info.transpose_const->cast_vector<size_t>();
|
||||
|
||||
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
|
||||
auto new_transpose_order = ov::op::v0::Constant::create(transpose_order->get_element_type(),
|
||||
auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(),
|
||||
{ts_order_values.size()},
|
||||
ts_order_values);
|
||||
|
||||
std::shared_ptr<Node> new_unsqueeze;
|
||||
if (as_type_ptr<ov::op::v1::Reshape>(main_node)) {
|
||||
std::vector<size_t> new_values;
|
||||
auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values);
|
||||
auto success =
|
||||
unsqueeze_axes_to_shape(transpose_info.transpose->input_value(0), non_negative_axes, new_values);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
@ -156,24 +146,18 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
copy_runtime_info(unsqueeze_axes, new_const);
|
||||
}
|
||||
|
||||
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
|
||||
TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_transpose_order, 0};
|
||||
// deletes Transpose from 0 input
|
||||
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
default_outputs_update(main_node, transpose_input_info);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(unsqueeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
|
@ -66,14 +66,22 @@ Output<Node> ChangeAxes(const Output<Node>& indices,
|
||||
return ChangeAxes(indices, data, axis);
|
||||
}
|
||||
|
||||
TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node) {
|
||||
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||
TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node,
|
||||
bool const_transpose_order,
|
||||
const std::vector<size_t>& indices) {
|
||||
auto indices_to_check = indices;
|
||||
if (indices.empty()) {
|
||||
indices_to_check.resize(node->get_input_size());
|
||||
std::iota(indices_to_check.begin(), indices_to_check.end(), 0);
|
||||
}
|
||||
|
||||
for (const auto& input_idx : indices_to_check) {
|
||||
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
||||
auto transpose_node = as_type_ptr<ov::op::v1::Transpose>(input_node);
|
||||
if (!transpose_node)
|
||||
continue;
|
||||
auto constant_node = as_type_ptr<ov::op::v0::Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
||||
if (!constant_node)
|
||||
if (const_transpose_order && !constant_node)
|
||||
continue;
|
||||
{
|
||||
TransposeInputsInfo input_info;
|
||||
@ -87,11 +95,6 @@ TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node) {
|
||||
return {};
|
||||
}
|
||||
|
||||
bool IfNodeHasTransposeInputs(const Output<Node>& output) {
|
||||
TransposeInputsInfo inputs_info = GetFirstTransposeInput(output.get_node_shared_ptr());
|
||||
return !inputs_info.isEmpty();
|
||||
}
|
||||
|
||||
AxisVector ReverseTransposeOrder(const AxisVector& axis_order) {
|
||||
AxisVector out(axis_order.size());
|
||||
for (size_t i = 0; i < axis_order.size(); i++) {
|
||||
@ -104,6 +107,12 @@ void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
|
||||
const auto node2_output_names = output2.get_names();
|
||||
output2.set_names(output1.get_names());
|
||||
output1.set_names(node2_output_names);
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto node2_legacy_output_names = get_ov_tensor_legacy_name(output2.get_tensor());
|
||||
set_ov_tensor_legacy_name(output2.get_tensor(), get_ov_tensor_legacy_name(output1.get_tensor()));
|
||||
set_ov_tensor_legacy_name(output1.get_tensor(), node2_legacy_output_names);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void SwapFriendlyNames(const NodePtr& node1, const NodePtr& node2) {
|
||||
@ -304,67 +313,6 @@ NodeVector InsertTransposeBeforeNode(const NodePtr& main_node,
|
||||
}
|
||||
} // namespace sink_backward
|
||||
|
||||
#define CHECK_TRANSPOSE_SINKING_SUPPORTED(TYPE, node) \
|
||||
if (dynamic_cast<TYPE*>(node)) { \
|
||||
return true; \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool CanPropagateForwardThrough(Node* node) {
|
||||
// todo: collect this info automatically
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::UnaryElementwiseArithmetic, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Clamp, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Elu, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v4::SoftPlus, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::LogicalNot, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Convert, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v10::IsInf, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v10::IsNaN, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v10::IsFinite, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::BinaryElementwiseArithmetic, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::BinaryElementwiseComparison, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::BinaryElementwiseLogical, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::PRelu, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Pad, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::BatchToSpace, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::SpaceToBatch, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::ReverseSequence, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v8::Gather, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v4::Interpolate, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::ArithmeticReductionKeepDims, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::LogicalReductionKeepDims, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v8::Slice, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Split, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::VariadicSplit, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Squeeze, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Reshape, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Unsqueeze, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Transpose, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::FakeQuantize, node)
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Concat, node)
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CanPropagateForward(const NodePtr& node) {
|
||||
for (size_t i = 0; i < node->get_output_size(); ++i) {
|
||||
for (auto& consumer_input : node->output(i).get_target_inputs()) {
|
||||
if (!CanPropagateForwardThrough(consumer_input.get_node()))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void UpdateForwardSinkingAbility(const NodePtr& node) {
|
||||
if (!CanPropagateForward(node))
|
||||
mark_as_no_sinking_node(node);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<ov::op::v0::Constant> GetTransposeConstant(Node* node) {
|
||||
|
@ -394,13 +394,13 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
|
||||
{
|
||||
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
|
||||
auto node0 = MakeAllNodesSubgraph(X, 1, 1);
|
||||
|
||||
auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3);
|
||||
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<Transpose>(node0, ng_order0);
|
||||
|
||||
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape = std::make_shared<Reshape>(node0, reshape_const, false);
|
||||
auto reshape = std::make_shared<Reshape>(transpose0, reshape_const, false);
|
||||
|
||||
auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
|
||||
|
||||
|
@ -0,0 +1,54 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
|
||||
using namespace op::v0;
|
||||
using namespace op::v1;
|
||||
|
||||
TEST(TransformationTests, ResetNoSinkingAttribute) {
|
||||
auto a = std::make_shared<Parameter>(element::f32, Shape{12, 3, 4, 8});
|
||||
auto b = std::make_shared<Parameter>(element::f32, Shape{12, 3, 4, 8});
|
||||
|
||||
auto transpose_a = make_shared<Transpose>(a, Constant::create(element::i64, Shape{4}, {1, 0, 2, 3}));
|
||||
auto transpose_b = make_shared<Transpose>(b, Constant::create(element::i64, Shape{4}, {1, 0, 2, 3}));
|
||||
|
||||
auto add = std::make_shared<Add>(transpose_a, transpose_b);
|
||||
auto trans_after = make_shared<Transpose>(add, Constant::create(element::i64, Shape{4}, {1, 0, 2, 3}));
|
||||
auto model = std::make_shared<Model>(NodeVector{trans_after}, ParameterVector{a, b});
|
||||
|
||||
mark_as_no_sinking_node(transpose_a);
|
||||
mark_as_no_sinking_node(transpose_b);
|
||||
mark_as_no_sinking_node(trans_after);
|
||||
|
||||
const auto& ops = model->get_ordered_ops();
|
||||
const auto cnt_before = count_if(ops.begin(), ops.end(), [](const std::shared_ptr<Node>& node) {
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) != rt_info.end();
|
||||
});
|
||||
|
||||
EXPECT_EQ(cnt_before, 3);
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<pass::transpose_sinking::TSResetNoSinkingAttribute>();
|
||||
manager.run_passes(model);
|
||||
|
||||
const auto cnt_after = count_if(ops.begin(), ops.end(), [](const std::shared_ptr<Node>& node) {
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) != rt_info.end();
|
||||
});
|
||||
|
||||
EXPECT_EQ(cnt_after, 0);
|
||||
}
|
Loading…
Reference in New Issue
Block a user