Merge branch 'itikhono/ts/refactoring' into itikhono/ts/slice

This commit is contained in:
Ivan
2023-03-15 17:10:30 +04:00
39 changed files with 720 additions and 692 deletions

View File

@@ -1,40 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingBinaryForward;
class TRANSFORMATIONS_API TransposeSinkingBinaryBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic,
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction.
*/
class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0");
TransposeSinkingBinaryForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic,
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction.
*/
class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0");
TransposeSinkingBinaryBackward();
};

View File

@@ -1,40 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingConcatForward;
class TRANSFORMATIONS_API TransposeSinkingConcatBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingConcatForward transformation sinks Transpose through Concat operation
* in the forward direction.
*/
class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0");
TransposeSinkingConcatForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingConcatBackward transformation sinks Transpose through Concat operation
* in the backward direction.
*/
class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0");
TransposeSinkingConcatBackward();
};

View File

@@ -1,42 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingDataMovementForward;
class TRANSFORMATIONS_API TransposeSinkingDataMovementBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch
* and Pad operations in the forward direction.
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
*/
class ov::pass::TransposeSinkingDataMovementForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementForward", "0");
TransposeSinkingDataMovementForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch
* and Pad operations in the backward direction.
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
*/
class ov::pass::TransposeSinkingDataMovementBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementBackward", "0");
TransposeSinkingDataMovementBackward();
};

View File

@@ -1,52 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingGeneralForward;
class TRANSFORMATIONS_API TransposeSinkingGeneralBackward;
class TRANSFORMATIONS_API TransposeSinkingGeneral;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingGeneralForward transformation combines all TransposeSinkingForward* transformations into
* single GraphRewrite pass.
*/
class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingGeneralForward", "0");
TransposeSinkingGeneralForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingGeneralBackward transformation combines all TransposeSinkingBackward* transformations into
* single GraphRewrite pass.
*/
class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0");
TransposeSinkingGeneralBackward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingGeneral transformation combines TransposeSinkingGeneralForward and
* TransposeSinkingGeneralBackward transformations into single ModelPass pass and inserts
* ConstantFolding pass after them.
*/
class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("TransposeSinkingGeneral", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};

View File

@@ -1,40 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingInterpolateForward;
class TRANSFORMATIONS_API TransposeSinkingInterpolateBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingInterpolateForward transformation sinks Transpose through Interpolate operation
* in the forward direction.
*/
class ov::pass::TransposeSinkingInterpolateForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateForward", "0");
TransposeSinkingInterpolateForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingInterpolateBackward transformation sinks Transpose through Interpolate operation
* in the backward direction.
*/
class ov::pass::TransposeSinkingInterpolateBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateBackward", "0");
TransposeSinkingInterpolateBackward();
};

View File

@@ -1,40 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingSplitBackward;
class TRANSFORMATIONS_API TransposeSinkingSplitForward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingSplitForward transformation sinks Transpose through Split, VariadicSplit operations
* in the forward direction.
*/
class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0");
TransposeSinkingSplitForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingSplitBackward transformation sinks Transpose through Split, VariadicSplit operations
* in the backward direction.
*/
class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0");
TransposeSinkingSplitBackward();
};

View File

@@ -1,39 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingUnaryForward;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction.
*/
class ov::pass::TransposeSinkingUnaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryForward", "0");
TransposeSinkingUnaryForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction.
*/
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");
TransposeSinkingUnaryBackward();
};

View File

@@ -0,0 +1,42 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSBinaryForward;
class TRANSFORMATIONS_API TSBinaryBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic,
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction.
*/
class ov::pass::transpose_sinking::TSBinaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
TSBinaryForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic,
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction.
*/
class ov::pass::transpose_sinking::TSBinaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
TSBinaryBackward();
};

View File

@@ -0,0 +1,42 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSConcatForward;
class TRANSFORMATIONS_API TSConcatBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSConcatForward transformation sinks Transpose through Concat operation
* in the forward direction.
*/
class ov::pass::transpose_sinking::TSConcatForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSConcatForward", "0");
TSConcatForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSConcatBackward transformation sinks Transpose through Concat operation
* in the backward direction.
*/
class ov::pass::transpose_sinking::TSConcatBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSConcatBackward", "0");
TSConcatBackward();
};

View File

@@ -0,0 +1,44 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSDataMovementForward;
class TRANSFORMATIONS_API TSDataMovementBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch
* and Pad operations in the forward direction.
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
*/
class ov::pass::transpose_sinking::TSDataMovementForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSDataMovementForward", "0");
TSDataMovementForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch
* and Pad operations in the backward direction.
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
*/
class ov::pass::transpose_sinking::TSDataMovementBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSDataMovementBackward", "0");
TSDataMovementBackward();
};

View File

@@ -10,19 +10,21 @@
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TransposeSinkingFuse;
class TRANSFORMATIONS_API TSFuse;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input
* @brief TSFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input
* or fuses them to single Transpose if input gets changed
*/
class ov::pass::TransposeSinkingFuse : public ov::pass::MatcherPass {
class ov::pass::transpose_sinking::TSFuse : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingFuse", "0");
TransposeSinkingFuse();
OPENVINO_RTTI("TSFuse", "0");
TSFuse();
};

View File

@@ -0,0 +1,54 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSGeneralForward;
class TRANSFORMATIONS_API TSGeneralBackward;
class TRANSFORMATIONS_API TSGeneral;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSGeneralForward transformation combines all TransposeSinkingForward* transformations into
* single GraphRewrite pass.
*/
class ov::pass::transpose_sinking::TSGeneralForward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TSGeneralForward", "0");
TSGeneralForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSGeneralBackward transformation combines all TransposeSinkingBackward* transformations into
* single GraphRewrite pass.
*/
class ov::pass::transpose_sinking::TSGeneralBackward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TSGeneralBackward", "0");
TSGeneralBackward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSGeneral transformation combines TSGeneralForward and
* TSGeneralBackward transformations into single ModelPass pass and inserts
* ConstantFolding pass after them.
*/
class ov::pass::transpose_sinking::TSGeneral : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("TSGeneral", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};

