[CPU Plugin][Func Test] Upgrade FuseMulAddAndEwSimpleTest to API 2.0 (#21331)

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
Xuejun Zhai 2023-11-28 19:06:20 +08:00 committed by GitHub
parent 7ab79be0f6
commit 409ed190de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 66 deletions

View File

@ -4,23 +4,23 @@
#pragma once #pragma once
#include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include <string>
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp" #include "test_utils/cpu_test_utils.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
namespace ov {
namespace test {
namespace SubgraphTestsDefinitions { using FuseMulAddAndEwSimpleParams = std::tuple<ov::Shape, // Input shape
ov::element::Type // Input precision
>;
using FuseMulAddAndEwSimpleParams = std::tuple< class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>,
InferenceEngine::SizeVector, // Input shape public CPUTestUtils::CPUTestsBase,
InferenceEngine::Precision // Input precision virtual public SubgraphBaseStaticTest {
>;
class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>, public CPUTestUtils::CPUTestsBase,
virtual public LayerTestsUtils::LayerTestsCommon {
public: public:
static std::string getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj); static std::string getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj);
@ -28,8 +28,8 @@ protected:
void SetUp() override; void SetUp() override;
virtual void CreateGraph() = 0; virtual void CreateGraph() = 0;
InferenceEngine::SizeVector inputShape; ov::Shape inputShape;
InferenceEngine::Precision inPrec; ov::element::Type inPrec;
}; };
class FuseMulAddAndEwSimpleTest1 : public FuseMulAddAndEwSimpleTest { class FuseMulAddAndEwSimpleTest1 : public FuseMulAddAndEwSimpleTest {
@ -47,4 +47,5 @@ protected:
void CreateGraph() override; void CreateGraph() override;
}; };
} // namespace SubgraphTestsDefinitions } // namespace test
} // namespace ov

View File

