Add descriptions to the transformations, add additional checks
This commit is contained in:
parent
3a96e06d4c
commit
d284ac1b7a
@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingBinaryBackward;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic,
|
||||
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0");
|
||||
TransposeSinkingBinaryForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic,
|
||||
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0");
|
||||
|
@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingConcatBackward;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingConcatForward transformation sinks Transpose through Concat operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0");
|
||||
TransposeSinkingConcatForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingConcatBackward transformation sinks Transpose through Concat operation
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0");
|
||||
|
@ -17,12 +17,24 @@ class TRANSFORMATIONS_API TransposeSinkingDataMovementBackward;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch
|
||||
* and Pad operations in the forward direction.
|
||||
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingDataMovementForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementForward", "0");
|
||||
TransposeSinkingDataMovementForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch
|
||||
* and Pad operations in the backward direction.
|
||||
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingDataMovementBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementBackward", "0");
|
||||
|
@ -17,18 +17,34 @@ class TRANSFORMATIONS_API TransposeSinkingGeneral;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingGeneralForward transformation combines all TransposeSinkingForward* transformations into
|
||||
* single GraphRewrite pass.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneralForward", "0");
|
||||
TransposeSinkingGeneralForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingGeneralBackward transformation combines all TransposeSinkingBackward* transformations into
|
||||
* single GraphRewrite pass.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0");
|
||||
TransposeSinkingGeneralBackward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingGeneral transformation combines TransposeSinkingGeneralForward and
|
||||
* TransposeSinkingGeneralBackward transformations into single ModelPass pass and inserts
|
||||
* ConstantFolding pass after them.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneral", "0");
|
||||
|
@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingInterpolateBackward;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingInterpolateForward transformation sinks Transpose through Interpolate operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingInterpolateForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateForward", "0");
|
||||
TransposeSinkingInterpolateForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingInterpolateBackward transformation sinks Transpose through Interpolate operation
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingInterpolateBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateBackward", "0");
|
||||
|
@ -19,7 +19,8 @@ class TRANSFORMATIONS_API TransposeSinkingReductionBackward;
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce operations
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingReductionForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
@ -29,7 +30,8 @@ public:
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce operations
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingReductionBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
|
@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingSplitForward;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingSplitForward transformation sinks Transpose through Split, VariadicSplit operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0");
|
||||
TransposeSinkingSplitForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingSplitBackward transformation sinks Transpose through Split, VariadicSplit operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0");
|
||||
|
@ -16,12 +16,22 @@ class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
|
||||
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingUnaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingUnaryForward", "0");
|
||||
TransposeSinkingUnaryForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
|
||||
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");
|
||||
|
@ -4,23 +4,18 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
|
||||
struct TransposeInputsInfo {
|
||||
std::shared_ptr<ov::opset9::Transpose> transpose;
|
||||
std::shared_ptr<ov::opset9::Constant> transpose_const;
|
||||
std::shared_ptr<ov::opset10::Transpose> transpose;
|
||||
std::shared_ptr<ov::opset10::Constant> transpose_const;
|
||||
size_t input_idx;
|
||||
|
||||
bool isEmpty() const {
|
||||
@ -87,7 +82,7 @@ namespace sink_backward {
|
||||
* transposes for all inputs.
|
||||
*/
|
||||
ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr<ov::Node>& main_node,
|
||||
const std::shared_ptr<ov::opset9::Constant>& transpose_const,
|
||||
const std::shared_ptr<ov::opset10::Constant>& transpose_const,
|
||||
std::vector<int> input_indexes = {});
|
||||
} // namespace sink_backward
|
||||
|
||||
@ -109,6 +104,6 @@ void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
|
||||
*/
|
||||
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
|
||||
const ov::AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::opset9::Constant>& axis);
|
||||
const std::shared_ptr<ov::opset10::Constant>& axis);
|
||||
|
||||
} // namespace transpose_sinking
|
||||
|
@ -81,18 +81,23 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
if (concat_axis < 0) {
|
||||
return false;
|
||||
}
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_transpose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
if (static_cast<int64_t>(reversed_transpose_axis_order.size()) <= concat_axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto transposed_concat_axis = reversed_transpose_axis_order[concat_axis];
|
||||
concat_node->set_axis(static_cast<int64_t>(transposed_concat_axis));
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
concat_node->validate_and_infer_types();
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
concat_node->validate_and_infer_types();
|
||||
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
SwapNames(transpose, main_node);
|
||||
return true;
|
||||
};
|
||||
|
@ -41,6 +41,9 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
|
||||
|
||||
bool is_ordered = true;
|
||||
for (size_t i = 0; i < order1.size(); i++) {
|
||||
if (order1.size() <= order2[i]) {
|
||||
return false;
|
||||
}
|
||||
order2[i] = order1[order2[i]];
|
||||
if (order2[i] != static_cast<int64_t>(i))
|
||||
is_ordered = false;
|
||||
@ -61,7 +64,7 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
|
||||
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
transpose_sinking::RemoveSingleOutputConsumers(transpose1);
|
||||
copy_runtime_info(transpose1, new_transpose);
|
||||
ngraph::replace_node(transpose1, new_transpose);
|
||||
ov::replace_node(transpose1, new_transpose);
|
||||
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
|
||||
}
|
||||
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
|
||||
auto transpose_parent = transpose->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
|
Loading…
Reference in New Issue
Block a user