View File

@@ -0,0 +1,42 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSInterpolateForward;
class TRANSFORMATIONS_API TSInterpolateBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSInterpolateForward transformation sinks Transpose through Interpolate operation
* in the forward direction.
*/
class ov::pass::transpose_sinking::TSInterpolateForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSInterpolateForward", "0");
TSInterpolateForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSInterpolateBackward transformation sinks Transpose through Interpolate operation
* in the backward direction.
*/
class ov::pass::transpose_sinking::TSInterpolateBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSInterpolateBackward", "0");
TSInterpolateBackward();
};

View File

@@ -10,10 +10,12 @@
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TransposeSinkingReductionForward;
class TRANSFORMATIONS_API TransposeSinkingReductionBackward;
class TRANSFORMATIONS_API TSReductionForward;
class TRANSFORMATIONS_API TSReductionBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
@@ -22,10 +24,10 @@ class TRANSFORMATIONS_API TransposeSinkingReductionBackward;
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
* in the forward direction.
*/
class ov::pass::TransposeSinkingReductionForward : public ov::pass::MatcherPass {
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingReductionForward", "0");
TransposeSinkingReductionForward();
OPENVINO_RTTI("ov::pass::TSReductionForward", "0");
TSReductionForward();
};
/**
@@ -33,8 +35,8 @@ public:
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
* in the backward direction.
*/
class ov::pass::TransposeSinkingReductionBackward : public ov::pass::MatcherPass {
class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingReductionBackward", "0");
TransposeSinkingReductionBackward();
OPENVINO_RTTI("ov::pass::TSReductionBackward", "0");
TSReductionBackward();
};

View File

@@ -0,0 +1,42 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSSplitBackward;
class TRANSFORMATIONS_API TSSplitForward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSSplitForward transformation sinks Transpose through Split, VariadicSplit operations
* in the forward direction.
*/
class ov::pass::transpose_sinking::TSSplitForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSSplitForward", "0");
TSSplitForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSSplitBackward transformation sinks Transpose through Split, VariadicSplit operations
* in the backward direction.
*/
class ov::pass::transpose_sinking::TSSplitBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSSplitBackward", "0");
TSSplitBackward();
};

View File

@@ -0,0 +1,41 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSUnaryForward;
class TRANSFORMATIONS_API TSUnaryBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction.
*/
class ov::pass::transpose_sinking::TSUnaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TSUnaryForward", "0");
TSUnaryForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction.
*/
class ov::pass::transpose_sinking::TSUnaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TSUnaryBackwardMultiConsumers", "0");
TSUnaryBackward();
};

View File

@@ -4,14 +4,17 @@
#pragma once
#include <transformations/utils/utils.hpp>
#include <utility>
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/utils/utils.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
namespace utils {
struct TransposeInputsInfo {
std::shared_ptr<ov::opset10::Transpose> transpose;
@@ -106,4 +109,7 @@ ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
const ov::AxisVector& transpose_axis_order,
const std::shared_ptr<ov::opset10::Constant>& axis);
} // namespace utils
} // namespace transpose_sinking
} // namespace pass
} // namespace ov

View File

@@ -14,7 +14,6 @@
#include <vector>
#include "itt.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
@@ -32,7 +31,7 @@ std::shared_ptr<opset6::Constant> get_reduced_order_constant(const std::shared_p
order.erase(order.begin() + i);
} else {
// if 2nd input for Squeeze op is not provided, we should remove all 1 dims
// this case will be supported in new TransposeSinkingGeneral transformation.
// this case will be supported in new TSGeneral transformation.
return nullptr;
}
@@ -318,8 +317,6 @@ ov::pass::TransposeFuse::TransposeFuse() {
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose);
ngraph::replace_node(m.get_match_root(), new_transpose);
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
}
return true;

View File

@@ -1,64 +0,0 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/graph_rewrite.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include "itt.hpp"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/utils/utils.hpp"
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
MATCHER_SCOPE(TransposeSinkingGeneralForward);
add_matcher<ov::pass::TransposeSinkingUnaryForward>();
add_matcher<ov::pass::TransposeSinkingBinaryForward>();
add_matcher<ov::pass::TransposeSinkingConcatForward>();
add_matcher<ov::pass::TransposeSinkingSplitForward>();
add_matcher<ov::pass::TransposeSinkingDataMovementForward>();
add_matcher<ov::pass::TransposeSinkingReductionForward>();
add_matcher<ov::pass::TransposeSinkingInterpolateForward>();
add_matcher<ov::pass::TransposeSinkingFuse>();
}
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
MATCHER_SCOPE(TransposeSinkingGeneralBackward);
add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
add_matcher<ov::pass::TransposeSinkingBinaryBackward>();
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
add_matcher<ov::pass::TransposeSinkingDataMovementBackward>();
add_matcher<ov::pass::TransposeSinkingReductionBackward>();
add_matcher<ov::pass::TransposeSinkingInterpolateBackward>();
add_matcher<ov::pass::TransposeSinkingFuse>();
}
bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral);
{
ov::pass::Manager manager(get_pass_config());
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
manager.register_pass<ov::pass::ConstantFolding>();
manager.run_passes(f);
}
{
ov::pass::Manager manager(get_pass_config());
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
manager.register_pass<ov::pass::ConstantFolding>();
manager.run_passes(f);
}
return false;
}

View File

