TransposeSinking tests refactoring: part1

This commit is contained in:
Ivan 2023-03-03 04:34:48 +04:00
parent c5991f0b06
commit 34c89eb962
2 changed files with 150 additions and 65 deletions

View File

@ -1,16 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
namespace transpose_sinking {

View File

@ -1,19 +1,18 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations//transpose_sinking_unary.hpp"
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
using namespace ov;
using namespace ov::opset9;
using namespace ov::opset10;
namespace {
std::string to_string(const Shape& shape) {
@ -449,95 +448,182 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
pass_factory->registerPass(manager);
}
struct TestCase {
std::vector<UnaryFactoryPtr> main_node;
PassFactoryPtr transformation;
std::vector<size_t> num_main_ops;
CreateGraphF test_model;
CreateGraphF ref_model;
Shape input_shape;
element::Type type;
};
auto wrapper = [](const TestCase& test_case) {
return ::testing::Combine(::testing::ValuesIn(test_case.main_node),
::testing::Values(test_case.transformation),
::testing::ValuesIn(test_case.num_main_ops),
::testing::Values(test_case.test_model),
::testing::Values(test_case.ref_model),
::testing::Values(test_case.input_shape),
::testing::Values(test_case.type));
};
auto test_forward = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.test_model = CreateFunctionTransposeBefore;
test_case.ref_model = CreateFunctionTransposeAfter;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_backward = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = CreateFunctionTransposeAfter;
test_case.ref_model = CreateFunctionTransposeBefore;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_forward_multiple_consumers_reshape = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_backward_multiple_consumers_reshape = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_forward_multiple_consumers_eltwise = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_backward_multiple_consumers_eltwise = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter;
test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_backward_multiple_consumers_first_node = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_first_node::backward::CreateFunction;
test_case.ref_model = mult_consumers_first_node::backward::CreateFunction;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_backward_multiple_transposes_first_node = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_first_node::backward_mult_transposes::CreateFunction;
test_case.ref_model = mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
auto test_forward_multiple_consumers_first_node = []() {
TestCase test_case;
test_case.main_node = unary_factories;
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_first_node::forward::CreateFunction;
test_case.ref_model = mult_consumers_first_node::forward::CreateReferenceFunction;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_forward(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_backward(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_forward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_backward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_forward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::forward::CreateFunction),
::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_backward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_backward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction),
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
test_backward_multiple_transposes_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
test_forward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);