@ -3,23 +3,26 @@
// //
#include "subgraph_tests/include/fuse_muladd_ewsimple.hpp" #include "subgraph_tests/include/fuse_muladd_ewsimple.hpp"
#include "common_test_utils/node_builders/activation.hpp"
#include "common_test_utils/node_builders/eltwise.hpp"
#include "ov_models/builders.hpp" #include "ov_models/builders.hpp"
using namespace InferenceEngine;
using namespace CPUTestUtils; using namespace CPUTestUtils;
using ngraph::helpers::EltwiseTypes; using ov::test::utils::ActivationTypes;
using ngraph::helpers::ActivationTypes; using ov::test::utils::EltwiseTypes;
namespace SubgraphTestsDefinitions { namespace ov {
namespace test {
std::string FuseMulAddAndEwSimpleTest::getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj) { std::string FuseMulAddAndEwSimpleTest::getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj) {
std::ostringstream result; std::ostringstream result;
SizeVector inputShape; ov::Shape inputShape;
Precision inPrec; ov::element::Type inPrec;
std::tie(inputShape, inPrec) = obj.param; std::tie(inputShape, inPrec) = obj.param;
result << "IS=" << ov::test::utils::vec2str(inputShape) << "_"; result << "IS=" << inputShape << "_";
result << "Precision=" << inPrec.name(); result << "Precision=" << inPrec.get_type_name();
return result.str(); return result.str();
} }
@ -31,83 +34,85 @@ void FuseMulAddAndEwSimpleTest::SetUp() {
CreateGraph(); CreateGraph();
} }
const auto mulAddAndEwSimpleCommonParams = ::testing::Combine( const auto mulAddAndEwSimpleCommonParams =
::testing::Values(SizeVector{1, 20}), ::testing::Combine(::testing::Values(ov::Shape({1, 20})), ::testing::Values(ov::element::f32));
::testing::Values(Precision::FP32)
);
// Fused EltwiseAndSimple comes on the 3rd port into MulAdd // Fused EltwiseAndSimple comes on the 3rd port into MulAdd
void FuseMulAddAndEwSimpleTest1::CreateGraph() { void FuseMulAddAndEwSimpleTest1::CreateGraph() {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
auto mulSecondInput = inputShape; auto mulSecondInput = inputShape;
mulSecondInput[0] = 1; mulSecondInput[0] = 1;
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)), ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)), std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(mulSecondInput))}; std::make_shared<ov::op::v0::Parameter>(inPrec, mulSecondInput)};
auto clamp = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100}); auto clamp = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh = ngraph::builder::makeActivation(clamp, ngPrc, ActivationTypes::Tanh); auto tanh = ov::test::utils::make_activation(clamp, inPrec, ActivationTypes::Tanh);
auto mul1 = ngraph::builder::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY); auto mul1 = ov::test::utils::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY);
auto add = ngraph::builder::makeEltwise(tanh, mul1, EltwiseTypes::ADD); auto add = ov::test::utils::makeEltwise(tanh, mul1, EltwiseTypes::ADD);
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)}; ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple"); function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple");
} }
TEST_P(FuseMulAddAndEwSimpleTest1, CompareWithRefs) { TEST_P(FuseMulAddAndEwSimpleTest1, CompareWithRefs) {
Run(); run();
} }
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest1, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_Basic,
FuseMulAddAndEwSimpleTest1,
mulAddAndEwSimpleCommonParams,
FuseMulAddAndEwSimpleTest::getTestCaseName);
// Fused EltwiseAndSimple comes on the 2nd input into MulAdd // Fused EltwiseAndSimple comes on the 2nd input into MulAdd
void FuseMulAddAndEwSimpleTest2::CreateGraph() { void FuseMulAddAndEwSimpleTest2::CreateGraph() {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec); ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)), std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)), std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape)};
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape))};
auto clamp1 = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100}); auto clamp1 = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh1 = ngraph::builder::makeActivation(clamp1, ngPrc, ActivationTypes::Tanh); auto tanh1 = ov::test::utils::make_activation(clamp1, inPrec, ActivationTypes::Tanh);
auto clamp2 = ngraph::builder::makeActivation(params[1], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100}); auto clamp2 = ov::test::utils::make_activation(params[1], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh2 = ngraph::builder::makeActivation(clamp2, ngPrc, ActivationTypes::Tanh); auto tanh2 = ov::test::utils::make_activation(clamp2, inPrec, ActivationTypes::Tanh);
auto mul1 = ngraph::builder::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY); auto mul1 = ov::test::utils::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY);
auto add = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD); auto add = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)}; ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_2"); function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_2");
} }
TEST_P(FuseMulAddAndEwSimpleTest2, CompareWithRefs) { TEST_P(FuseMulAddAndEwSimpleTest2, CompareWithRefs) {
Run(); run();
} }
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest2, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_Basic,
FuseMulAddAndEwSimpleTest2,
mulAddAndEwSimpleCommonParams,
FuseMulAddAndEwSimpleTest::getTestCaseName);
// Fused MulAdd with more than 3 inputs // Fused MulAdd with more than 3 inputs
void FuseMulAddAndEwSimpleTest3::CreateGraph() { void FuseMulAddAndEwSimpleTest3::CreateGraph() {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
ov::ParameterVector params; ov::ParameterVector params;
for (auto&& shape : {inputShape, inputShape, inputShape, inputShape, inputShape}) { for (auto&& shape : {inputShape, inputShape, inputShape, inputShape, inputShape}) {
params.push_back(std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(shape))); params.push_back(std::make_shared<ov::op::v0::Parameter>(inPrec, shape));
} }
auto mul1 = ngraph::builder::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY); auto mul1 = ov::test::utils::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY);
auto add1 = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD); auto add1 = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
auto tanh1 = ngraph::builder::makeActivation(add1, ngPrc, ActivationTypes::Tanh); auto tanh1 = ov::test::utils::make_activation(add1, inPrec, ActivationTypes::Tanh);
auto mul2 = ngraph::builder::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY); auto mul2 = ov::test::utils::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY);
auto add2 = ngraph::builder::makeEltwise(params[4], mul2, EltwiseTypes::ADD); auto add2 = ov::test::utils::makeEltwise(params[4], mul2, EltwiseTypes::ADD);
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add2)}; ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add2)};
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_3"); function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_3");
} }
TEST_P(FuseMulAddAndEwSimpleTest3, CompareWithRefs) { TEST_P(FuseMulAddAndEwSimpleTest3, CompareWithRefs) {
Run(); run();
} }
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest3, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_Basic,
} // namespace SubgraphTestsDefinitions FuseMulAddAndEwSimpleTest3,
mulAddAndEwSimpleCommonParams,
FuseMulAddAndEwSimpleTest::getTestCaseName);
} // namespace test
} // namespace ov