@@ -2,24 +2,24 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include "transformations/transpose_sinking/ts_binary.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset10;
using namespace transpose_sinking;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
MATCHER_SCOPE(TransposeSinkingBinaryForward);
TSBinaryForward::TSBinaryForward() {
MATCHER_SCOPE(TSBinaryForward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
op::util::BinaryElementwiseComparison,
@@ -42,7 +42,7 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
};
@@ -51,8 +51,8 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
MATCHER_SCOPE(TransposeSinkingBinaryBackward);
TSBinaryBackward::TSBinaryBackward() {
MATCHER_SCOPE(TSBinaryBackward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
op::util::BinaryElementwiseComparison,

View File

@@ -2,22 +2,23 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset10;
using namespace transpose_sinking;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
MATCHER_SCOPE(TransposeSinkingConcatForward);
TSConcatForward::TSConcatForward() {
MATCHER_SCOPE(TSConcatForward);
auto main_node_label = wrap_type<Concat>(IfNodeHasTransposeInputs);
@@ -47,7 +48,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
@@ -57,8 +58,8 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
MATCHER_SCOPE(TransposeSinkingConcatBackward);
TSConcatBackward::TSConcatBackward() {
MATCHER_SCOPE(TSConcatBackward);
auto main_node_label = wrap_type<Concat>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);

View File

@@ -2,25 +2,24 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include "transformations/transpose_sinking/ts_data_movement.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace transpose_sinking;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForward() {
MATCHER_SCOPE(TransposeSinkingDataMovementForward);
TSDataMovementForward::TSDataMovementForward() {
MATCHER_SCOPE(TSDataMovementForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label =
@@ -65,7 +64,7 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
};
@@ -74,8 +73,8 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingDataMovementBackward::TransposeSinkingDataMovementBackward() {
MATCHER_SCOPE(TransposeSinkingDataMovementBackward);
TSDataMovementBackward::TSDataMovementBackward() {
MATCHER_SCOPE(TSDataMovementBackward);
auto main_node_label = wrap_type<Pad, BatchToSpace, SpaceToBatch>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);

View File

@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include <memory>
#include <vector>
@@ -11,16 +11,18 @@
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace opset10;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
TSFuse::TSFuse() {
MATCHER_SCOPE(TransposeFuse);
auto transpose_1_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
transpose_sinking::HasSameOutputTransposeNodes);
HasSameOutputTransposeNodes);
auto transpose_2_label = pattern::wrap_type<Transpose>({transpose_1_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
@@ -62,11 +64,11 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
auto new_transpose = register_new_node<Transpose>(input, new_order);
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
transpose_sinking::RemoveSingleOutputConsumers(transpose1);
RemoveSingleOutputConsumers(transpose1);
copy_runtime_info(transpose1, new_transpose);
ov::replace_node(transpose1, new_transpose);
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
UpdateForwardSinkingAbility(new_transpose);
}
return true;
};

View File

@@ -0,0 +1,65 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_general.hpp"
#include "itt.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/transpose_sinking/ts_binary.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp"
#include "transformations/transpose_sinking/ts_data_movement.hpp"
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov::pass::transpose_sinking;
TSGeneralForward::TSGeneralForward() {
MATCHER_SCOPE(TSGeneralForward);
add_matcher<TSUnaryForward>();
add_matcher<TSBinaryForward>();
add_matcher<TSConcatForward>();
add_matcher<TSSplitForward>();
add_matcher<TSDataMovementForward>();
add_matcher<TSReductionForward>();
add_matcher<TSInterpolateForward>();
add_matcher<TSFuse>();
}
TSGeneralBackward::TSGeneralBackward() {
MATCHER_SCOPE(TSGeneralBackward);
add_matcher<TSUnaryBackward>();
add_matcher<TSBinaryBackward>();
add_matcher<TSConcatBackward>();
add_matcher<TSSplitBackward>();
add_matcher<TSDataMovementBackward>();
add_matcher<TSReductionBackward>();
add_matcher<TSInterpolateBackward>();
add_matcher<TSFuse>();
}
bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(TSGeneral);
{
Manager manager(get_pass_config());
manager.register_pass<TSGeneralForward>();
manager.register_pass<ConstantFolding>();
manager.run_passes(f);
}
{
Manager manager(get_pass_config());
manager.register_pass<TSGeneralBackward>();
manager.register_pass<ConstantFolding>();
manager.run_passes(f);
}
return false;
}

View File

@@ -2,25 +2,25 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace transpose_sinking;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward() {
MATCHER_SCOPE(TransposeSinkingInterpolateForward);
TSInterpolateForward::TSInterpolateForward() {
MATCHER_SCOPE(TSInterpolateForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label = wrap_type<Interpolate>({transpose_label, any_input(), any_input(), any_input()});
@@ -74,7 +74,7 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
};
@@ -83,8 +83,8 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingInterpolateBackward::TransposeSinkingInterpolateBackward() {
MATCHER_SCOPE(TransposeSinkingInterpolateBackward);
TSInterpolateBackward::TSInterpolateBackward() {
MATCHER_SCOPE(TSInterpolateBackward);
auto main_node_label = wrap_type<Interpolate>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);

View File

@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include <memory>
#include <vector>
@@ -13,11 +13,13 @@
#include "openvino/op/util/logical_reduction_keep_dims.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace opset10;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
namespace {
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
@@ -80,8 +82,8 @@ bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
}
} // namespace
ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
MATCHER_SCOPE(TransposeSinkingReductionForward);
TSReductionForward::TSReductionForward() {
MATCHER_SCOPE(TSReductionForward);
auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
pattern::consumers_count(1));
@@ -150,7 +152,7 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
replace_node(reduction, new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(reduction->get_friendly_name());
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
return true;
@@ -160,13 +162,13 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() {
MATCHER_SCOPE(TransposeSinkingReductionBackward);
TSReductionBackward::TSReductionBackward() {
MATCHER_SCOPE(TSReductionBackward);
auto reduce_or_squeeze_label = pattern::
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
{pattern::any_input(), pattern::wrap_type<Constant>()},
transpose_sinking::HasSameOutputTransposeNodes);
HasSameOutputTransposeNodes);
auto transpose_label = pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
@@ -225,7 +227,7 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward()
}
if (!unsqueeze) {
auto reversed_order_values = transpose_sinking::ReverseTransposeOrder(transpose_order_values);
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
for (const auto& axis : non_negative_axes) {
new_values.push_back(reversed_order_values[axis]);
}
@@ -246,7 +248,7 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward()
}
replace_node(transpose, new_reduction);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
UpdateForwardSinkingAbility(new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
register_new_node(new_transpose);
return true;

View File

@@ -2,24 +2,22 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "transformations/transpose_sinking/ts_split.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset10;
using namespace transpose_sinking;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
namespace {
@@ -107,7 +105,7 @@ bool GetSplitAxis(const std::shared_ptr<Constant>& split_axis, const ov::Rank& r
* Consider case Split (1) -> Split (2) -> Transpose
* If specify Split as main searched node after first transformation work we will have
* Split (1) -> Transpose -> Split(2)
* Matcher pass will not call TransposeSinkingSplitBackward since
* Matcher pass will not call TSSplitBackward since
* - matcher pattern has no Transpose label
* - Split (1) has already been proceeded
* Adding Split(2) into the working queue as register_new_node(split)
@@ -121,8 +119,8 @@ bool GetSplitAxis(const std::shared_ptr<Constant>& split_axis, const ov::Rank& r
* - add reversed Transpose operations on all outputs except sinking Transpose
* nothing to do with new added output Transposes
*/
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
MATCHER_SCOPE(TransposeSinkingSplitBackward);
TSSplitBackward::TSSplitBackward() {
MATCHER_SCOPE(TSSplitBackward);
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), transpose_const_label}, IsSplitSinked);
@@ -192,8 +190,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
MATCHER_SCOPE(TransposeSinkingSplitForward);
TSSplitForward::TSSplitForward() {
MATCHER_SCOPE(TSSplitForward);
auto main_node_label = wrap_type<Split, VariadicSplit>(IfNodeHasTransposeInputs);
@@ -225,7 +223,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;

View File

@@ -2,22 +2,23 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace transpose_sinking;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
namespace {
@@ -51,8 +52,8 @@ NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
} // namespace
ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
MATCHER_SCOPE(TransposeSinkingUnaryForward);
TSUnaryForward::TSUnaryForward() {
MATCHER_SCOPE(TSUnaryForward);
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
auto unary_label =
@@ -73,7 +74,7 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
return true;
};
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward");
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TSUnaryForward");
register_matcher(m, matcher_pass_callback);
}
@@ -83,8 +84,8 @@ bool IfSinkingEnabled(const Output<Node>& output) {
}
} // namespace
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers);
TSUnaryBackward::TSUnaryBackward() {
MATCHER_SCOPE(TSUnaryBackwardMultiConsumers);
auto unary_restrictions = [](const Output<Node>& output) -> bool {
return HasSameOutputTransposeNodes(output);
@@ -115,6 +116,6 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackward");
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TSUnaryBackward");
register_matcher(m, matcher_pass_callback);
}

View File

@@ -2,17 +2,19 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include <transformations/utils/utils.hpp>
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/utils/utils.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
namespace utils {
using namespace ov;
using namespace ov::opset10;
@@ -377,4 +379,7 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
}
}
} // namespace utils
} // namespace transpose_sinking
} // namespace pass
} // namespace ov

View File

@@ -1,40 +1,24 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_binary.hpp"
#include <functional>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/transpose_sinking_binary.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "transpose_sinking_test_utils.hpp"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "ts_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace transpose_sinking::testing;
namespace transpose_sinking_binary_eltwise {
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing::utils;
namespace {
namespace {
std::string to_string(const Shape& shape) {
std::ostringstream result;
result << "{";
for (size_t idx = 0; idx < shape.size(); ++idx) {
if (idx)
result << ",";
result << shape[idx];
}
result << "}";
return result.str();
}
} // namespace
// ----------------------------------------------------------------------------
template <typename BinaryT>
class BinaryFactory : public IFactory {
@@ -87,6 +71,10 @@ std::vector<size_t> binary_transpose_input_indexes = {0, 1};
} // namespace
namespace transpose_sinking {
namespace testing {
namespace binary {
namespace single_consumer {
namespace forward {
namespace one_input_transpose {
@@ -207,7 +195,7 @@ class TransposeSinkingBinaryTwoTransposeInputsTestFixture
: public ::testing::WithParamInterface<TestBinaryTwoTransposeInputsParams>,
public TransformationTestsF {
public:
static std::string get_test_name(const testing::TestParamInfo<TestBinaryTwoTransposeInputsParams>& obj) {
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryTwoTransposeInputsParams>& obj) {
FactoryPtr binary_factory;
PassFactoryPtr pass_factory;
size_t num_binary_ops;
@@ -247,7 +235,7 @@ TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite,
TransposeSinkingBinaryTwoTransposeInputsTestFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
::testing::ValuesIn(binary_operations_numbers),
::testing::Values(CreateFunction),
::testing::Values(CreateReferenceFunction),
@@ -327,7 +315,7 @@ using TestBinaryParams = std::tuple<FactoryPtr,
class TransposeSinkingBinaryTestFixture : public ::testing::WithParamInterface<TestBinaryParams>,
public TransformationTestsF {
public:
static std::string get_test_name(const testing::TestParamInfo<TestBinaryParams>& obj) {
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryParams>& obj) {
FactoryPtr binary_factory;
PassFactoryPtr pass_factory;
size_t num_binary_ops;
@@ -377,10 +365,10 @@ TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) {
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingBinaryForwardTestSuite,
TSBinaryForwardTestSuite,
TransposeSinkingBinaryTestFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
::testing::ValuesIn(binary_operations_numbers),
::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction),
::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction),
@@ -389,10 +377,10 @@ INSTANTIATE_TEST_SUITE_P(
TransposeSinkingBinaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingBinaryBackwardTestSuite,
TSBinaryBackwardTestSuite,
TransposeSinkingBinaryTestFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
::testing::ValuesIn(binary_operations_numbers),
::testing::Values(single_consumer::backward::one_input_transpose::CreateFunction),
::testing::Values(single_consumer::backward::one_input_transpose::CreateReferenceFunction),
@@ -421,7 +409,7 @@ class TransposeSinkingBinaryIncompatShapesTestFixture
: public ::testing::WithParamInterface<TestBinaryIncompatShapesParams>,
public TransformationTestsF {
public:
static std::string get_test_name(const testing::TestParamInfo<TestBinaryIncompatShapesParams>& obj) {
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryIncompatShapesParams>& obj) {
FactoryPtr binary_factory;
PassFactoryPtr pass_factory;
Shape input_shape;
@@ -600,7 +588,7 @@ INSTANTIATE_TEST_SUITE_P(
TransposeSinkingBinaryIncompatShapesBackwardTestSuite,
TransposeSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::ValuesIn(binary::single_consumer::backward::incompat_shapes::constant_shapes),
::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction),
@@ -613,7 +601,7 @@ INSTANTIATE_TEST_SUITE_P(
TransposeSinkingBinaryIncompatShapesForwardTestSuite,
TransposeSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::ValuesIn(binary::single_consumer::forward::incompat_shapes::constant_shapes),
::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction),
@@ -626,7 +614,7 @@ INSTANTIATE_TEST_SUITE_P(
TransposeSinkingPReluIncompatShapesBackwardTestSuite,
TransposeSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
::testing::Values(Shape{1, 3, 16, 16}),
::testing::ValuesIn(std::vector<Shape>{Shape{3}}),
::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction),
@@ -639,7 +627,7 @@ INSTANTIATE_TEST_SUITE_P(
TransposeSinkingPReluIncompatShapesForwardTestSuite,
TransposeSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
::testing::Values(Shape{1, 3, 16, 16}),
::testing::ValuesIn(std::vector<Shape>{Shape{3}}),
::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction),
@@ -1090,7 +1078,7 @@ using TestBinaryParams = std::tuple<FactoryPtr,
class TransposeBinaryMultiSinkingFixture : public ::testing::WithParamInterface<TestBinaryParams>,
public TransformationTestsF {
public:
static std::string get_test_name(const testing::TestParamInfo<TestBinaryParams>& obj) {
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryParams>& obj) {
FactoryPtr binary_factory;
PassFactoryPtr pass_factory;
CreateGraphFunctionDesc function_desc;
@@ -1139,19 +1127,19 @@ std::vector<CreateGraphFunctionDesc> backward_subtests = {
#undef SUBTEST
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardMultiConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSBinaryForwardMultiConsumersTestSuite,
TransposeBinaryMultiSinkingFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
::testing::ValuesIn(forward_subtests),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
TransposeBinaryMultiSinkingFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardMultiConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSBinaryBackwardMultiConsumersTestSuite,
TransposeBinaryMultiSinkingFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
::testing::ValuesIn(backward_subtests),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
@@ -1177,7 +1165,7 @@ using TestBinaryParams = std::tuple<FactoryPtr,
class TransposeBinaryMultiSinkingBinaryMultiConsumersFixture : public ::testing::WithParamInterface<TestBinaryParams>,
public TransformationTestsF {
public:
static std::string get_test_name(const testing::TestParamInfo<TestBinaryParams>& obj) {
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryParams>& obj) {
FactoryPtr binary_factory;
PassFactoryPtr pass_factory;
CreateGraphFunctionDesc function_desc;
@@ -1219,10 +1207,10 @@ std::vector<CreateGraphFunctionDesc> backward_subtests_binary_consumers = {
};
#undef SUBTEST
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardBinaryMultiConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSBinaryBackwardBinaryMultiConsumersTestSuite,
TransposeBinaryMultiSinkingBinaryMultiConsumersFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
::testing::ValuesIn(backward_subtests_binary_consumers),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
@@ -1232,4 +1220,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardBinaryMultiConsumersTestS
} // namespace mult_consumers
} // namespace transpose_sinking_binary_eltwise
} // namespace binary
} // namespace testing
} // namespace transpose_sinking

View File

@@ -2,27 +2,27 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/common_optimizations/transpose_sinking_slice.hpp"
#include "transpose_sinking_test_utils.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/transpose_sinking/ts_binary.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp"
#include "transformations/transpose_sinking/ts_data_movement.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "ts_test_utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::opset10;
using namespace transpose_sinking::testing;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing::utils;
namespace transpose_sinking {
namespace testing {
namespace common {
template <typename UnaryT>
@@ -378,7 +378,7 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
test_case.num_main_ops = num_main_ops;
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
@@ -387,12 +387,12 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = factories;
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.main_op = factories;
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -408,7 +408,7 @@ auto test_forward_binary = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryForward);
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryForward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
@@ -418,13 +418,13 @@ auto test_forward_binary = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = binary_factories;
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{1}}};
test_case.model_ref.main_op = binary_factories;
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -435,7 +435,7 @@ auto test_forward_concat = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatForward);
test_case.transformation = CREATE_PASS_FACTORY(TSConcatForward);
test_case.num_main_ops = {1, 3};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
@@ -446,13 +446,13 @@ auto test_forward_concat = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_CONCAT_FACTORY(Concat)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{1, 2}}};
test_case.model_ref.main_op = {CREATE_CONCAT_REF_FACTORY(Concat)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -463,7 +463,7 @@ auto test_forward_split = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitForward);
test_case.transformation = CREATE_PASS_FACTORY(TSSplitForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 9, 55, 55}),
@@ -473,7 +473,7 @@ auto test_forward_split = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_SPLIT_FACTORY(Split)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -486,7 +486,7 @@ auto test_forward_split = []() {
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
test_case.model_ref.main_op = {CREATE_SPLIT_FACTORY(Split)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0, 1, 2}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -497,7 +497,7 @@ auto test_forward_pad = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 3, 55, 55}),
@@ -508,13 +508,13 @@ auto test_forward_pad = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_PAD_FACTORY(Pad)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}};
test_case.model_ref.main_op = {CREATE_PAD_FACTORY(Pad)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -525,7 +525,7 @@ auto test_forward_batch_to_space = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {128, 55, 3, 128}),
@@ -537,13 +537,13 @@ auto test_forward_batch_to_space = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}};
test_case.model_ref.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -556,7 +556,7 @@ auto test_forward_space_to_batch = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {64, 9, 8, 1}),
@@ -568,13 +568,13 @@ auto test_forward_space_to_batch = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}};
test_case.model_ref.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -587,7 +587,7 @@ auto test_forward_reduction = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 4, 2, 1}),
@@ -597,7 +597,7 @@ auto test_forward_reduction = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = reduction_factories;
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -610,7 +610,7 @@ auto test_forward_reduction = []() {
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
test_case.model_ref.main_op = reduction_factories;
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -621,7 +621,7 @@ auto test_forward_interpolate = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateForward);
test_case.transformation = CREATE_PASS_FACTORY(TSInterpolateForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {1, 2, 48, 80}),
@@ -633,7 +633,7 @@ auto test_forward_interpolate = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, false)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -653,7 +653,7 @@ auto test_forward_interpolate = []() {
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{3}}};
test_case.model_ref.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, true)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -666,7 +666,7 @@ auto test_forward_squeeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
@@ -676,7 +676,7 @@ auto test_forward_squeeze = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -689,7 +689,7 @@ auto test_forward_squeeze = []() {
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -700,7 +700,7 @@ auto test_forward_unsqueeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
@@ -710,7 +710,7 @@ auto test_forward_unsqueeze = []() {
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -730,7 +730,7 @@ auto test_forward_unsqueeze = []() {
return new_out_vec;
};
test_case.model_ref.preprocess_outputs_of_main = {{new_transpose}, {{0}}};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -786,7 +786,7 @@ auto test_backward_unary = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
@@ -795,12 +795,12 @@ auto test_backward_unary = []() {
// Test model description:
test_case.model.main_op = unary_factories;
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.main_op = unary_factories;
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -811,7 +811,7 @@ auto test_backward_binary = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryBackward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
@@ -821,12 +821,12 @@ auto test_backward_binary = []() {
// Test model description:
test_case.model.main_op = binary_factories;
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0, 1}}};
test_case.model_ref.main_op = binary_factories;
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -837,7 +837,7 @@ auto test_backward_concat = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSConcatBackward);
test_case.num_main_ops = {1, 3};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
@@ -848,12 +848,12 @@ auto test_backward_concat = []() {
// Test model description:
test_case.model.main_op = {CREATE_CONCAT_FACTORY(Concat)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0, 1, 2}}};
test_case.model_ref.main_op = {CREATE_CONCAT_REF_FACTORY(Concat)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -864,7 +864,7 @@ auto test_backward_split = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSSplitBackward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 9, 55, 55}),
@@ -874,7 +874,7 @@ auto test_backward_split = []() {
// Test model description:
test_case.model.main_op = {CREATE_SPLIT_FACTORY(Split)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0, 1, 2}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -886,7 +886,7 @@ auto test_backward_split = []() {
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_SPLIT_FACTORY(Split)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -896,7 +896,7 @@ auto test_backward_pad = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 3, 55, 55}),
@@ -907,12 +907,12 @@ auto test_backward_pad = []() {
// Test model description:
test_case.model.main_op = {CREATE_PAD_FACTORY(Pad)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2}}};
test_case.model_ref.main_op = {CREATE_PAD_FACTORY(Pad)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -923,7 +923,7 @@ auto test_backward_batch_to_space = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {128, 55, 3, 128}),
@@ -935,12 +935,12 @@ auto test_backward_batch_to_space = []() {
// Reference model description:
test_case.model.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Test model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2, 3}}};
test_case.model_ref.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -953,7 +953,7 @@ auto test_backward_space_to_batch = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {1, 8, 9, 64}),
@@ -965,12 +965,12 @@ auto test_backward_space_to_batch = []() {
// Test model description:
test_case.model.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2, 3}}};
test_case.model_ref.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -982,7 +982,7 @@ auto test_backward_reduction = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 4, 2, 1}),
@@ -992,7 +992,7 @@ auto test_backward_reduction = []() {
// Test model description:
test_case.model.main_op = reduction_factories;
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -1004,7 +1004,7 @@ auto test_backward_reduction = []() {
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
test_case.model_ref.main_op = reduction_factories;
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -1017,7 +1017,7 @@ auto test_backward_interpolate = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSInterpolateBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {1, 2, 48, 80}),
@@ -1029,7 +1029,7 @@ auto test_backward_interpolate = []() {
// Test model description:
test_case.model.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, true)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -1048,7 +1048,7 @@ auto test_backward_interpolate = []() {
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {3}}};
test_case.model_ref.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, false)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -1062,7 +1062,7 @@ auto test_backward_squeeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
@@ -1072,7 +1072,7 @@ auto test_backward_squeeze = []() {
// Test model description:
test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -1084,7 +1084,7 @@ auto test_backward_squeeze = []() {
};
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -1095,7 +1095,7 @@ auto test_backward_unsqueeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
@@ -1105,7 +1105,7 @@ auto test_backward_unsqueeze = []() {
// Test model description:
test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = transpose_sinking::common::create_model;
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
@@ -1117,7 +1117,7 @@ auto test_backward_unsqueeze = []() {
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
test_case.model_ref.model_template = transpose_sinking::common::create_model;
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
@@ -1166,4 +1166,5 @@ auto test_backward_slice = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
} // namespace common
} // namespace testing
} // namespace transpose_sinking

