[CPU Plugin][Func Test] Upgrade FuseMulAddAndEwSimpleTest to API 2.0 (#21331)

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
Xuejun Zhai 2023-11-28 19:06:20 +08:00 committed by GitHub
parent 7ab79be0f6
commit 409ed190de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 66 deletions

View File

@ -4,23 +4,23 @@
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include <string>
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
namespace ov {
namespace test {
namespace SubgraphTestsDefinitions {
using FuseMulAddAndEwSimpleParams = std::tuple<ov::Shape, // Input shape
ov::element::Type // Input precision
>;
using FuseMulAddAndEwSimpleParams = std::tuple<
InferenceEngine::SizeVector, // Input shape
InferenceEngine::Precision // Input precision
>;
class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>, public CPUTestUtils::CPUTestsBase,
virtual public LayerTestsUtils::LayerTestsCommon {
class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>,
public CPUTestUtils::CPUTestsBase,
virtual public SubgraphBaseStaticTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj);
@ -28,8 +28,8 @@ protected:
void SetUp() override;
virtual void CreateGraph() = 0;
InferenceEngine::SizeVector inputShape;
InferenceEngine::Precision inPrec;
ov::Shape inputShape;
ov::element::Type inPrec;
};
class FuseMulAddAndEwSimpleTest1 : public FuseMulAddAndEwSimpleTest {
@ -47,4 +47,5 @@ protected:
void CreateGraph() override;
};
} // namespace SubgraphTestsDefinitions
} // namespace test
} // namespace ov

View File

@ -3,23 +3,26 @@
//
#include "subgraph_tests/include/fuse_muladd_ewsimple.hpp"
#include "common_test_utils/node_builders/activation.hpp"
#include "common_test_utils/node_builders/eltwise.hpp"
#include "ov_models/builders.hpp"
using namespace InferenceEngine;
using namespace CPUTestUtils;
using ngraph::helpers::EltwiseTypes;
using ngraph::helpers::ActivationTypes;
using ov::test::utils::ActivationTypes;
using ov::test::utils::EltwiseTypes;
namespace SubgraphTestsDefinitions {
namespace ov {
namespace test {
std::string FuseMulAddAndEwSimpleTest::getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj) {
std::ostringstream result;
SizeVector inputShape;
Precision inPrec;
ov::Shape inputShape;
ov::element::Type inPrec;
std::tie(inputShape, inPrec) = obj.param;
result << "IS=" << ov::test::utils::vec2str(inputShape) << "_";
result << "Precision=" << inPrec.name();
result << "IS=" << inputShape << "_";
result << "Precision=" << inPrec.get_type_name();
return result.str();
}
@ -31,83 +34,85 @@ void FuseMulAddAndEwSimpleTest::SetUp() {
CreateGraph();
}
const auto mulAddAndEwSimpleCommonParams = ::testing::Combine(
::testing::Values(SizeVector{1, 20}),
::testing::Values(Precision::FP32)
);
const auto mulAddAndEwSimpleCommonParams =
::testing::Combine(::testing::Values(ov::Shape({1, 20})), ::testing::Values(ov::element::f32));
// Fused EltwiseAndSimple comes on the 3rd port into MulAdd
void FuseMulAddAndEwSimpleTest1::CreateGraph() {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
auto mulSecondInput = inputShape;
mulSecondInput[0] = 1;
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(mulSecondInput))};
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
std::make_shared<ov::op::v0::Parameter>(inPrec, mulSecondInput)};
auto clamp = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh = ngraph::builder::makeActivation(clamp, ngPrc, ActivationTypes::Tanh);
auto mul1 = ngraph::builder::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY);
auto add = ngraph::builder::makeEltwise(tanh, mul1, EltwiseTypes::ADD);
auto clamp = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh = ov::test::utils::make_activation(clamp, inPrec, ActivationTypes::Tanh);
auto mul1 = ov::test::utils::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY);
auto add = ov::test::utils::makeEltwise(tanh, mul1, EltwiseTypes::ADD);
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)};
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple");
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple");
}
TEST_P(FuseMulAddAndEwSimpleTest1, CompareWithRefs) {
Run();
run();
}
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest1, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
FuseMulAddAndEwSimpleTest1,
mulAddAndEwSimpleCommonParams,
FuseMulAddAndEwSimpleTest::getTestCaseName);
// Fused EltwiseAndSimple comes on the 2nd input into MulAdd
void FuseMulAddAndEwSimpleTest2::CreateGraph() {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape))};
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape)};
auto clamp1 = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh1 = ngraph::builder::makeActivation(clamp1, ngPrc, ActivationTypes::Tanh);
auto clamp2 = ngraph::builder::makeActivation(params[1], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh2 = ngraph::builder::makeActivation(clamp2, ngPrc, ActivationTypes::Tanh);
auto mul1 = ngraph::builder::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY);
auto add = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
auto clamp1 = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh1 = ov::test::utils::make_activation(clamp1, inPrec, ActivationTypes::Tanh);
auto clamp2 = ov::test::utils::make_activation(params[1], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
auto tanh2 = ov::test::utils::make_activation(clamp2, inPrec, ActivationTypes::Tanh);
auto mul1 = ov::test::utils::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY);
auto add = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)};
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_2");
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_2");
}
TEST_P(FuseMulAddAndEwSimpleTest2, CompareWithRefs) {
Run();
run();
}
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest2, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
FuseMulAddAndEwSimpleTest2,
mulAddAndEwSimpleCommonParams,
FuseMulAddAndEwSimpleTest::getTestCaseName);
// Fused MulAdd with more than 3 inputs
void FuseMulAddAndEwSimpleTest3::CreateGraph() {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
ov::ParameterVector params;
for (auto&& shape : {inputShape, inputShape, inputShape, inputShape, inputShape}) {
params.push_back(std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(shape)));
params.push_back(std::make_shared<ov::op::v0::Parameter>(inPrec, shape));
}
auto mul1 = ngraph::builder::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY);
auto add1 = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
auto tanh1 = ngraph::builder::makeActivation(add1, ngPrc, ActivationTypes::Tanh);
auto mul2 = ngraph::builder::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY);
auto add2 = ngraph::builder::makeEltwise(params[4], mul2, EltwiseTypes::ADD);
auto mul1 = ov::test::utils::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY);
auto add1 = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
auto tanh1 = ov::test::utils::make_activation(add1, inPrec, ActivationTypes::Tanh);
auto mul2 = ov::test::utils::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY);
auto add2 = ov::test::utils::makeEltwise(params[4], mul2, EltwiseTypes::ADD);
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add2)};
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_3");
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add2)};
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_3");
}
TEST_P(FuseMulAddAndEwSimpleTest3, CompareWithRefs) {
Run();
run();
}
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest3, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
} // namespace SubgraphTestsDefinitions
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
FuseMulAddAndEwSimpleTest3,
mulAddAndEwSimpleCommonParams,
FuseMulAddAndEwSimpleTest::getTestCaseName);
} // namespace test
} // namespace ov