fix ov_transformation_tests duplicates (#18383)

This commit is contained in:
Evgeny Kotov 2023-07-05 16:16:00 +02:00 committed by GitHub
parent 6a0c6a1a60
commit fb676a9e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 20 deletions

View File

@ -72,12 +72,6 @@ PadFactoryPtr CreatePadFactory(const std::string& type_name) {
return std::make_shared<PadFactory<PadT>>(type_name);
}
#undef CREATE_PAD_FACTORY
#define CREATE_PAD_FACTORY(type_name, type_str) CreatePadFactory<type_name>(type_str)
std::vector<PadFactoryPtr> pad_factories = {CREATE_PAD_FACTORY(ov::op::v1::Pad, "op_v1_Pad"),
CREATE_PAD_FACTORY(ov::op::v12::Pad, "op_v12_Pad")};
struct ITestModelFactory {
explicit ITestModelFactory(const std::string& a_test_name) : test_name(a_test_name) {}
virtual ~ITestModelFactory() = default;
@ -90,7 +84,7 @@ using TestModelFactoryPtr = std::shared_ptr<ITestModelFactory>;
using TestParams = std::tuple<PadFactoryPtr, TestModelFactoryPtr>;
class PadTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
class PadFusionTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
public:
static std::string get_test_name(const ::testing::TestParamInfo<TestParams>& obj) {
PadFactoryPtr pad_factory;
@ -105,7 +99,7 @@ public:
}
};
TEST_P(PadTestFixture, CompareFunctions) {
TEST_P(PadFusionTestFixture, CompareFunctions) {
PadFactoryPtr pad_factory;
TestModelFactoryPtr model_factory;
std::tie(pad_factory, model_factory) = this->GetParam();
@ -806,6 +800,8 @@ TEST_BODY(NegativePadPreservation) {
// Reference function is equal to function
}
namespace {
#undef CREATE_MODEL_FACTORY
#define CREATE_MODEL_FACTORY(type_name) std::make_shared<type_name>()
@ -834,7 +830,15 @@ std::vector<TestModelFactoryPtr> model_factories = {
CREATE_MODEL_FACTORY(NegativePadFusionConvolution),
CREATE_MODEL_FACTORY(NegativePadFusionGroupConvolution)};
#undef CREATE_PAD_FACTORY
#define CREATE_PAD_FACTORY(type_name, type_str) CreatePadFactory<type_name>(type_str)
std::vector<PadFactoryPtr> pad_factories = {CREATE_PAD_FACTORY(ov::op::v1::Pad, "op_v1_Pad"),
CREATE_PAD_FACTORY(ov::op::v12::Pad, "op_v12_Pad")};
} // namespace
INSTANTIATE_TEST_SUITE_P(PadTestSuite,
PadTestFixture,
PadFusionTestFixture,
::testing::Combine(::testing::ValuesIn(pad_factories), ::testing::ValuesIn(model_factories)),
PadTestFixture::get_test_name);
PadFusionTestFixture::get_test_name);

View File

@ -70,12 +70,6 @@ PadFactoryPtr CreatePadFactory(const std::string& type_name) {
return std::make_shared<PadFactory<PadT>>(type_name);
}
#undef CREATE_PAD_FACTORY
#define CREATE_PAD_FACTORY(type_name, type_str) CreatePadFactory<type_name>(type_str)
std::vector<PadFactoryPtr> pad_factories = {CREATE_PAD_FACTORY(ov::op::v1::Pad, "op_v1_Pad"),
CREATE_PAD_FACTORY(ov::op::v12::Pad, "op_v12_Pad")};
struct ITestModelFactory {
explicit ITestModelFactory(const std::string& a_test_name) : test_name(a_test_name) {}
virtual ~ITestModelFactory() = default;
@ -88,7 +82,7 @@ using TestModelFactoryPtr = std::shared_ptr<ITestModelFactory>;
using TestParams = std::tuple<PadFactoryPtr, TestModelFactoryPtr>;
class PadTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
class ConvertPadGroupConvTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
public:
static std::string get_test_name(const ::testing::TestParamInfo<TestParams>& obj) {
PadFactoryPtr pad_factory;
@ -103,7 +97,7 @@ public:
}
};
TEST_P(PadTestFixture, CompareFunctions) {
TEST_P(ConvertPadGroupConvTestFixture, CompareFunctions) {
PadFactoryPtr pad_factory;
TestModelFactoryPtr model_factory;
std::tie(pad_factory, model_factory) = this->GetParam();
@ -215,6 +209,8 @@ TEST_BODY(ConvertPadToConvNeg4) {
manager.register_pass<ov::pass::ConvertPadToGroupConvolution>();
}
namespace {
#undef CREATE_MODEL_FACTORY
#define CREATE_MODEL_FACTORY(type_name) std::make_shared<type_name>()
@ -225,7 +221,15 @@ std::vector<TestModelFactoryPtr> model_factories = {CREATE_MODEL_FACTORY(Convert
CREATE_MODEL_FACTORY(ConvertPadToConvNeg4),
CREATE_MODEL_FACTORY(NegativeConvertPadToConv)};
#undef CREATE_PAD_FACTORY
#define CREATE_PAD_FACTORY(type_name, type_str) CreatePadFactory<type_name>(type_str)
std::vector<PadFactoryPtr> pad_factories = {CREATE_PAD_FACTORY(ov::op::v1::Pad, "op_v1_Pad"),
CREATE_PAD_FACTORY(ov::op::v12::Pad, "op_v12_Pad")};
} // namespace
INSTANTIATE_TEST_SUITE_P(ConvertPadToGroupConvolutionTestSuite,
PadTestFixture,
ConvertPadGroupConvTestFixture,
::testing::Combine(::testing::ValuesIn(pad_factories), ::testing::ValuesIn(model_factories)),
PadTestFixture::get_test_name);
ConvertPadGroupConvTestFixture::get_test_name);