View File

@@ -1,27 +1,28 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_concat.hpp"
#include <functional>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/transpose_sinking_concat.hpp>
#include <transformations/common_optimizations/transpose_sinking_utils.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "transpose_sinking_test_utils.hpp"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "ts_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace transpose_sinking::testing;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing::utils;
namespace {
std::vector<size_t> concat_operations_numbers = {1, 10};
std::vector<size_t> concat_transpose_input_indexes = {0, 2};
NodePtr CreateConcatChain(NodePtr input_node,
@@ -331,9 +332,9 @@ TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) {
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingConcatForwardTestSuite,
TSConcatForwardTestSuite,
TransposeSinkingConcatTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)),
::testing::ValuesIn(concat_operations_numbers),
::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction),
::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction),
@@ -342,9 +343,9 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(5)),
TransposeSinkingConcatTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite,
INSTANTIATE_TEST_SUITE_P(TSConcatBackwardTestSuite,
TransposeSinkingConcatTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)),
::testing::ValuesIn(concat_operations_numbers),
::testing::Values(single_consumer::backward::CreateFunction),
::testing::Values(single_consumer::backward::CreateReferenceFunction),
@@ -408,9 +409,9 @@ TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) {
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingConcatForwardAllTransposesTestSuite,
TSConcatForwardAllTransposesTestSuite,
TransposeSinkingConcatAllTransposesInputTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)),
::testing::ValuesIn(concat_operations_numbers),
::testing::Values(single_consumer::forward::double_transpose::CreateFunction),
::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction),
@@ -937,9 +938,9 @@ std::vector<CreateGraphFunctionDesc> backward_subtests = {
#undef SUBTEST
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardMultiConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSConcatForwardMultiConsumersTestSuite,
TransposeConcatMultiSinkingFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)),
::testing::ValuesIn(concat_operations_numbers),
::testing::ValuesIn(forward_subtests),
::testing::Values(element::f32),
@@ -947,9 +948,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardMultiConsumersTestSuite,
::testing::Values(5)),
TransposeConcatMultiSinkingFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSConcatBackwardMultiConsumersTestSuite,
TransposeConcatMultiSinkingFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)),
::testing::ValuesIn(concat_operations_numbers),
::testing::ValuesIn(backward_subtests),
::testing::Values(element::f32),
@@ -1029,9 +1030,9 @@ std::vector<CreateGraphFunctionNoSinkingDesc> backward_subtests_no_sinking = {
#undef SUBTEST
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSConcatBackwardMultiConsumersTestSuite,
TransposeConcatMultiSinkingConcatConsumersFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)),
::testing::ValuesIn(concat_operations_numbers),
::testing::ValuesIn(backward_subtests_no_sinking),
::testing::Values(element::f32),

