Add descriptions to the transformations, add additional checks

This commit is contained in:
Ivan 2023-03-14 17:48:46 +04:00
parent 3a96e06d4c
commit d284ac1b7a
12 changed files with 103 additions and 20 deletions

View File

@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingBinaryBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic,
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction.
*/
class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0");
TransposeSinkingBinaryForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic,
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction.
*/
class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0");

View File

@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingConcatBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingConcatForward transformation sinks Transpose through Concat operation
* in the forward direction.
*/
class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0");
TransposeSinkingConcatForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingConcatBackward transformation sinks Transpose through Concat operation
* in the backward direction.
*/
class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0");

View File

@ -17,12 +17,24 @@ class TRANSFORMATIONS_API TransposeSinkingDataMovementBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch
* and Pad operations in the forward direction.
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
*/
class ov::pass::TransposeSinkingDataMovementForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementForward", "0");
TransposeSinkingDataMovementForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch
* and Pad operations in the backward direction.
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
*/
class ov::pass::TransposeSinkingDataMovementBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementBackward", "0");

View File

@ -17,18 +17,34 @@ class TRANSFORMATIONS_API TransposeSinkingGeneral;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingGeneralForward transformation combines all TransposeSinkingForward* transformations into
* single GraphRewrite pass.
*/
class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingGeneralForward", "0");
TransposeSinkingGeneralForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingGeneralBackward transformation combines all TransposeSinkingBackward* transformations into
* single GraphRewrite pass.
*/
class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0");
TransposeSinkingGeneralBackward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingGeneral transformation combines TransposeSinkingGeneralForward and
* TransposeSinkingGeneralBackward transformations into single ModelPass pass and inserts
* ConstantFolding pass after them.
*/
class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("TransposeSinkingGeneral", "0");

View File

@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingInterpolateBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingInterpolateForward transformation sinks Transpose through Interpolate operation
* in the forward direction.
*/
class ov::pass::TransposeSinkingInterpolateForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateForward", "0");
TransposeSinkingInterpolateForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingInterpolateBackward transformation sinks Transpose through Interpolate operation
* in the backward direction.
*/
class ov::pass::TransposeSinkingInterpolateBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateBackward", "0");

View File

@ -19,7 +19,8 @@ class TRANSFORMATIONS_API TransposeSinkingReductionBackward;
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionForward transformation sinks Transpose through Reduce operations
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
* in the forward direction.
*/
class ov::pass::TransposeSinkingReductionForward : public ov::pass::MatcherPass {
public:
@ -29,7 +30,8 @@ public:
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce operations
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
* in the backward direction.
*/
class ov::pass::TransposeSinkingReductionBackward : public ov::pass::MatcherPass {
public:

View File

@ -17,12 +17,22 @@ class TRANSFORMATIONS_API TransposeSinkingSplitForward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingSplitForward transformation sinks Transpose through Split, VariadicSplit operations
* in the forward direction.
*/
class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0");
TransposeSinkingSplitForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingSplitBackward transformation sinks Transpose through Split, VariadicSplit operations
* in the backward direction.
*/
class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0");

View File

@ -16,12 +16,22 @@ class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction.
*/
class ov::pass::TransposeSinkingUnaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryForward", "0");
TransposeSinkingUnaryForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction.
*/
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");

View File

@ -4,23 +4,18 @@
#pragma once
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
namespace transpose_sinking {
struct TransposeInputsInfo {
std::shared_ptr<ov::opset9::Transpose> transpose;
std::shared_ptr<ov::opset9::Constant> transpose_const;
std::shared_ptr<ov::opset10::Transpose> transpose;
std::shared_ptr<ov::opset10::Constant> transpose_const;
size_t input_idx;
bool isEmpty() const {
@ -87,7 +82,7 @@ namespace sink_backward {
* transposes for all inputs.
*/
ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr<ov::Node>& main_node,
const std::shared_ptr<ov::opset9::Constant>& transpose_const,
const std::shared_ptr<ov::opset10::Constant>& transpose_const,
std::vector<int> input_indexes = {});
} // namespace sink_backward
@ -109,6 +104,6 @@ void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
*/
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
const ov::AxisVector& transpose_axis_order,
const std::shared_ptr<ov::opset9::Constant>& axis);
const std::shared_ptr<ov::opset10::Constant>& axis);
} // namespace transpose_sinking

View File

@ -81,18 +81,23 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
if (concat_axis < 0) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);
}
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_axis_order = ReverseTransposeOrder(transpose_axis_order);
if (static_cast<int64_t>(reversed_transpose_axis_order.size()) <= concat_axis) {
return false;
}
const auto transposed_concat_axis = reversed_transpose_axis_order[concat_axis];
concat_node->set_axis(static_cast<int64_t>(transposed_concat_axis));
concat_node->set_concatenation_axis(-1);
concat_node->validate_and_infer_types();
// remove output transposes
RemoveSingleOutputConsumers(main_node);
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);
}
concat_node->validate_and_infer_types();
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
return true;
};

View File

@ -41,6 +41,9 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
bool is_ordered = true;
for (size_t i = 0; i < order1.size(); i++) {
if (order1.size() <= order2[i]) {
return false;
}
order2[i] = order1[order2[i]];
if (order2[i] != static_cast<int64_t>(i))
is_ordered = false;
@ -61,7 +64,7 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
transpose_sinking::RemoveSingleOutputConsumers(transpose1);
copy_runtime_info(transpose1, new_transpose);
ngraph::replace_node(transpose1, new_transpose);
ov::replace_node(transpose1, new_transpose);
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
}

View File

@ -40,7 +40,7 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
}
// remove Transpose on 1st input:
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
auto transpose_parent = transpose->input_value(0);
main_node->input(0).replace_source_output(transpose_parent);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();