codestye
This commit is contained in:
parent
3c5f62c013
commit
176686318f
@ -18,8 +18,8 @@ class TRANSFORMATIONS_API TransposeSinkingFuse;
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input or
|
||||
* fuses them to single Transpose if input gets changed
|
||||
* @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input
|
||||
* or fuses them to single Transpose if input gets changed
|
||||
*/
|
||||
class ov::pass::TransposeSinkingFuse : public ov::pass::MatcherPass {
|
||||
public:
|
||||
|
@ -13,11 +13,11 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||
|
@ -9,9 +9,9 @@
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/util/arithmetic_reductions_keep_dims.hpp"
|
||||
#include "openvino/op/util/logical_reduction_keep_dims.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
@ -23,8 +23,7 @@ ov::pass::TransposeSinkingStridedSliceForward::TransposeSinkingStridedSliceForwa
|
||||
MATCHER_SCOPE(TransposeSinkingStridedSliceForward);
|
||||
auto const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
|
||||
auto main_node_label =
|
||||
wrap_type<StridedSlice>({transpose_label, any_input(), any_input(), any_input()});
|
||||
auto main_node_label = wrap_type<StridedSlice>({transpose_label, any_input(), any_input(), any_input()});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
|
@ -65,25 +65,25 @@ FactoryPtr CreateBinaryFactory(const std::string& type_name) {
|
||||
* Unsqueeze insertion
|
||||
*/
|
||||
std::vector<FactoryPtr> binary_elementwise_factories = {CREATE_BINARY_FACTORY(Add),
|
||||
CREATE_BINARY_FACTORY(Divide),
|
||||
CREATE_BINARY_FACTORY(Maximum),
|
||||
CREATE_BINARY_FACTORY(Minimum),
|
||||
CREATE_BINARY_FACTORY(Mod),
|
||||
CREATE_BINARY_FACTORY(Multiply),
|
||||
CREATE_BINARY_FACTORY(Power),
|
||||
CREATE_BINARY_FACTORY(SquaredDifference),
|
||||
CREATE_BINARY_FACTORY(Subtract)};
|
||||
CREATE_BINARY_FACTORY(Divide),
|
||||
CREATE_BINARY_FACTORY(Maximum),
|
||||
CREATE_BINARY_FACTORY(Minimum),
|
||||
CREATE_BINARY_FACTORY(Mod),
|
||||
CREATE_BINARY_FACTORY(Multiply),
|
||||
CREATE_BINARY_FACTORY(Power),
|
||||
CREATE_BINARY_FACTORY(SquaredDifference),
|
||||
CREATE_BINARY_FACTORY(Subtract)};
|
||||
|
||||
std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
|
||||
CREATE_BINARY_FACTORY(Divide),
|
||||
CREATE_BINARY_FACTORY(Maximum),
|
||||
CREATE_BINARY_FACTORY(Minimum),
|
||||
CREATE_BINARY_FACTORY(Mod),
|
||||
CREATE_BINARY_FACTORY(Multiply),
|
||||
CREATE_BINARY_FACTORY(Power),
|
||||
CREATE_BINARY_FACTORY(SquaredDifference),
|
||||
CREATE_BINARY_FACTORY(Subtract),
|
||||
CREATE_BINARY_FACTORY(PRelu)};
|
||||
CREATE_BINARY_FACTORY(Divide),
|
||||
CREATE_BINARY_FACTORY(Maximum),
|
||||
CREATE_BINARY_FACTORY(Minimum),
|
||||
CREATE_BINARY_FACTORY(Mod),
|
||||
CREATE_BINARY_FACTORY(Multiply),
|
||||
CREATE_BINARY_FACTORY(Power),
|
||||
CREATE_BINARY_FACTORY(SquaredDifference),
|
||||
CREATE_BINARY_FACTORY(Subtract),
|
||||
CREATE_BINARY_FACTORY(PRelu)};
|
||||
|
||||
std::vector<size_t> binary_operations_numbers = {1, 10};
|
||||
|
||||
@ -150,9 +150,7 @@ std::shared_ptr<Model> CreateReferenceFunction(FactoryPtr binary_factory,
|
||||
} // namespace one_input_transpose
|
||||
|
||||
namespace double_transpose {
|
||||
std::shared_ptr<Model> CreateFunction(FactoryPtr binary_factory,
|
||||
size_t num_binary_ops,
|
||||
element::Type input_type) {
|
||||
std::shared_ptr<Model> CreateFunction(FactoryPtr binary_factory, size_t num_binary_ops, element::Type input_type) {
|
||||
const Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||
@ -198,8 +196,8 @@ std::shared_ptr<Model> CreateReferenceFunction(FactoryPtr binary_factory,
|
||||
return std::make_shared<Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
using CreateGraphBinaryTwoTransposeInputsF = std::function<
|
||||
std::shared_ptr<Model>(FactoryPtr binary_factory, size_t num_binary_ops, element::Type input_type)>;
|
||||
using CreateGraphBinaryTwoTransposeInputsF =
|
||||
std::function<std::shared_ptr<Model>(FactoryPtr binary_factory, size_t num_binary_ops, element::Type input_type)>;
|
||||
|
||||
using TestBinaryTwoTransposeInputsParams =
|
||||
std::tuple<FactoryPtr,
|
||||
@ -1071,9 +1069,8 @@ std::shared_ptr<Model> CreateReferenceFunction(FactoryPtr binary_factory,
|
||||
|
||||
} // namespace backward
|
||||
|
||||
using CreateGraphF = std::function<std::shared_ptr<Model>(FactoryPtr binary_factory,
|
||||
element::Type input_type,
|
||||
size_t binary_transpose_input_idx)>;
|
||||
using CreateGraphF = std::function<
|
||||
std::shared_ptr<Model>(FactoryPtr binary_factory, element::Type input_type, size_t binary_transpose_input_idx)>;
|
||||
|
||||
struct CreateGraphFunctionDesc {
|
||||
CreateGraphFunctionDesc() = default;
|
||||
|
@ -2,19 +2,18 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
@ -89,7 +88,6 @@ FactoryPtr CreateConcatRefFactory(const std::string& type_name) {
|
||||
return std::make_shared<ConcatFactoryRef>(type_name);
|
||||
}
|
||||
|
||||
|
||||
class SplitFactory : public IFactory {
|
||||
public:
|
||||
explicit SplitFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
@ -194,7 +192,8 @@ struct Preprocessing {
|
||||
struct TestCase;
|
||||
struct ModelDescription;
|
||||
using TestParams = tuple<size_t /* idx num_main_ops */, size_t /* idx main_op */, TestCase>;
|
||||
using CreateGraphF = function<shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const OutputVector&)>;
|
||||
using CreateGraphF =
|
||||
function<shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const OutputVector&)>;
|
||||
|
||||
// Describes a model to test.
|
||||
// Expects to be used in such a scenario:
|
||||
@ -242,16 +241,16 @@ public:
|
||||
};
|
||||
|
||||
vector<FactoryPtr> unary_factories = {
|
||||
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
|
||||
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
|
||||
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
|
||||
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
|
||||
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
|
||||
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
|
||||
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
|
||||
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
|
||||
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
|
||||
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
|
||||
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
|
||||
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
|
||||
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
|
||||
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
|
||||
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||
|
||||
std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
|
||||
CREATE_BINARY_FACTORY(Divide),
|
||||
@ -264,19 +263,19 @@ std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
|
||||
CREATE_BINARY_FACTORY(Subtract),
|
||||
CREATE_BINARY_FACTORY(PRelu)};
|
||||
|
||||
std::vector<FactoryPtr> reduction_factories = {CREATE_BINARY_FACTORY(ReduceMax),
|
||||
CREATE_BINARY_FACTORY(ReduceMin),
|
||||
CREATE_BINARY_FACTORY(ReduceMean),
|
||||
CREATE_BINARY_FACTORY(ReduceSum),
|
||||
CREATE_BINARY_FACTORY(ReduceProd),
|
||||
//CREATE_BINARY_FACTORY(ReduceLogicalOr),
|
||||
//CREATE_BINARY_FACTORY(ReduceLogicalAnd),
|
||||
CREATE_BINARY_FACTORY(ReduceL1),
|
||||
CREATE_BINARY_FACTORY(ReduceL2),
|
||||
//CREATE_BINARY_FACTORY(Squeeze),
|
||||
//CREATE_BINARY_FACTORY(Unsqueeze),
|
||||
};
|
||||
|
||||
std::vector<FactoryPtr> reduction_factories = {
|
||||
CREATE_BINARY_FACTORY(ReduceMax),
|
||||
CREATE_BINARY_FACTORY(ReduceMin),
|
||||
CREATE_BINARY_FACTORY(ReduceMean),
|
||||
CREATE_BINARY_FACTORY(ReduceSum),
|
||||
CREATE_BINARY_FACTORY(ReduceProd),
|
||||
// CREATE_BINARY_FACTORY(ReduceLogicalOr),
|
||||
// CREATE_BINARY_FACTORY(ReduceLogicalAnd),
|
||||
CREATE_BINARY_FACTORY(ReduceL1),
|
||||
CREATE_BINARY_FACTORY(ReduceL2),
|
||||
// CREATE_BINARY_FACTORY(Squeeze),
|
||||
// CREATE_BINARY_FACTORY(Unsqueeze),
|
||||
};
|
||||
|
||||
TEST_P(TransposeSinkingTestFixture, CompareFunctions) {
|
||||
int num_main_ops_idx;
|
||||
@ -311,8 +310,8 @@ shared_ptr<ov::Model> create_model(size_t main_node_idx,
|
||||
auto outputs = model_desc.preprocess_outputs_of_main.apply(main_node->outputs());
|
||||
return make_shared<ov::Model>(outputs, filter_parameters(inputs_to_main));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace transpose_sinking
|
||||
|
||||
auto wrapper = [](const TestCase& test_case) {
|
||||
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
|
||||
@ -330,7 +329,7 @@ auto test_forward_unary = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -346,9 +345,7 @@ auto test_forward_unary = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_unary());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward, TransposeSinkingTestFixture, test_forward_unary());
|
||||
|
||||
auto test_forward_binary = []() {
|
||||
TestCase test_case;
|
||||
@ -357,8 +354,8 @@ auto test_forward_binary = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -375,10 +372,7 @@ auto test_forward_binary = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_binary());
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TransposeSinkingTestFixture, test_forward_binary());
|
||||
|
||||
auto test_forward_concat = []() {
|
||||
TestCase test_case;
|
||||
@ -387,9 +381,9 @@ auto test_forward_concat = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatForward);
|
||||
test_case.num_main_ops = {1, 3};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -406,9 +400,7 @@ auto test_forward_concat = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_concat());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward, TransposeSinkingTestFixture, test_forward_concat());
|
||||
|
||||
auto test_forward_split = []() {
|
||||
TestCase test_case;
|
||||
@ -417,8 +409,8 @@ auto test_forward_split = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 9, 55, 55}),
|
||||
constant(element::i32, {}, {2}),
|
||||
parameter(element::f32, {1, 9, 55, 55}),
|
||||
constant(element::i32, {}, {2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -430,7 +422,8 @@ auto test_forward_split = []() {
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{1});
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{1});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -441,9 +434,7 @@ auto test_forward_split = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_split());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward, TransposeSinkingTestFixture, test_forward_split());
|
||||
|
||||
auto test_forward_pad = []() {
|
||||
TestCase test_case;
|
||||
@ -452,9 +443,9 @@ auto test_forward_pad = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -471,9 +462,7 @@ auto test_forward_pad = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_pad());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TransposeSinkingTestFixture, test_forward_pad());
|
||||
|
||||
auto test_forward_batch_to_space = []() {
|
||||
TestCase test_case;
|
||||
@ -482,10 +471,10 @@ auto test_forward_batch_to_space = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {128, 55, 3, 128}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
parameter(element::f32, {128, 55, 3, 128}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -513,10 +502,10 @@ auto test_forward_space_to_batch = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {64, 9, 8, 1}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
parameter(element::f32, {64, 9, 8, 1}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -544,8 +533,8 @@ auto test_forward_reduction = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 4, 2, 1}),
|
||||
constant(element::i32, {2}, {1, 3}),
|
||||
parameter(element::f32, {32, 4, 2, 1}),
|
||||
constant(element::i32, {2}, {1, 3}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -557,7 +546,8 @@ auto test_forward_reduction = []() {
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -568,6 +558,4 @@ auto test_forward_reduction = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reduction());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward, TransposeSinkingTestFixture, test_forward_reduction());
|
||||
|
@ -11,8 +11,8 @@
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
@ -9,11 +9,11 @@
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_data_movement.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_utils.hpp>
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
|
@ -2,18 +2,19 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
||||
shared_ptr <Node> create_main_node(const OutputVector &inputs, size_t num_ops, const FactoryPtr &creator) {
|
||||
shared_ptr<Node> create_main_node(const OutputVector& inputs, size_t num_ops, const FactoryPtr& creator) {
|
||||
OutputVector current_inputs = inputs;
|
||||
for (size_t i = 0; i < num_ops; ++i) {
|
||||
auto op = creator->create(current_inputs);
|
||||
@ -22,7 +23,7 @@ shared_ptr <Node> create_main_node(const OutputVector &inputs, size_t num_ops, c
|
||||
return current_inputs[0].get_node_shared_ptr();
|
||||
}
|
||||
|
||||
ParameterVector filter_parameters(const OutputVector &out_vec) {
|
||||
ParameterVector filter_parameters(const OutputVector& out_vec) {
|
||||
ParameterVector parameters;
|
||||
for (const auto& out : out_vec) {
|
||||
auto node = out.get_node_shared_ptr();
|
||||
@ -36,7 +37,7 @@ ParameterVector filter_parameters(const OutputVector &out_vec) {
|
||||
OutputVector set_transpose_for(const vector<size_t>& idxs, const OutputVector& out_vec) {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto &out = out_vec[idx];
|
||||
const auto& out = out_vec[idx];
|
||||
auto rank = out.get_partial_shape().rank().get_length();
|
||||
vector<int64_t> axes(rank);
|
||||
iota(axes.begin(), axes.end(), 0);
|
||||
@ -51,7 +52,7 @@ OutputVector set_transpose_for(const vector<size_t>& idxs, const OutputVector& o
|
||||
OutputVector set_gather_for(const vector<size_t>& idxs, const OutputVector& out_vec) {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto &out = out_vec[idx];
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> axes(out.get_shape()[0]);
|
||||
iota(axes.begin(), axes.end(), 0);
|
||||
reverse(axes.begin(), axes.end());
|
||||
@ -63,7 +64,7 @@ OutputVector set_gather_for(const vector<size_t>& idxs, const OutputVector& out_
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string to_string(const Shape &shape) {
|
||||
std::string to_string(const Shape& shape) {
|
||||
ostringstream result;
|
||||
result << "{";
|
||||
for (size_t idx = 0; idx < shape.size(); ++idx) {
|
||||
@ -75,10 +76,10 @@ std::string to_string(const Shape &shape) {
|
||||
return result.str();
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShape &ps) {
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShape& ps) {
|
||||
return std::make_shared<Parameter>(el_type, ps);
|
||||
}
|
||||
|
||||
shared_ptr<ov::Node> constant(ov::element::Type el_type, const Shape &shape, const vector<int64_t> &value) {
|
||||
shared_ptr<ov::Node> constant(ov::element::Type el_type, const Shape& shape, const vector<int64_t>& value) {
|
||||
return make_shared<Constant>(el_type, shape, value);
|
||||
}
|
||||
|
@ -64,7 +64,8 @@ ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
|
||||
|
||||
std::shared_ptr<ov::Node> create_main_node(const ov::OutputVector& inputs, size_t num_ops, const FactoryPtr& creator);
|
||||
|
||||
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
|
||||
|
||||
std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<int64_t>& value);
|
||||
std::shared_ptr<ov::Node> constant(ov::element::Type el_type,
|
||||
const ov::Shape& shape,
|
||||
const std::vector<int64_t>& value);
|
@ -9,26 +9,24 @@
|
||||
#include <openvino/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
|
||||
using CreateGraphF = std::function<std::shared_ptr<ov::Model>(FactoryPtr unary_factory,
|
||||
size_t num_unary_ops,
|
||||
const Shape& input_shape,
|
||||
element::Type input_type)>;
|
||||
using CreateGraphF = std::function<std::shared_ptr<
|
||||
ov::Model>(FactoryPtr unary_factory, size_t num_unary_ops, const Shape& input_shape, element::Type input_type)>;
|
||||
|
||||
using TestParams = std::tuple<FactoryPtr,
|
||||
PassFactoryPtr,
|
||||
size_t, /* num_unary_ops */
|
||||
CreateGraphF, /* model_factory */
|
||||
CreateGraphF, /* reference_model_factory */
|
||||
Shape, /* input shape */
|
||||
element::Type>; /* input type */
|
||||
PassFactoryPtr,
|
||||
size_t, /* num_unary_ops */
|
||||
CreateGraphF, /* model_factory */
|
||||
CreateGraphF, /* reference_model_factory */
|
||||
Shape, /* input shape */
|
||||
element::Type>; /* input type */
|
||||
|
||||
class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
|
||||
public:
|
||||
@ -60,7 +58,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
namespace unary {
|
||||
@ -404,12 +401,12 @@ struct TestCase {
|
||||
|
||||
auto wrapper = [](const TestCase& test_case) {
|
||||
return ::testing::Combine(::testing::ValuesIn(test_case.main_node),
|
||||
::testing::Values(test_case.transformation),
|
||||
::testing::ValuesIn(test_case.num_main_ops),
|
||||
::testing::Values(test_case.test_model),
|
||||
::testing::Values(test_case.ref_model),
|
||||
::testing::Values(test_case.input_shape),
|
||||
::testing::Values(test_case.type));
|
||||
::testing::Values(test_case.transformation),
|
||||
::testing::ValuesIn(test_case.num_main_ops),
|
||||
::testing::Values(test_case.test_model),
|
||||
::testing::Values(test_case.ref_model),
|
||||
::testing::Values(test_case.input_shape),
|
||||
::testing::Values(test_case.type));
|
||||
};
|
||||
|
||||
auto test_forward = []() {
|
||||
@ -454,7 +451,8 @@ auto test_backward_multiple_consumers_reshape = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
|
||||
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;;
|
||||
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
|
||||
;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
@ -519,9 +517,9 @@ auto test_forward_multiple_consumers_first_node = []() {
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace unary
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
@ -533,44 +531,37 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
|
||||
transpose_sinking::testing::unary::test_backward(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_consumers_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_transposes_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_backward_multiple_transposes_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
Loading…
Reference in New Issue
Block a user