View File

@@ -1,26 +1,28 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_general.hpp"
#include <functional>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/init_node_info.hpp"
using namespace testing;
using namespace ov::opset10;
using namespace ov::pass::transpose_sinking;
using NodePtr = std::shared_ptr<ov::Node>;
namespace transpose_sinking {
namespace testing {
namespace general {
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward) {
TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesForward) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_unary_ops = 10;
@@ -53,10 +55,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
manager.register_pass<TSGeneralForward>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackward) {
TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesBackward) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_unary_ops = 10;
@@ -88,10 +90,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackwar
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
manager.register_pass<TSGeneralBackward>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral) {
TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesGeneral) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_unary_ops = 10;
@@ -130,10 +132,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<TSGeneral>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
TEST_F(TransformationTestsF, TSGeneralTestBinaryGeneral) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_binary_ops = 10;
@@ -171,10 +173,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<TSGeneral>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
TEST_F(TransformationTestsF, TSGeneralTestConcatGeneral) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
const size_t num_concat_ops = 3;
@@ -224,7 +226,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<TSGeneral>();
}
// ----------------------------------------------------------------------------------------------------------------------
@@ -364,7 +366,7 @@ NodePtr MakeAllNodesSubgraph(NodePtr parent, size_t split_axis, size_t concat_ax
return in_op;
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
using namespace transpose_sinking::testing::general;
ov::Shape input_shape = {1, 96, 40, 55};
ov::element::Type input_type = ov::element::f32;
@@ -407,7 +409,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
function_ref = std::make_shared<ov::Model>(transpose1, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<TSGeneral>();
}
} // namespace general

