Merge branch 'itikhono/ts/refactoring' into itikhono/ts/slice
This commit is contained in:
@@ -1,40 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingBinaryForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingBinaryBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic,
|
||||
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0");
|
||||
TransposeSinkingBinaryForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic,
|
||||
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0");
|
||||
TransposeSinkingBinaryBackward();
|
||||
};
|
||||
@@ -1,40 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingConcatForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingConcatBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingConcatForward transformation sinks Transpose through Concat operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0");
|
||||
TransposeSinkingConcatForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingConcatBackward transformation sinks Transpose through Concat operation
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0");
|
||||
TransposeSinkingConcatBackward();
|
||||
};
|
||||
@@ -1,42 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingDataMovementForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingDataMovementBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch
|
||||
* and Pad operations in the forward direction.
|
||||
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingDataMovementForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementForward", "0");
|
||||
TransposeSinkingDataMovementForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch
|
||||
* and Pad operations in the backward direction.
|
||||
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingDataMovementBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementBackward", "0");
|
||||
TransposeSinkingDataMovementBackward();
|
||||
};
|
||||
@@ -1,52 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingGeneralForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingGeneralBackward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingGeneral;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingGeneralForward transformation combines all TransposeSinkingForward* transformations into
|
||||
* single GraphRewrite pass.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneralForward", "0");
|
||||
TransposeSinkingGeneralForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingGeneralBackward transformation combines all TransposeSinkingBackward* transformations into
|
||||
* single GraphRewrite pass.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0");
|
||||
TransposeSinkingGeneralBackward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingGeneral transformation combines TransposeSinkingGeneralForward and
|
||||
* TransposeSinkingGeneralBackward transformations into single ModelPass pass and inserts
|
||||
* ConstantFolding pass after them.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneral", "0");
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
};
|
||||
@@ -1,40 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingInterpolateForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingInterpolateBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingInterpolateForward transformation sinks Transpose through Interpolate operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingInterpolateForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateForward", "0");
|
||||
TransposeSinkingInterpolateForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingInterpolateBackward transformation sinks Transpose through Interpolate operation
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingInterpolateBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateBackward", "0");
|
||||
TransposeSinkingInterpolateBackward();
|
||||
};
|
||||
@@ -1,40 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingSplitBackward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingSplitForward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingSplitForward transformation sinks Transpose through Split, VariadicSplit operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0");
|
||||
TransposeSinkingSplitForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingSplitBackward transformation sinks Transpose through Split, VariadicSplit operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0");
|
||||
TransposeSinkingSplitBackward();
|
||||
};
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingUnaryForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
|
||||
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingUnaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingUnaryForward", "0");
|
||||
TransposeSinkingUnaryForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
|
||||
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");
|
||||
TransposeSinkingUnaryBackward();
|
||||
};
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSBinaryForward;
|
||||
class TRANSFORMATIONS_API TSBinaryBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic,
|
||||
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSBinaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
|
||||
TSBinaryForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic,
|
||||
* BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSBinaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
|
||||
TSBinaryBackward();
|
||||
};
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSConcatForward;
|
||||
class TRANSFORMATIONS_API TSConcatBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSConcatForward transformation sinks Transpose through Concat operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSConcatForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSConcatForward", "0");
|
||||
TSConcatForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSConcatBackward transformation sinks Transpose through Concat operation
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSConcatBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSConcatBackward", "0");
|
||||
TSConcatBackward();
|
||||
};
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSDataMovementForward;
|
||||
class TRANSFORMATIONS_API TSDataMovementBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch
|
||||
* and Pad operations in the forward direction.
|
||||
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSDataMovementForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSDataMovementForward", "0");
|
||||
TSDataMovementForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch
|
||||
* and Pad operations in the backward direction.
|
||||
* These operations are categorized as "DataMovement" and are handled in a similar way in this transformation.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSDataMovementBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSDataMovementBackward", "0");
|
||||
TSDataMovementBackward();
|
||||
};
|
||||
@@ -10,19 +10,21 @@
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingFuse;
|
||||
class TRANSFORMATIONS_API TSFuse;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input
|
||||
* @brief TSFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input
|
||||
* or fuses them to single Transpose if input gets changed
|
||||
*/
|
||||
class ov::pass::TransposeSinkingFuse : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSFuse : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingFuse", "0");
|
||||
TransposeSinkingFuse();
|
||||
OPENVINO_RTTI("TSFuse", "0");
|
||||
TSFuse();
|
||||
};
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSGeneralForward;
|
||||
class TRANSFORMATIONS_API TSGeneralBackward;
|
||||
class TRANSFORMATIONS_API TSGeneral;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSGeneralForward transformation combines all TransposeSinkingForward* transformations into
|
||||
* single GraphRewrite pass.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSGeneralForward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TSGeneralForward", "0");
|
||||
TSGeneralForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSGeneralBackward transformation combines all TransposeSinkingBackward* transformations into
|
||||
* single GraphRewrite pass.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSGeneralBackward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TSGeneralBackward", "0");
|
||||
TSGeneralBackward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSGeneral transformation combines TSGeneralForward and
|
||||
* TSGeneralBackward transformations into single ModelPass pass and inserts
|
||||
* ConstantFolding pass after them.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSGeneral : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TSGeneral", "0");
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
};
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSInterpolateForward;
|
||||
class TRANSFORMATIONS_API TSInterpolateBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSInterpolateForward transformation sinks Transpose through Interpolate operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSInterpolateForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSInterpolateForward", "0");
|
||||
TSInterpolateForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSInterpolateBackward transformation sinks Transpose through Interpolate operation
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSInterpolateBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSInterpolateBackward", "0");
|
||||
TSInterpolateBackward();
|
||||
};
|
||||
@@ -10,10 +10,12 @@
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingReductionForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingReductionBackward;
|
||||
class TRANSFORMATIONS_API TSReductionForward;
|
||||
class TRANSFORMATIONS_API TSReductionBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
@@ -22,10 +24,10 @@ class TRANSFORMATIONS_API TransposeSinkingReductionBackward;
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingReductionForward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingReductionForward", "0");
|
||||
TransposeSinkingReductionForward();
|
||||
OPENVINO_RTTI("ov::pass::TSReductionForward", "0");
|
||||
TSReductionForward();
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -33,8 +35,8 @@ public:
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::TransposeSinkingReductionBackward : public ov::pass::MatcherPass {
|
||||
class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingReductionBackward", "0");
|
||||
TransposeSinkingReductionBackward();
|
||||
OPENVINO_RTTI("ov::pass::TSReductionBackward", "0");
|
||||
TSReductionBackward();
|
||||
};
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSSplitBackward;
|
||||
class TRANSFORMATIONS_API TSSplitForward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSSplitForward transformation sinks Transpose through Split, VariadicSplit operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSplitForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSplitForward", "0");
|
||||
TSSplitForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSSplitBackward transformation sinks Transpose through Split, VariadicSplit operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSplitBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSplitBackward", "0");
|
||||
TSSplitBackward();
|
||||
};
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSUnaryForward;
|
||||
class TRANSFORMATIONS_API TSUnaryBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
|
||||
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TSUnaryForward", "0");
|
||||
TSUnaryForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu,
|
||||
* SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TSUnaryBackwardMultiConsumers", "0");
|
||||
TSUnaryBackward();
|
||||
};
|
||||
@@ -4,14 +4,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
namespace utils {
|
||||
|
||||
struct TransposeInputsInfo {
|
||||
std::shared_ptr<ov::opset10::Transpose> transpose;
|
||||
@@ -106,4 +109,7 @@ ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
|
||||
const ov::AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::opset10::Constant>& axis);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@@ -14,7 +14,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
@@ -32,7 +31,7 @@ std::shared_ptr<opset6::Constant> get_reduced_order_constant(const std::shared_p
|
||||
order.erase(order.begin() + i);
|
||||
} else {
|
||||
// if 2nd input for Squeeze op is not provided, we should remove all 1 dims
|
||||
// this case will be supported in new TransposeSinkingGeneral transformation.
|
||||
// this case will be supported in new TSGeneral transformation.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -318,8 +317,6 @@ ov::pass::TransposeFuse::TransposeFuse() {
|
||||
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose);
|
||||
ngraph::replace_node(m.get_match_root(), new_transpose);
|
||||
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
|
||||
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingGeneralForward);
|
||||
add_matcher<ov::pass::TransposeSinkingUnaryForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingBinaryForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingConcatForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingSplitForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingDataMovementForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingReductionForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingInterpolateForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingFuse>();
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingGeneralBackward);
|
||||
add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingBinaryBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingDataMovementBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingReductionBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingInterpolateBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingFuse>();
|
||||
}
|
||||
|
||||
bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral);
|
||||
{
|
||||
ov::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
ov::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -2,24 +2,24 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryForward);
|
||||
TSBinaryForward::TSBinaryForward() {
|
||||
MATCHER_SCOPE(TSBinaryForward);
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||
op::util::BinaryElementwiseComparison,
|
||||
@@ -42,7 +42,7 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@@ -51,8 +51,8 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryBackward);
|
||||
TSBinaryBackward::TSBinaryBackward() {
|
||||
MATCHER_SCOPE(TSBinaryBackward);
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||
op::util::BinaryElementwiseComparison,
|
||||
@@ -2,22 +2,23 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingConcatForward);
|
||||
TSConcatForward::TSConcatForward() {
|
||||
MATCHER_SCOPE(TSConcatForward);
|
||||
|
||||
auto main_node_label = wrap_type<Concat>(IfNodeHasTransposeInputs);
|
||||
|
||||
@@ -47,7 +48,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -57,8 +58,8 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingConcatBackward);
|
||||
TSConcatBackward::TSConcatBackward() {
|
||||
MATCHER_SCOPE(TSConcatBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Concat>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
@@ -2,25 +2,24 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingDataMovementForward);
|
||||
TSDataMovementForward::TSDataMovementForward() {
|
||||
MATCHER_SCOPE(TSDataMovementForward);
|
||||
auto const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
|
||||
auto main_node_label =
|
||||
@@ -65,7 +64,7 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@@ -74,8 +73,8 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingDataMovementBackward::TransposeSinkingDataMovementBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingDataMovementBackward);
|
||||
TSDataMovementBackward::TSDataMovementBackward() {
|
||||
MATCHER_SCOPE(TSDataMovementBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Pad, BatchToSpace, SpaceToBatch>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
@@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@@ -11,16 +11,18 @@
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
|
||||
TSFuse::TSFuse() {
|
||||
MATCHER_SCOPE(TransposeFuse);
|
||||
auto transpose_1_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
transpose_sinking::HasSameOutputTransposeNodes);
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_2_label = pattern::wrap_type<Transpose>({transpose_1_label, pattern::wrap_type<Constant>()});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
@@ -62,11 +64,11 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
|
||||
auto new_transpose = register_new_node<Transpose>(input, new_order);
|
||||
|
||||
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
transpose_sinking::RemoveSingleOutputConsumers(transpose1);
|
||||
RemoveSingleOutputConsumers(transpose1);
|
||||
copy_runtime_info(transpose1, new_transpose);
|
||||
ov::replace_node(transpose1, new_transpose);
|
||||
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_general.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
|
||||
TSGeneralForward::TSGeneralForward() {
|
||||
MATCHER_SCOPE(TSGeneralForward);
|
||||
add_matcher<TSUnaryForward>();
|
||||
add_matcher<TSBinaryForward>();
|
||||
add_matcher<TSConcatForward>();
|
||||
add_matcher<TSSplitForward>();
|
||||
add_matcher<TSDataMovementForward>();
|
||||
add_matcher<TSReductionForward>();
|
||||
add_matcher<TSInterpolateForward>();
|
||||
add_matcher<TSFuse>();
|
||||
}
|
||||
|
||||
TSGeneralBackward::TSGeneralBackward() {
|
||||
MATCHER_SCOPE(TSGeneralBackward);
|
||||
add_matcher<TSUnaryBackward>();
|
||||
add_matcher<TSBinaryBackward>();
|
||||
add_matcher<TSConcatBackward>();
|
||||
add_matcher<TSSplitBackward>();
|
||||
add_matcher<TSDataMovementBackward>();
|
||||
add_matcher<TSReductionBackward>();
|
||||
add_matcher<TSInterpolateBackward>();
|
||||
add_matcher<TSFuse>();
|
||||
}
|
||||
|
||||
bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(TSGeneral);
|
||||
{
|
||||
Manager manager(get_pass_config());
|
||||
manager.register_pass<TSGeneralForward>();
|
||||
manager.register_pass<ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
Manager manager(get_pass_config());
|
||||
manager.register_pass<TSGeneralBackward>();
|
||||
manager.register_pass<ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -2,25 +2,25 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingInterpolateForward);
|
||||
TSInterpolateForward::TSInterpolateForward() {
|
||||
MATCHER_SCOPE(TSInterpolateForward);
|
||||
auto const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
|
||||
auto main_node_label = wrap_type<Interpolate>({transpose_label, any_input(), any_input(), any_input()});
|
||||
@@ -74,7 +74,7 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@@ -83,8 +83,8 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingInterpolateBackward::TransposeSinkingInterpolateBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingInterpolateBackward);
|
||||
TSInterpolateBackward::TSInterpolateBackward() {
|
||||
MATCHER_SCOPE(TSInterpolateBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Interpolate>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
@@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@@ -13,11 +13,13 @@
|
||||
#include "openvino/op/util/logical_reduction_keep_dims.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
|
||||
@@ -80,8 +82,8 @@ bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingReductionForward);
|
||||
TSReductionForward::TSReductionForward() {
|
||||
MATCHER_SCOPE(TSReductionForward);
|
||||
|
||||
auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
pattern::consumers_count(1));
|
||||
@@ -150,7 +152,7 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
|
||||
replace_node(reduction, new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
|
||||
return true;
|
||||
@@ -160,13 +162,13 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() {
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingReductionBackward);
|
||||
TSReductionBackward::TSReductionBackward() {
|
||||
MATCHER_SCOPE(TSReductionBackward);
|
||||
|
||||
auto reduce_or_squeeze_label = pattern::
|
||||
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
|
||||
{pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
transpose_sinking::HasSameOutputTransposeNodes);
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_label = pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
@@ -225,7 +227,7 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward()
|
||||
}
|
||||
|
||||
if (!unsqueeze) {
|
||||
auto reversed_order_values = transpose_sinking::ReverseTransposeOrder(transpose_order_values);
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
@@ -246,7 +248,7 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward()
|
||||
}
|
||||
replace_node(transpose, new_reduction);
|
||||
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
return true;
|
||||
@@ -2,24 +2,22 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -107,7 +105,7 @@ bool GetSplitAxis(const std::shared_ptr<Constant>& split_axis, const ov::Rank& r
|
||||
* Consider case Split (1) -> Split (2) -> Transpose
|
||||
* If specify Split as main searched node after first transformation work we will have
|
||||
* Split (1) -> Transpose -> Split(2)
|
||||
* Matcher pass will not call TransposeSinkingSplitBackward since
|
||||
* Matcher pass will not call TSSplitBackward since
|
||||
* - matcher pattern has no Transpose label
|
||||
* - Split (1) has already been proceeded
|
||||
* Adding Split(2) into the working queue as register_new_node(split)
|
||||
@@ -121,8 +119,8 @@ bool GetSplitAxis(const std::shared_ptr<Constant>& split_axis, const ov::Rank& r
|
||||
* - add reversed Transpose operations on all outputs except sinking Transpose
|
||||
* nothing to do with new added output Transposes
|
||||
*/
|
||||
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSplitBackward);
|
||||
TSSplitBackward::TSSplitBackward() {
|
||||
MATCHER_SCOPE(TSSplitBackward);
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), transpose_const_label}, IsSplitSinked);
|
||||
@@ -192,8 +190,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSplitForward);
|
||||
TSSplitForward::TSSplitForward() {
|
||||
MATCHER_SCOPE(TSSplitForward);
|
||||
|
||||
auto main_node_label = wrap_type<Split, VariadicSplit>(IfNodeHasTransposeInputs);
|
||||
|
||||
@@ -225,7 +223,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -2,22 +2,23 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::op::util;
|
||||
using namespace transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -51,8 +52,8 @@ NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
|
||||
|
||||
} // namespace
|
||||
|
||||
ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingUnaryForward);
|
||||
TSUnaryForward::TSUnaryForward() {
|
||||
MATCHER_SCOPE(TSUnaryForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
|
||||
auto unary_label =
|
||||
@@ -73,7 +74,7 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward");
|
||||
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TSUnaryForward");
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
@@ -83,8 +84,8 @@ bool IfSinkingEnabled(const Output<Node>& output) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers);
|
||||
TSUnaryBackward::TSUnaryBackward() {
|
||||
MATCHER_SCOPE(TSUnaryBackwardMultiConsumers);
|
||||
|
||||
auto unary_restrictions = [](const Output<Node>& output) -> bool {
|
||||
return HasSameOutputTransposeNodes(output);
|
||||
@@ -115,6 +116,6 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackward");
|
||||
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TSUnaryBackward");
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
@@ -2,17 +2,19 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
namespace utils {
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
@@ -377,4 +379,7 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@@ -1,40 +1,24 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_binary.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking::testing;
|
||||
|
||||
namespace transpose_sinking_binary_eltwise {
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace {
|
||||
namespace {
|
||||
std::string to_string(const Shape& shape) {
|
||||
std::ostringstream result;
|
||||
result << "{";
|
||||
for (size_t idx = 0; idx < shape.size(); ++idx) {
|
||||
if (idx)
|
||||
result << ",";
|
||||
result << shape[idx];
|
||||
}
|
||||
result << "}";
|
||||
return result.str();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
template <typename BinaryT>
|
||||
class BinaryFactory : public IFactory {
|
||||
@@ -87,6 +71,10 @@ std::vector<size_t> binary_transpose_input_indexes = {0, 1};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
namespace binary {
|
||||
|
||||
namespace single_consumer {
|
||||
namespace forward {
|
||||
namespace one_input_transpose {
|
||||
@@ -207,7 +195,7 @@ class TransposeSinkingBinaryTwoTransposeInputsTestFixture
|
||||
: public ::testing::WithParamInterface<TestBinaryTwoTransposeInputsParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_name(const testing::TestParamInfo<TestBinaryTwoTransposeInputsParams>& obj) {
|
||||
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryTwoTransposeInputsParams>& obj) {
|
||||
FactoryPtr binary_factory;
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_binary_ops;
|
||||
@@ -247,7 +235,7 @@ TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) {
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite,
|
||||
TransposeSinkingBinaryTwoTransposeInputsTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
|
||||
::testing::ValuesIn(binary_operations_numbers),
|
||||
::testing::Values(CreateFunction),
|
||||
::testing::Values(CreateReferenceFunction),
|
||||
@@ -327,7 +315,7 @@ using TestBinaryParams = std::tuple<FactoryPtr,
|
||||
class TransposeSinkingBinaryTestFixture : public ::testing::WithParamInterface<TestBinaryParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_name(const testing::TestParamInfo<TestBinaryParams>& obj) {
|
||||
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryParams>& obj) {
|
||||
FactoryPtr binary_factory;
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_binary_ops;
|
||||
@@ -377,10 +365,10 @@ TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingBinaryForwardTestSuite,
|
||||
TSBinaryForwardTestSuite,
|
||||
TransposeSinkingBinaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
|
||||
::testing::ValuesIn(binary_operations_numbers),
|
||||
::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction),
|
||||
::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction),
|
||||
@@ -389,10 +377,10 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingBinaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingBinaryBackwardTestSuite,
|
||||
TSBinaryBackwardTestSuite,
|
||||
TransposeSinkingBinaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
|
||||
::testing::ValuesIn(binary_operations_numbers),
|
||||
::testing::Values(single_consumer::backward::one_input_transpose::CreateFunction),
|
||||
::testing::Values(single_consumer::backward::one_input_transpose::CreateReferenceFunction),
|
||||
@@ -421,7 +409,7 @@ class TransposeSinkingBinaryIncompatShapesTestFixture
|
||||
: public ::testing::WithParamInterface<TestBinaryIncompatShapesParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_name(const testing::TestParamInfo<TestBinaryIncompatShapesParams>& obj) {
|
||||
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryIncompatShapesParams>& obj) {
|
||||
FactoryPtr binary_factory;
|
||||
PassFactoryPtr pass_factory;
|
||||
Shape input_shape;
|
||||
@@ -600,7 +588,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingBinaryIncompatShapesBackwardTestSuite,
|
||||
TransposeSinkingBinaryIncompatShapesTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::ValuesIn(binary::single_consumer::backward::incompat_shapes::constant_shapes),
|
||||
::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction),
|
||||
@@ -613,7 +601,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingBinaryIncompatShapesForwardTestSuite,
|
||||
TransposeSinkingBinaryIncompatShapesTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::ValuesIn(binary::single_consumer::forward::incompat_shapes::constant_shapes),
|
||||
::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction),
|
||||
@@ -626,7 +614,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingPReluIncompatShapesBackwardTestSuite,
|
||||
TransposeSinkingBinaryIncompatShapesTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
|
||||
::testing::Values(Shape{1, 3, 16, 16}),
|
||||
::testing::ValuesIn(std::vector<Shape>{Shape{3}}),
|
||||
::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction),
|
||||
@@ -639,7 +627,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingPReluIncompatShapesForwardTestSuite,
|
||||
TransposeSinkingBinaryIncompatShapesTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
|
||||
::testing::Values(Shape{1, 3, 16, 16}),
|
||||
::testing::ValuesIn(std::vector<Shape>{Shape{3}}),
|
||||
::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction),
|
||||
@@ -1090,7 +1078,7 @@ using TestBinaryParams = std::tuple<FactoryPtr,
|
||||
class TransposeBinaryMultiSinkingFixture : public ::testing::WithParamInterface<TestBinaryParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_name(const testing::TestParamInfo<TestBinaryParams>& obj) {
|
||||
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryParams>& obj) {
|
||||
FactoryPtr binary_factory;
|
||||
PassFactoryPtr pass_factory;
|
||||
CreateGraphFunctionDesc function_desc;
|
||||
@@ -1139,19 +1127,19 @@ std::vector<CreateGraphFunctionDesc> backward_subtests = {
|
||||
|
||||
#undef SUBTEST
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardMultiConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSBinaryForwardMultiConsumersTestSuite,
|
||||
TransposeBinaryMultiSinkingFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)),
|
||||
::testing::ValuesIn(forward_subtests),
|
||||
::testing::Values(element::f32),
|
||||
::testing::ValuesIn(binary_transpose_input_indexes)),
|
||||
TransposeBinaryMultiSinkingFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardMultiConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSBinaryBackwardMultiConsumersTestSuite,
|
||||
TransposeBinaryMultiSinkingFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
|
||||
::testing::ValuesIn(backward_subtests),
|
||||
::testing::Values(element::f32),
|
||||
::testing::ValuesIn(binary_transpose_input_indexes)),
|
||||
@@ -1177,7 +1165,7 @@ using TestBinaryParams = std::tuple<FactoryPtr,
|
||||
class TransposeBinaryMultiSinkingBinaryMultiConsumersFixture : public ::testing::WithParamInterface<TestBinaryParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_name(const testing::TestParamInfo<TestBinaryParams>& obj) {
|
||||
static std::string get_test_name(const ::testing::TestParamInfo<TestBinaryParams>& obj) {
|
||||
FactoryPtr binary_factory;
|
||||
PassFactoryPtr pass_factory;
|
||||
CreateGraphFunctionDesc function_desc;
|
||||
@@ -1219,10 +1207,10 @@ std::vector<CreateGraphFunctionDesc> backward_subtests_binary_consumers = {
|
||||
};
|
||||
#undef SUBTEST
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardBinaryMultiConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSBinaryBackwardBinaryMultiConsumersTestSuite,
|
||||
TransposeBinaryMultiSinkingBinaryMultiConsumersFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)),
|
||||
::testing::ValuesIn(backward_subtests_binary_consumers),
|
||||
::testing::Values(element::f32),
|
||||
::testing::ValuesIn(binary_transpose_input_indexes)),
|
||||
@@ -1232,4 +1220,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardBinaryMultiConsumersTestS
|
||||
|
||||
} // namespace mult_consumers
|
||||
|
||||
} // namespace transpose_sinking_binary_eltwise
|
||||
} // namespace binary
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
@@ -2,27 +2,27 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_slice.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking::testing;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
namespace common {
|
||||
|
||||
template <typename UnaryT>
|
||||
@@ -378,7 +378,7 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
|
||||
test_case.num_main_ops = num_main_ops;
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
@@ -387,12 +387,12 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = factories;
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.main_op = factories;
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -408,7 +408,7 @@ auto test_forward_binary = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
@@ -418,13 +418,13 @@ auto test_forward_binary = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = binary_factories;
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{1}}};
|
||||
test_case.model_ref.main_op = binary_factories;
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -435,7 +435,7 @@ auto test_forward_concat = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSConcatForward);
|
||||
test_case.num_main_ops = {1, 3};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
@@ -446,13 +446,13 @@ auto test_forward_concat = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_CONCAT_FACTORY(Concat)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{1, 2}}};
|
||||
test_case.model_ref.main_op = {CREATE_CONCAT_REF_FACTORY(Concat)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -463,7 +463,7 @@ auto test_forward_split = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSplitForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 9, 55, 55}),
|
||||
@@ -473,7 +473,7 @@ auto test_forward_split = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SPLIT_FACTORY(Split)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -486,7 +486,7 @@ auto test_forward_split = []() {
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_SPLIT_FACTORY(Split)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0, 1, 2}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -497,7 +497,7 @@ auto test_forward_pad = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
@@ -508,13 +508,13 @@ auto test_forward_pad = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_PAD_FACTORY(Pad)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}};
|
||||
test_case.model_ref.main_op = {CREATE_PAD_FACTORY(Pad)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -525,7 +525,7 @@ auto test_forward_batch_to_space = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {128, 55, 3, 128}),
|
||||
@@ -537,13 +537,13 @@ auto test_forward_batch_to_space = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}};
|
||||
test_case.model_ref.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -556,7 +556,7 @@ auto test_forward_space_to_batch = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {64, 9, 8, 1}),
|
||||
@@ -568,13 +568,13 @@ auto test_forward_space_to_batch = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}};
|
||||
test_case.model_ref.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -587,7 +587,7 @@ auto test_forward_reduction = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 4, 2, 1}),
|
||||
@@ -597,7 +597,7 @@ auto test_forward_reduction = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = reduction_factories;
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -610,7 +610,7 @@ auto test_forward_reduction = []() {
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = reduction_factories;
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -621,7 +621,7 @@ auto test_forward_interpolate = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSInterpolateForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 2, 48, 80}),
|
||||
@@ -633,7 +633,7 @@ auto test_forward_interpolate = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, false)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -653,7 +653,7 @@ auto test_forward_interpolate = []() {
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{3}}};
|
||||
test_case.model_ref.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, true)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -666,7 +666,7 @@ auto test_forward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
@@ -676,7 +676,7 @@ auto test_forward_squeeze = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -689,7 +689,7 @@ auto test_forward_squeeze = []() {
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -700,7 +700,7 @@ auto test_forward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
@@ -710,7 +710,7 @@ auto test_forward_unsqueeze = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -730,7 +730,7 @@ auto test_forward_unsqueeze = []() {
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{new_transpose}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -786,7 +786,7 @@ auto test_backward_unary = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
@@ -795,12 +795,12 @@ auto test_backward_unary = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = unary_factories;
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.main_op = unary_factories;
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -811,7 +811,7 @@ auto test_backward_binary = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
@@ -821,12 +821,12 @@ auto test_backward_binary = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = binary_factories;
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0, 1}}};
|
||||
test_case.model_ref.main_op = binary_factories;
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -837,7 +837,7 @@ auto test_backward_concat = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSConcatBackward);
|
||||
test_case.num_main_ops = {1, 3};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
@@ -848,12 +848,12 @@ auto test_backward_concat = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_CONCAT_FACTORY(Concat)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0, 1, 2}}};
|
||||
test_case.model_ref.main_op = {CREATE_CONCAT_REF_FACTORY(Concat)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -864,7 +864,7 @@ auto test_backward_split = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSplitBackward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 9, 55, 55}),
|
||||
@@ -874,7 +874,7 @@ auto test_backward_split = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_SPLIT_FACTORY(Split)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0, 1, 2}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -886,7 +886,7 @@ auto test_backward_split = []() {
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_SPLIT_FACTORY(Split)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
@@ -896,7 +896,7 @@ auto test_backward_pad = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
@@ -907,12 +907,12 @@ auto test_backward_pad = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_PAD_FACTORY(Pad)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2}}};
|
||||
test_case.model_ref.main_op = {CREATE_PAD_FACTORY(Pad)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -923,7 +923,7 @@ auto test_backward_batch_to_space = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {128, 55, 3, 128}),
|
||||
@@ -935,12 +935,12 @@ auto test_backward_batch_to_space = []() {
|
||||
// Reference model description:
|
||||
test_case.model.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Test model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2, 3}}};
|
||||
test_case.model_ref.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -953,7 +953,7 @@ auto test_backward_space_to_batch = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 8, 9, 64}),
|
||||
@@ -965,12 +965,12 @@ auto test_backward_space_to_batch = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2, 3}}};
|
||||
test_case.model_ref.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
@@ -982,7 +982,7 @@ auto test_backward_reduction = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 4, 2, 1}),
|
||||
@@ -992,7 +992,7 @@ auto test_backward_reduction = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = reduction_factories;
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -1004,7 +1004,7 @@ auto test_backward_reduction = []() {
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = reduction_factories;
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -1017,7 +1017,7 @@ auto test_backward_interpolate = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSInterpolateBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 2, 48, 80}),
|
||||
@@ -1029,7 +1029,7 @@ auto test_backward_interpolate = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, true)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -1048,7 +1048,7 @@ auto test_backward_interpolate = []() {
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {3}}};
|
||||
test_case.model_ref.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, false)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -1062,7 +1062,7 @@ auto test_backward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
@@ -1072,7 +1072,7 @@ auto test_backward_squeeze = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -1084,7 +1084,7 @@ auto test_backward_squeeze = []() {
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
|
||||
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -1095,7 +1095,7 @@ auto test_backward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
@@ -1105,7 +1105,7 @@ auto test_backward_unsqueeze = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@@ -1117,7 +1117,7 @@ auto test_backward_unsqueeze = []() {
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
@@ -1166,4 +1166,5 @@ auto test_backward_slice = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
|
||||
} // namespace common
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
@@ -1,27 +1,28 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_concat.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking::testing;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<size_t> concat_operations_numbers = {1, 10};
|
||||
|
||||
std::vector<size_t> concat_transpose_input_indexes = {0, 2};
|
||||
|
||||
NodePtr CreateConcatChain(NodePtr input_node,
|
||||
@@ -331,9 +332,9 @@ TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingConcatForwardTestSuite,
|
||||
TSConcatForwardTestSuite,
|
||||
TransposeSinkingConcatTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction),
|
||||
::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction),
|
||||
@@ -342,9 +343,9 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::Values(5)),
|
||||
TransposeSinkingConcatTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSConcatBackwardTestSuite,
|
||||
TransposeSinkingConcatTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::Values(single_consumer::backward::CreateFunction),
|
||||
::testing::Values(single_consumer::backward::CreateReferenceFunction),
|
||||
@@ -408,9 +409,9 @@ TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingConcatForwardAllTransposesTestSuite,
|
||||
TSConcatForwardAllTransposesTestSuite,
|
||||
TransposeSinkingConcatAllTransposesInputTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::Values(single_consumer::forward::double_transpose::CreateFunction),
|
||||
::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction),
|
||||
@@ -937,9 +938,9 @@ std::vector<CreateGraphFunctionDesc> backward_subtests = {
|
||||
|
||||
#undef SUBTEST
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardMultiConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSConcatForwardMultiConsumersTestSuite,
|
||||
TransposeConcatMultiSinkingFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::ValuesIn(forward_subtests),
|
||||
::testing::Values(element::f32),
|
||||
@@ -947,9 +948,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardMultiConsumersTestSuite,
|
||||
::testing::Values(5)),
|
||||
TransposeConcatMultiSinkingFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSConcatBackwardMultiConsumersTestSuite,
|
||||
TransposeConcatMultiSinkingFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::ValuesIn(backward_subtests),
|
||||
::testing::Values(element::f32),
|
||||
@@ -1029,9 +1030,9 @@ std::vector<CreateGraphFunctionNoSinkingDesc> backward_subtests_no_sinking = {
|
||||
|
||||
#undef SUBTEST
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSConcatBackwardMultiConsumersTestSuite,
|
||||
TransposeConcatMultiSinkingConcatConsumersFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::ValuesIn(backward_subtests_no_sinking),
|
||||
::testing::Values(element::f32),
|
||||
@@ -1,26 +1,28 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_general.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
namespace general {
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward) {
|
||||
TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesForward) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_unary_ops = 10;
|
||||
@@ -53,10 +55,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward
|
||||
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
|
||||
manager.register_pass<TSGeneralForward>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackward) {
|
||||
TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesBackward) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_unary_ops = 10;
|
||||
@@ -88,10 +90,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackwar
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
|
||||
manager.register_pass<TSGeneralBackward>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral) {
|
||||
TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesGeneral) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_unary_ops = 10;
|
||||
@@ -130,10 +132,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral
|
||||
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<TSGeneral>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
|
||||
TEST_F(TransformationTestsF, TSGeneralTestBinaryGeneral) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_binary_ops = 10;
|
||||
@@ -171,10 +173,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
|
||||
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<TSGeneral>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
|
||||
TEST_F(TransformationTestsF, TSGeneralTestConcatGeneral) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
const size_t num_concat_ops = 3;
|
||||
@@ -224,7 +226,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
|
||||
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<TSGeneral>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------------------------------------
|
||||
@@ -364,7 +366,7 @@ NodePtr MakeAllNodesSubgraph(NodePtr parent, size_t split_axis, size_t concat_ax
|
||||
return in_op;
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
|
||||
using namespace transpose_sinking::testing::general;
|
||||
ov::Shape input_shape = {1, 96, 40, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
@@ -407,7 +409,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
function_ref = std::make_shared<ov::Model>(transpose1, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<TSGeneral>();
|
||||
}
|
||||
|
||||
} // namespace general
|
||||
@@ -1,20 +1,23 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_split.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
@@ -527,9 +530,9 @@ TEST_P(TransposeSinkingSplitTestFixture, CompareFunctions) {
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitForwardSingleConsumerTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSSplitForwardSingleConsumerTestSuite,
|
||||
TransposeSinkingSplitTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
|
||||
::testing::ValuesIn(split_operations_numbers),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(forward::single_consumer::CreateFunction),
|
||||
@@ -538,9 +541,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitForwardSingleConsumerTestSuite,
|
||||
TransposeSinkingSplitTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitForwardMultInputNodeConsumersTestSuite,
|
||||
TSSplitForwardMultInputNodeConsumersTestSuite,
|
||||
TransposeSinkingSplitTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
|
||||
::testing::ValuesIn(split_operations_numbers),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(forward::mult_consumers::input_node_consumers::CreateFunction),
|
||||
@@ -549,9 +552,9 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitForwardMultInputTransposeConsumersTestSuite,
|
||||
TSSplitForwardMultInputTransposeConsumersTestSuite,
|
||||
TransposeSinkingSplitTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
|
||||
::testing::ValuesIn(split_operations_numbers),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(forward::mult_consumers::input_transpose_consumers::CreateFunction),
|
||||
@@ -560,9 +563,9 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitForwardMultOutputConsumersTestSuite,
|
||||
TSSplitForwardMultOutputConsumersTestSuite,
|
||||
TransposeSinkingSplitTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)),
|
||||
::testing::ValuesIn(split_operations_numbers),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(forward::mult_consumers::output_consumers::CreateFunction),
|
||||
@@ -570,9 +573,9 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::Values(element::f32)),
|
||||
TransposeSinkingSplitTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardTestSuite,
|
||||
TransposeSinkingSplitTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
|
||||
::testing::ValuesIn(split_tree_depth_nums),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(backward::single_consumer::CreateFunction),
|
||||
@@ -580,9 +583,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardTestSuite,
|
||||
::testing::Values(element::f32)),
|
||||
TransposeSinkingSplitTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultOutputConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardMultOutputConsumersTestSuite,
|
||||
TransposeSinkingSplitTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
|
||||
::testing::ValuesIn(split_tree_depth_nums),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(backward::mult_output_consumers::CreateFunction),
|
||||
@@ -590,9 +593,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultOutputConsumersTestSui
|
||||
::testing::Values(element::f32)),
|
||||
TransposeSinkingSplitTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultSplitConsumersTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardMultSplitConsumersTestSuite,
|
||||
TransposeSinkingSplitTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
|
||||
::testing::ValuesIn(split_tree_depth_nums),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(backward::mult_split_consumers::CreateFunction),
|
||||
@@ -764,9 +767,8 @@ using TestSplitBackwardRestrictParams = std::tuple<PassFactoryPtr,
|
||||
element::Type, /* input type */
|
||||
TransposeInsertFuncDesc>; /* insert transpose function */
|
||||
|
||||
class TransposeSinkingSplitBackwardRestrictTestFixture
|
||||
: public ::testing::WithParamInterface<TestSplitBackwardRestrictParams>,
|
||||
public TransformationTestsF {
|
||||
class TSSplitBackwardRestrictTestFixture : public ::testing::WithParamInterface<TestSplitBackwardRestrictParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_name(const ::testing::TestParamInfo<TestSplitBackwardRestrictParams>& obj) {
|
||||
PassFactoryPtr pass_factory;
|
||||
@@ -794,7 +796,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) {
|
||||
TEST_P(TSSplitBackwardRestrictTestFixture, CompareFunctions) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t split_tree_depth;
|
||||
size_t num_split_outputs;
|
||||
@@ -821,15 +823,15 @@ std::vector<TransposeInsertFuncDesc> insertTransposeFactories = {FUNC(OnlyFirstT
|
||||
|
||||
#undef FUNC
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardRestrictTestSuite,
|
||||
TransposeSinkingSplitBackwardRestrictTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)),
|
||||
INSTANTIATE_TEST_SUITE_P(TSSplitBackwardRestrictTestSuite,
|
||||
TSSplitBackwardRestrictTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)),
|
||||
::testing::Values(1),
|
||||
::testing::Values(5),
|
||||
::testing::Values(backward::restrictions::CreateFunction),
|
||||
::testing::Values(element::f32),
|
||||
::testing::ValuesIn(insertTransposeFactories)),
|
||||
TransposeSinkingSplitBackwardRestrictTestFixture::get_test_name);
|
||||
TSSplitBackwardRestrictTestFixture::get_test_name);
|
||||
|
||||
} // namespace restrictions
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
@@ -16,6 +15,7 @@ using namespace ov::opset10;
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
namespace utils {
|
||||
|
||||
shared_ptr<Node> create_main_node(const OutputVector& inputs, size_t num_ops, const FactoryPtr& creator) {
|
||||
OutputVector current_inputs = inputs;
|
||||
@@ -83,5 +83,6 @@ std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShap
|
||||
return std::make_shared<Parameter>(el_type, ps);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
@@ -4,15 +4,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
namespace utils {
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
}
|
||||
};
|
||||
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::pass_name>>(#pass_name)
|
||||
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::transpose_sinking::pass_name>>(#pass_name)
|
||||
|
||||
std::string to_string(const ov::Shape& shape);
|
||||
ov::OutputVector set_transpose_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
|
||||
@@ -67,5 +67,6 @@ std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& s
|
||||
return ov::opset10::Constant::create<T>(el_type, shape, value);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
@@ -2,19 +2,19 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking::testing;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
@@ -407,7 +407,7 @@ auto wrapper = [](const TestCase& test_case) {
|
||||
auto test_forward = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = CreateFunctionTransposeBefore;
|
||||
test_case.ref_model = CreateFunctionTransposeAfter;
|
||||
@@ -419,7 +419,7 @@ auto test_forward = []() {
|
||||
auto test_backward = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = CreateFunctionTransposeAfter;
|
||||
test_case.ref_model = CreateFunctionTransposeBefore;
|
||||
@@ -431,7 +431,7 @@ auto test_backward = []() {
|
||||
auto test_forward_multiple_consumers_reshape = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
|
||||
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
|
||||
@@ -443,7 +443,7 @@ auto test_forward_multiple_consumers_reshape = []() {
|
||||
auto test_backward_multiple_consumers_reshape = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
|
||||
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
|
||||
@@ -456,7 +456,7 @@ auto test_backward_multiple_consumers_reshape = []() {
|
||||
auto test_forward_multiple_consumers_eltwise = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
|
||||
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
|
||||
@@ -468,7 +468,7 @@ auto test_forward_multiple_consumers_eltwise = []() {
|
||||
auto test_backward_multiple_consumers_eltwise = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
|
||||
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
|
||||
@@ -480,7 +480,7 @@ auto test_backward_multiple_consumers_eltwise = []() {
|
||||
auto test_backward_multiple_consumers_first_node = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_first_node::backward::CreateFunction;
|
||||
test_case.ref_model = mult_consumers_first_node::backward::CreateFunction;
|
||||
@@ -492,7 +492,7 @@ auto test_backward_multiple_consumers_first_node = []() {
|
||||
auto test_backward_multiple_transposes_first_node = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_first_node::backward_mult_transposes::CreateFunction;
|
||||
test_case.ref_model = mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction;
|
||||
@@ -504,7 +504,7 @@ auto test_backward_multiple_transposes_first_node = []() {
|
||||
auto test_forward_multiple_consumers_first_node = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_first_node::forward::CreateFunction;
|
||||
test_case.ref_model = mult_consumers_first_node::forward::CreateReferenceFunction;
|
||||
@@ -513,47 +513,47 @@ auto test_forward_multiple_consumers_first_node = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardTestSuite,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardTestSuite,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteEltwise,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteEltwise,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_transposes_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
@@ -21,7 +21,7 @@
|
||||
#include "so_extension.hpp"
|
||||
#include "tf_framework_node.hpp"
|
||||
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
|
||||
#include "transformations/transpose_sinking/ts_general.hpp"
|
||||
#include "translate_session.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@@ -239,7 +239,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
// TODO: TransposeSinkingGeneral can fail on models with Framework nodes (not converted to OV opset)
|
||||
// TODO: TSGeneral can fail on models with Framework nodes (not converted to OV opset)
|
||||
auto unsupported_ops = get_unconverted_types_from_model(model);
|
||||
if (unsupported_ops.size() > 0) {
|
||||
return;
|
||||
@@ -248,7 +248,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
{
|
||||
// perform transpose sinking and reverse infer if the model contains only OpenVINO operations
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
|
||||
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#include "tflite_transformations/rfft2d_complex_abs.h"
|
||||
#include "tflite_transformations/tflite_quantize_resolver.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
|
||||
#include "transformations/transpose_sinking/ts_general.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::frontend::tensorflow_lite;
|
||||
@@ -268,7 +268,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
||||
manager.register_pass<ov::frontend::tensorflow_lite::pass::TFLQuantizeResolver>();
|
||||
manager.register_pass<ov::frontend::tensorflow_lite::pass::Rfft2dSimplifier>();
|
||||
manager.register_pass<ov::pass::TransposeSinking>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
|
||||
manager.run_passes(function);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user