Moved ConstantResultTest to new API (#20224)

This commit is contained in:
Ilya Churaev 2023-10-04 14:18:05 +04:00 committed by GitHub
parent ee8bd33c6d
commit 3b8ac28ced
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 183 additions and 89 deletions

View File

@ -2,44 +2,38 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "subgraph_tests/constant_result.hpp"
#include <vector> #include <vector>
#include "subgraph_tests/constant_result.hpp"
#include "common_test_utils/test_constants.hpp" #include "common_test_utils/test_constants.hpp"
using namespace SubgraphTestsDefinitions; using namespace ov::test;
using namespace InferenceEngine;
namespace { namespace {
const std::vector<ConstantSubgraphType> types = { const std::vector<ConstantSubgraphType> types = {ConstantSubgraphType::SINGLE_COMPONENT,
ConstantSubgraphType::SINGLE_COMPONENT, ConstantSubgraphType::SEVERAL_COMPONENT};
ConstantSubgraphType::SEVERAL_COMPONENT
};
const std::vector<SizeVector> shapes = { const std::vector<ov::Shape> shapes = {{1, 3, 10, 10}, {2, 3, 4, 5}};
{1, 3, 10, 10},
{2, 3, 4, 5}
};
const std::vector<Precision> precisions = { const std::vector<ov::element::Type> precisions = {ov::element::u8,
Precision::U8, ov::element::i8,
Precision::I8, ov::element::u16,
Precision::U16, ov::element::i16,
Precision::I16, ov::element::u32,
Precision::I32, ov::element::i32,
Precision::U64, ov::element::u64,
Precision::I64, ov::element::i64,
Precision::FP32, ov::element::f32,
Precision::BOOL ov::element::boolean};
};
INSTANTIATE_TEST_SUITE_P(smoke_Check, ConstantResultSubgraphTest, INSTANTIATE_TEST_SUITE_P(smoke_Check,
::testing::Combine( ConstantResultSubgraphTest,
::testing::ValuesIn(types), ::testing::Combine(::testing::ValuesIn(types),
::testing::ValuesIn(shapes), ::testing::ValuesIn(shapes),
::testing::ValuesIn(precisions), ::testing::ValuesIn(precisions),
::testing::Values(ov::test::utils::DEVICE_CPU)), ::testing::Values(ov::test::utils::DEVICE_CPU)),
ConstantResultSubgraphTest::getTestCaseName); ConstantResultSubgraphTest::getTestCaseName);
} // namespace } // namespace

View File

@ -2,11 +2,10 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "subgraph_tests/constant_result.hpp"
#include <vector> #include <vector>
#include "common_test_utils/test_constants.hpp" #include "common_test_utils/test_constants.hpp"
#include "subgraph_tests/constant_result_legacy.hpp"
using namespace SubgraphTestsDefinitions; using namespace SubgraphTestsDefinitions;
using namespace InferenceEngine; using namespace InferenceEngine;

View File

@ -4,7 +4,7 @@
#include <vector> #include <vector>
#include "subgraph_tests/constant_result.hpp" #include "subgraph_tests/constant_result_legacy.hpp"
#include "common_test_utils/test_constants.hpp" #include "common_test_utils/test_constants.hpp"
using namespace SubgraphTestsDefinitions; using namespace SubgraphTestsDefinitions;

View File

@ -6,11 +6,12 @@
#include "shared_test_classes/subgraph/constant_result.hpp" #include "shared_test_classes/subgraph/constant_result.hpp"
namespace SubgraphTestsDefinitions { namespace ov {
namespace test {
TEST_P(ConstantResultSubgraphTest, CompareWithRefs) { TEST_P(ConstantResultSubgraphTest, CompareWithRefs) {
Run(); run();
} }
} // namespace SubgraphTestsDefinitions } // namespace test
} // namespace ov

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/subgraph/constant_result.hpp"
namespace SubgraphTestsDefinitions {
TEST_P(ConstantResultSubgraphTest, CompareWithRefs) {
Run();
}
} // namespace SubgraphTestsDefinitions

View File

@ -4,35 +4,63 @@
#pragma once #pragma once
#include <tuple>
#include <string>
#include <vector>
#include <memory> #include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "openvino/core/type/element_type.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp" #include "shared_test_classes/base/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp" #include "shared_test_classes/base/ov_subgraph.hpp"
namespace ov {
namespace test {
enum class ConstantSubgraphType { SINGLE_COMPONENT, SEVERAL_COMPONENT };
std::ostream& operator<<(std::ostream& os, ConstantSubgraphType type);
typedef std::tuple<ConstantSubgraphType,
ov::Shape, // input shape
ov::element::Type, // input element type
std::string // Device name
>
constResultParams;
class ConstantResultSubgraphTest : public testing::WithParamInterface<constResultParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<constResultParams>& obj);
void createGraph(const ConstantSubgraphType& type,
const ov::Shape& input_shape,
const ov::element::Type& input_type);
protected:
void SetUp() override;
};
} // namespace test
} // namespace ov
namespace SubgraphTestsDefinitions { namespace SubgraphTestsDefinitions {
enum class ConstantSubgraphType { using ov::test::ConstantSubgraphType;
SINGLE_COMPONENT,
SEVERAL_COMPONENT
};
std::ostream& operator<<(std::ostream &os, ConstantSubgraphType type); typedef std::tuple<ConstantSubgraphType,
InferenceEngine::SizeVector, // input shape
typedef std::tuple < InferenceEngine::Precision, // input precision
ConstantSubgraphType, std::string // Device name
InferenceEngine::SizeVector, // input shape >
InferenceEngine::Precision, // input precision constResultParams;
std::string // Device name
> constResultParams;
class ConstantResultSubgraphTest : public testing::WithParamInterface<constResultParams>, class ConstantResultSubgraphTest : public testing::WithParamInterface<constResultParams>,
virtual public LayerTestsUtils::LayerTestsCommon { virtual public LayerTestsUtils::LayerTestsCommon {
public: public:
static std::string getTestCaseName(const testing::TestParamInfo<constResultParams>& obj); static std::string getTestCaseName(const testing::TestParamInfo<constResultParams>& obj);
void createGraph(const ConstantSubgraphType& type, const InferenceEngine::SizeVector &inputShape, const InferenceEngine::Precision &inputPrecision); void createGraph(const ConstantSubgraphType& type,
const InferenceEngine::SizeVector& inputShape,
const InferenceEngine::Precision& inputPrecision);
protected: protected:
void SetUp() override; void SetUp() override;
}; };

View File

@ -4,29 +4,84 @@
#include "shared_test_classes/subgraph/constant_result.hpp" #include "shared_test_classes/subgraph/constant_result.hpp"
using namespace InferenceEngine; #include "ngraph_functions/builders.hpp"
using namespace ngraph; #include "openvino/op/result.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
namespace SubgraphTestsDefinitions { namespace ov {
namespace test {
std::ostream& operator<<(std::ostream &os, ConstantSubgraphType type) { std::ostream& operator<<(std::ostream& os, ConstantSubgraphType type) {
switch (type) { switch (type) {
case ConstantSubgraphType::SINGLE_COMPONENT: case ConstantSubgraphType::SINGLE_COMPONENT:
os << "SINGLE_COMPONENT"; os << "SINGLE_COMPONENT";
break; break;
case ConstantSubgraphType::SEVERAL_COMPONENT: case ConstantSubgraphType::SEVERAL_COMPONENT:
os << "SEVERAL_COMPONENT"; os << "SEVERAL_COMPONENT";
break; break;
default: default:
os << "UNSUPPORTED_CONST_SUBGRAPH_TYPE"; os << "UNSUPPORTED_CONST_SUBGRAPH_TYPE";
} }
return os; return os;
} }
std::string ConstantResultSubgraphTest::getTestCaseName(const testing::TestParamInfo<constResultParams>& obj) { std::string ConstantResultSubgraphTest::getTestCaseName(const testing::TestParamInfo<constResultParams>& obj) {
ConstantSubgraphType type; ConstantSubgraphType type;
SizeVector IS; ov::Shape input_shape;
Precision inputPrecision; ov::element::Type input_type;
std::string target_device;
std::tie(type, input_shape, input_type, target_device) = obj.param;
std::ostringstream result;
result << "SubgraphType=" << type << "_";
result << "IS=" << input_shape << "_";
result << "IT=" << input_type << "_";
result << "Device=" << target_device;
return result.str();
}
void ConstantResultSubgraphTest::createGraph(const ConstantSubgraphType& type,
const ov::Shape& input_shape,
const ov::element::Type& input_type) {
ParameterVector params;
ResultVector results;
switch (type) {
case ConstantSubgraphType::SINGLE_COMPONENT: {
auto input = ngraph::builder::makeConstant<float>(input_type, input_shape, {}, true);
results.push_back(std::make_shared<ov::op::v0::Result>(input));
break;
}
case ConstantSubgraphType::SEVERAL_COMPONENT: {
auto input1 = ngraph::builder::makeConstant<float>(input_type, input_shape, {}, true);
results.push_back(std::make_shared<ov::op::v0::Result>(input1));
auto input2 = ngraph::builder::makeConstant<float>(input_type, input_shape, {}, true);
results.push_back(std::make_shared<ov::op::v0::Result>(input2));
break;
}
default: {
throw std::runtime_error("Unsupported constant graph type");
}
}
function = std::make_shared<ov::Model>(results, params, "ConstResult");
}
void ConstantResultSubgraphTest::SetUp() {
ConstantSubgraphType type;
ov::Shape input_shape;
ov::element::Type input_type;
std::tie(type, input_shape, input_type, targetDevice) = this->GetParam();
createGraph(type, input_shape, input_type);
}
} // namespace test
} // namespace ov
namespace SubgraphTestsDefinitions {
std::string ConstantResultSubgraphTest::getTestCaseName(const testing::TestParamInfo<constResultParams>& obj) {
ConstantSubgraphType type;
InferenceEngine::SizeVector IS;
InferenceEngine::Precision inputPrecision;
std::string targetDevice; std::string targetDevice;
std::tie(type, IS, inputPrecision, targetDevice) = obj.param; std::tie(type, IS, inputPrecision, targetDevice) = obj.param;
@ -38,35 +93,37 @@ std::string ConstantResultSubgraphTest::getTestCaseName(const testing::TestParam
return result.str(); return result.str();
} }
void ConstantResultSubgraphTest::createGraph(const ConstantSubgraphType& type, const SizeVector &inputShape, const Precision &inputPrecision) { void ConstantResultSubgraphTest::createGraph(const ConstantSubgraphType& type,
const InferenceEngine::SizeVector& inputShape,
const InferenceEngine::Precision& inputPrecision) {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inputPrecision); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inputPrecision);
ParameterVector params; ov::ParameterVector params;
ResultVector results; ov::ResultVector results;
switch (type) { switch (type) {
case ConstantSubgraphType::SINGLE_COMPONENT: { case ConstantSubgraphType::SINGLE_COMPONENT: {
auto input = builder::makeConstant<float>(ngPrc, inputShape, {}, true); auto input = ngraph::builder::makeConstant<float>(ngPrc, inputShape, {}, true);
results.push_back(std::make_shared<opset3::Result>(input)); results.push_back(std::make_shared<ov::op::v0::Result>(input));
break; break;
}
case ConstantSubgraphType::SEVERAL_COMPONENT: {
auto input1 = builder::makeConstant<float>(ngPrc, inputShape, {}, true);
results.push_back(std::make_shared<opset3::Result>(input1));
auto input2 = builder::makeConstant<float>(ngPrc, inputShape, {}, true);
results.push_back(std::make_shared<opset3::Result>(input2));
break;
}
default: {
throw std::runtime_error("Unsupported constant graph type");
}
} }
function = std::make_shared<Function>(results, params, "ConstResult"); case ConstantSubgraphType::SEVERAL_COMPONENT: {
auto input1 = ngraph::builder::makeConstant<float>(ngPrc, inputShape, {}, true);
results.push_back(std::make_shared<ov::op::v0::Result>(input1));
auto input2 = ngraph::builder::makeConstant<float>(ngPrc, inputShape, {}, true);
results.push_back(std::make_shared<ov::op::v0::Result>(input2));
break;
}
default: {
throw std::runtime_error("Unsupported constant graph type");
}
}
function = std::make_shared<ov::Model>(results, params, "ConstResult");
} }
void ConstantResultSubgraphTest::SetUp() { void ConstantResultSubgraphTest::SetUp() {
ConstantSubgraphType type; ConstantSubgraphType type;
SizeVector IS; InferenceEngine::SizeVector IS;
Precision inputPrecision; InferenceEngine::Precision inputPrecision;
std::tie(type, IS, inputPrecision, targetDevice) = this->GetParam(); std::tie(type, IS, inputPrecision, targetDevice) = this->GetParam();
createGraph(type, IS, inputPrecision); createGraph(type, IS, inputPrecision);