View File

@@ -1,20 +1,23 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_split.hpp"
#include <functional>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/transpose_sinking_split.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "transpose_sinking_test_utils.hpp"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/init_node_info.hpp"
#include "ts_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing::utils;
namespace transpose_sinking {
namespace testing {
@@ -527,9 +530,9 @@ TEST_P(TransposeSinkingSplitTestFixture, CompareFunctions) {
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitForwardSingleConsumerTestSuite,
INSTANTIATE_TEST_SUITE_P(TSSplitForwardSingleConsumerTestSuite,
TransposeSinkingSplitTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
::testing::ValuesIn(split_operations_numbers),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(forward::single_consumer::CreateFunction),
@@ -538,9 +541,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitForwardSingleConsumerTestSuite,
TransposeSinkingSplitTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitForwardMultInputNodeConsumersTestSuite,
TSSplitForwardMultInputNodeConsumersTestSuite,
TransposeSinkingSplitTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
::testing::ValuesIn(split_operations_numbers),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(forward::mult_consumers::input_node_consumers::CreateFunction),
@@ -549,9 +552,9 @@ INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitForwardMultInputTransposeConsumersTestSuite,
TSSplitForwardMultInputTransposeConsumersTestSuite,
TransposeSinkingSplitTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
::testing::ValuesIn(split_operations_numbers),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(forward::mult_consumers::input_transpose_consumers::CreateFunction),
@@ -560,9 +563,9 @@ INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitForwardMultOutputConsumersTestSuite,
TSSplitForwardMultOutputConsumersTestSuite,
TransposeSinkingSplitTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
::testing::ValuesIn(split_operations_numbers),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(forward::mult_consumers::output_consumers::CreateFunction),
@@ -570,9 +573,9 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(element::f32)),
TransposeSinkingSplitTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardTestSuite,
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardTestSuite,
TransposeSinkingSplitTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
::testing::ValuesIn(split_tree_depth_nums),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(backward::single_consumer::CreateFunction),
@@ -580,9 +583,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardTestSuite,
::testing::Values(element::f32)),
TransposeSinkingSplitTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultOutputConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardMultOutputConsumersTestSuite,
TransposeSinkingSplitTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
::testing::ValuesIn(split_tree_depth_nums),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(backward::mult_output_consumers::CreateFunction),
@@ -590,9 +593,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultOutputConsumersTestSui
::testing::Values(element::f32)),
TransposeSinkingSplitTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultSplitConsumersTestSuite,
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardMultSplitConsumersTestSuite,
TransposeSinkingSplitTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
::testing::ValuesIn(split_tree_depth_nums),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(backward::mult_split_consumers::CreateFunction),
@@ -764,9 +767,8 @@ using TestSplitBackwardRestrictParams = std::tuple<PassFactoryPtr,
element::Type, /* input type */
TransposeInsertFuncDesc>; /* insert transpose function */
class TransposeSinkingSplitBackwardRestrictTestFixture
: public ::testing::WithParamInterface<TestSplitBackwardRestrictParams>,
public TransformationTestsF {
class TSSplitBackwardRestrictTestFixture : public ::testing::WithParamInterface<TestSplitBackwardRestrictParams>,
public TransformationTestsF {
public:
static std::string get_test_name(const ::testing::TestParamInfo<TestSplitBackwardRestrictParams>& obj) {
PassFactoryPtr pass_factory;
@@ -794,7 +796,7 @@ public:
}
};
TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) {
TEST_P(TSSplitBackwardRestrictTestFixture, CompareFunctions) {
PassFactoryPtr pass_factory;
size_t split_tree_depth;
size_t num_split_outputs;
@@ -821,15 +823,15 @@ std::vector<TransposeInsertFuncDesc> insertTransposeFactories = {FUNC(OnlyFirstT
#undef FUNC
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardRestrictTestSuite,
TransposeSinkingSplitBackwardRestrictTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardRestrictTestSuite,
TSSplitBackwardRestrictTestFixture,
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
::testing::Values(1),
::testing::Values(5),
::testing::Values(backward::restrictions::CreateFunction),
::testing::Values(element::f32),
::testing::ValuesIn(insertTransposeFactories)),
TransposeSinkingSplitBackwardRestrictTestFixture::get_test_name);
TSSplitBackwardRestrictTestFixture::get_test_name);
} // namespace restrictions

View File

@@ -2,13 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transpose_sinking_test_utils.hpp"
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include "ts_test_utils.hpp"
#include "gtest/gtest.h"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
using namespace std;
using namespace ov;
@@ -16,6 +15,7 @@ using namespace ov::opset10;
namespace transpose_sinking {
namespace testing {
namespace utils {
shared_ptr<Node> create_main_node(const OutputVector& inputs, size_t num_ops, const FactoryPtr& creator) {
OutputVector current_inputs = inputs;
@@ -83,5 +83,6 @@ std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShap
return std::make_shared<Parameter>(el_type, ps);
}
} // namespace utils
} // namespace testing
} // namespace transpose_sinking

View File

@@ -4,15 +4,15 @@
#pragma once
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
namespace transpose_sinking {
namespace testing {
namespace utils {
using NodePtr = std::shared_ptr<ov::Node>;
@@ -53,7 +53,7 @@ public:
}
};
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::pass_name>>(#pass_name)
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::transpose_sinking::pass_name>>(#pass_name)
std::string to_string(const ov::Shape& shape);
ov::OutputVector set_transpose_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
@@ -67,5 +67,6 @@ std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& s
return ov::opset10::Constant::create<T>(el_type, shape, value);
}
} // namespace utils
} // namespace testing
} // namespace transpose_sinking

