TransposeSinking tests refactoring: part1
This commit is contained in:
parent
c5991f0b06
commit
34c89eb962
@ -1,16 +1,15 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
|
@ -1,19 +1,18 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations//transpose_sinking_unary.hpp"
|
||||
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::opset10;
|
||||
|
||||
namespace {
|
||||
std::string to_string(const Shape& shape) {
|
||||
@ -449,95 +448,182 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
struct TestCase {
|
||||
std::vector<UnaryFactoryPtr> main_node;
|
||||
PassFactoryPtr transformation;
|
||||
std::vector<size_t> num_main_ops;
|
||||
CreateGraphF test_model;
|
||||
CreateGraphF ref_model;
|
||||
Shape input_shape;
|
||||
element::Type type;
|
||||
};
|
||||
|
||||
auto wrapper = [](const TestCase& test_case) {
|
||||
return ::testing::Combine(::testing::ValuesIn(test_case.main_node),
|
||||
::testing::Values(test_case.transformation),
|
||||
::testing::ValuesIn(test_case.num_main_ops),
|
||||
::testing::Values(test_case.test_model),
|
||||
::testing::Values(test_case.ref_model),
|
||||
::testing::Values(test_case.input_shape),
|
||||
::testing::Values(test_case.type));
|
||||
};
|
||||
|
||||
auto test_forward = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = CreateFunctionTransposeBefore;
|
||||
test_case.ref_model = CreateFunctionTransposeAfter;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_backward = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = CreateFunctionTransposeAfter;
|
||||
test_case.ref_model = CreateFunctionTransposeBefore;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_forward_multiple_consumers_reshape = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
|
||||
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_backward_multiple_consumers_reshape = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
|
||||
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_forward_multiple_consumers_eltwise = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
|
||||
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_backward_multiple_consumers_eltwise = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
|
||||
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_backward_multiple_consumers_first_node = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_first_node::backward::CreateFunction;
|
||||
test_case.ref_model = mult_consumers_first_node::backward::CreateFunction;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_backward_multiple_transposes_first_node = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_first_node::backward_mult_transposes::CreateFunction;
|
||||
test_case.ref_model = mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
auto test_forward_multiple_consumers_first_node = []() {
|
||||
TestCase test_case;
|
||||
test_case.main_node = unary_factories;
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.test_model = mult_consumers_first_node::forward::CreateFunction;
|
||||
test_case.ref_model = mult_consumers_first_node::forward::CreateReferenceFunction;
|
||||
test_case.input_shape = {1, 96, 55, 55};
|
||||
test_case.type = element::f32;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(CreateFunctionTransposeBefore),
|
||||
::testing::Values(CreateFunctionTransposeAfter),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_forward(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(CreateFunctionTransposeAfter),
|
||||
::testing::Values(CreateFunctionTransposeBefore),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_backward(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_forward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_backward_multiple_consumers_reshape(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
|
||||
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_forward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(mult_consumers_first_node::forward::CreateFunction),
|
||||
::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_backward_multiple_consumers_eltwise(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
|
||||
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_backward_multiple_consumers_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||
::testing::ValuesIn(unary_operations_numbers),
|
||||
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction),
|
||||
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction),
|
||||
::testing::Values(Shape{1, 96, 55, 55}),
|
||||
::testing::Values(element::f32)),
|
||||
test_backward_multiple_transposes_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
|
||||
TransposeSinkingUnaryTestFixture,
|
||||
test_forward_multiple_consumers_first_node(),
|
||||
TransposeSinkingUnaryTestFixture::get_test_name);
|
Loading…
Reference in New Issue
Block a user