View File

@@ -2,19 +2,19 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "transpose_sinking_test_utils.hpp"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "ts_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace transpose_sinking::testing;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing::utils;
namespace transpose_sinking {
namespace testing {
@@ -407,7 +407,7 @@ auto wrapper = [](const TestCase& test_case) {
auto test_forward = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.test_model = CreateFunctionTransposeBefore;
test_case.ref_model = CreateFunctionTransposeAfter;
@@ -419,7 +419,7 @@ auto test_forward = []() {
auto test_backward = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = CreateFunctionTransposeAfter;
test_case.ref_model = CreateFunctionTransposeBefore;
@@ -431,7 +431,7 @@ auto test_backward = []() {
auto test_forward_multiple_consumers_reshape = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
@@ -443,7 +443,7 @@ auto test_forward_multiple_consumers_reshape = []() {
auto test_backward_multiple_consumers_reshape = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
@@ -456,7 +456,7 @@ auto test_backward_multiple_consumers_reshape = []() {
auto test_forward_multiple_consumers_eltwise = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
@@ -468,7 +468,7 @@ auto test_forward_multiple_consumers_eltwise = []() {
auto test_backward_multiple_consumers_eltwise = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
@@ -480,7 +480,7 @@ auto test_backward_multiple_consumers_eltwise = []() {
auto test_backward_multiple_consumers_first_node = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_first_node::backward::CreateFunction;
test_case.ref_model = mult_consumers_first_node::backward::CreateFunction;
@@ -492,7 +492,7 @@ auto test_backward_multiple_consumers_first_node = []() {
auto test_backward_multiple_transposes_first_node = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_first_node::backward_mult_transposes::CreateFunction;
test_case.ref_model = mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction;
@@ -504,7 +504,7 @@ auto test_backward_multiple_transposes_first_node = []() {
auto test_forward_multiple_consumers_first_node = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_first_node::forward::CreateFunction;
test_case.ref_model = mult_consumers_first_node::forward::CreateReferenceFunction;
@@ -513,47 +513,47 @@ auto test_forward_multiple_consumers_first_node = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardTestSuite,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardTestSuite,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteEltwise,
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteEltwise,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_transposes_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);

View File

@@ -21,7 +21,7 @@
#include "so_extension.hpp"
#include "tf_framework_node.hpp"
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
#include "transformations/transpose_sinking/ts_general.hpp"
#include "translate_session.hpp"
#include "utils.hpp"
@@ -239,7 +239,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.run_passes(model);
}
// TODO: TransposeSinkingGeneral can fail on models with Framework nodes (not converted to OV opset)
// TODO: TSGeneral can fail on models with Framework nodes (not converted to OV opset)
auto unsupported_ops = get_unconverted_types_from_model(model);
if (unsupported_ops.size() > 0) {
return;
@@ -248,7 +248,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
{
// perform transpose sinking and reverse infer if the model contains only OpenVINO operations
ov::pass::Manager manager;
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.run_passes(model);
}

View File

@@ -17,7 +17,7 @@
#include "tflite_transformations/rfft2d_complex_abs.h"
#include "tflite_transformations/tflite_quantize_resolver.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
#include "transformations/transpose_sinking/ts_general.hpp"
using namespace ov;
using namespace ov::frontend::tensorflow_lite;
@@ -268,7 +268,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
manager.register_pass<ov::frontend::tensorflow_lite::pass::TFLQuantizeResolver>();
manager.register_pass<ov::frontend::tensorflow_lite::pass::Rfft2dSimplifier>();
manager.register_pass<ov::pass::TransposeSinking>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
manager.run_passes(function);
}