[CPU Plugin][Func Test] Upgrade FuseMulAddAndEwSimpleTest to API 2.0 (#21331)
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
parent
7ab79be0f6
commit
409ed190de
@ -4,23 +4,23 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
|
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||||
#include "test_utils/cpu_test_utils.hpp"
|
#include "test_utils/cpu_test_utils.hpp"
|
||||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace test {
|
||||||
|
|
||||||
namespace SubgraphTestsDefinitions {
|
using FuseMulAddAndEwSimpleParams = std::tuple<ov::Shape, // Input shape
|
||||||
|
ov::element::Type // Input precision
|
||||||
|
>;
|
||||||
|
|
||||||
using FuseMulAddAndEwSimpleParams = std::tuple<
|
class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>,
|
||||||
InferenceEngine::SizeVector, // Input shape
|
public CPUTestUtils::CPUTestsBase,
|
||||||
InferenceEngine::Precision // Input precision
|
virtual public SubgraphBaseStaticTest {
|
||||||
>;
|
|
||||||
|
|
||||||
class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>, public CPUTestUtils::CPUTestsBase,
|
|
||||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
|
||||||
public:
|
public:
|
||||||
static std::string getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj);
|
static std::string getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj);
|
||||||
|
|
||||||
@ -28,8 +28,8 @@ protected:
|
|||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
virtual void CreateGraph() = 0;
|
virtual void CreateGraph() = 0;
|
||||||
|
|
||||||
InferenceEngine::SizeVector inputShape;
|
ov::Shape inputShape;
|
||||||
InferenceEngine::Precision inPrec;
|
ov::element::Type inPrec;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FuseMulAddAndEwSimpleTest1 : public FuseMulAddAndEwSimpleTest {
|
class FuseMulAddAndEwSimpleTest1 : public FuseMulAddAndEwSimpleTest {
|
||||||
@ -47,4 +47,5 @@ protected:
|
|||||||
void CreateGraph() override;
|
void CreateGraph() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace SubgraphTestsDefinitions
|
} // namespace test
|
||||||
|
} // namespace ov
|
||||||
|
@ -3,23 +3,26 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "subgraph_tests/include/fuse_muladd_ewsimple.hpp"
|
#include "subgraph_tests/include/fuse_muladd_ewsimple.hpp"
|
||||||
|
|
||||||
|
#include "common_test_utils/node_builders/activation.hpp"
|
||||||
|
#include "common_test_utils/node_builders/eltwise.hpp"
|
||||||
#include "ov_models/builders.hpp"
|
#include "ov_models/builders.hpp"
|
||||||
|
|
||||||
using namespace InferenceEngine;
|
|
||||||
using namespace CPUTestUtils;
|
using namespace CPUTestUtils;
|
||||||
using ngraph::helpers::EltwiseTypes;
|
using ov::test::utils::ActivationTypes;
|
||||||
using ngraph::helpers::ActivationTypes;
|
using ov::test::utils::EltwiseTypes;
|
||||||
|
|
||||||
namespace SubgraphTestsDefinitions {
|
namespace ov {
|
||||||
|
namespace test {
|
||||||
|
|
||||||
std::string FuseMulAddAndEwSimpleTest::getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj) {
|
std::string FuseMulAddAndEwSimpleTest::getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj) {
|
||||||
std::ostringstream result;
|
std::ostringstream result;
|
||||||
SizeVector inputShape;
|
ov::Shape inputShape;
|
||||||
Precision inPrec;
|
ov::element::Type inPrec;
|
||||||
std::tie(inputShape, inPrec) = obj.param;
|
std::tie(inputShape, inPrec) = obj.param;
|
||||||
|
|
||||||
result << "IS=" << ov::test::utils::vec2str(inputShape) << "_";
|
result << "IS=" << inputShape << "_";
|
||||||
result << "Precision=" << inPrec.name();
|
result << "Precision=" << inPrec.get_type_name();
|
||||||
|
|
||||||
return result.str();
|
return result.str();
|
||||||
}
|
}
|
||||||
@ -31,83 +34,85 @@ void FuseMulAddAndEwSimpleTest::SetUp() {
|
|||||||
CreateGraph();
|
CreateGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto mulAddAndEwSimpleCommonParams = ::testing::Combine(
|
const auto mulAddAndEwSimpleCommonParams =
|
||||||
::testing::Values(SizeVector{1, 20}),
|
::testing::Combine(::testing::Values(ov::Shape({1, 20})), ::testing::Values(ov::element::f32));
|
||||||
::testing::Values(Precision::FP32)
|
|
||||||
);
|
|
||||||
|
|
||||||
|
|
||||||
// Fused EltwiseAndSimple comes on the 3rd port into MulAdd
|
// Fused EltwiseAndSimple comes on the 3rd port into MulAdd
|
||||||
void FuseMulAddAndEwSimpleTest1::CreateGraph() {
|
void FuseMulAddAndEwSimpleTest1::CreateGraph() {
|
||||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
|
||||||
auto mulSecondInput = inputShape;
|
auto mulSecondInput = inputShape;
|
||||||
mulSecondInput[0] = 1;
|
mulSecondInput[0] = 1;
|
||||||
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(mulSecondInput))};
|
std::make_shared<ov::op::v0::Parameter>(inPrec, mulSecondInput)};
|
||||||
|
|
||||||
auto clamp = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
|
auto clamp = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||||
auto tanh = ngraph::builder::makeActivation(clamp, ngPrc, ActivationTypes::Tanh);
|
auto tanh = ov::test::utils::make_activation(clamp, inPrec, ActivationTypes::Tanh);
|
||||||
auto mul1 = ngraph::builder::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY);
|
auto mul1 = ov::test::utils::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY);
|
||||||
auto add = ngraph::builder::makeEltwise(tanh, mul1, EltwiseTypes::ADD);
|
auto add = ov::test::utils::makeEltwise(tanh, mul1, EltwiseTypes::ADD);
|
||||||
|
|
||||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)};
|
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
|
||||||
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple");
|
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(FuseMulAddAndEwSimpleTest1, CompareWithRefs) {
|
TEST_P(FuseMulAddAndEwSimpleTest1, CompareWithRefs) {
|
||||||
Run();
|
run();
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest1, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
|
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
|
||||||
|
FuseMulAddAndEwSimpleTest1,
|
||||||
|
mulAddAndEwSimpleCommonParams,
|
||||||
|
FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||||
|
|
||||||
// Fused EltwiseAndSimple comes on the 2nd input into MulAdd
|
// Fused EltwiseAndSimple comes on the 2nd input into MulAdd
|
||||||
void FuseMulAddAndEwSimpleTest2::CreateGraph() {
|
void FuseMulAddAndEwSimpleTest2::CreateGraph() {
|
||||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||||
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape)};
|
||||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape))};
|
|
||||||
|
|
||||||
auto clamp1 = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
|
auto clamp1 = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||||
auto tanh1 = ngraph::builder::makeActivation(clamp1, ngPrc, ActivationTypes::Tanh);
|
auto tanh1 = ov::test::utils::make_activation(clamp1, inPrec, ActivationTypes::Tanh);
|
||||||
auto clamp2 = ngraph::builder::makeActivation(params[1], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
|
auto clamp2 = ov::test::utils::make_activation(params[1], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||||
auto tanh2 = ngraph::builder::makeActivation(clamp2, ngPrc, ActivationTypes::Tanh);
|
auto tanh2 = ov::test::utils::make_activation(clamp2, inPrec, ActivationTypes::Tanh);
|
||||||
auto mul1 = ngraph::builder::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY);
|
auto mul1 = ov::test::utils::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY);
|
||||||
auto add = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
auto add = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
||||||
|
|
||||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)};
|
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
|
||||||
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_2");
|
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_2");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(FuseMulAddAndEwSimpleTest2, CompareWithRefs) {
|
TEST_P(FuseMulAddAndEwSimpleTest2, CompareWithRefs) {
|
||||||
Run();
|
run();
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest2, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
|
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
|
||||||
|
FuseMulAddAndEwSimpleTest2,
|
||||||
|
mulAddAndEwSimpleCommonParams,
|
||||||
|
FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||||
|
|
||||||
// Fused MulAdd with more than 3 inputs
|
// Fused MulAdd with more than 3 inputs
|
||||||
void FuseMulAddAndEwSimpleTest3::CreateGraph() {
|
void FuseMulAddAndEwSimpleTest3::CreateGraph() {
|
||||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
|
||||||
ov::ParameterVector params;
|
ov::ParameterVector params;
|
||||||
for (auto&& shape : {inputShape, inputShape, inputShape, inputShape, inputShape}) {
|
for (auto&& shape : {inputShape, inputShape, inputShape, inputShape, inputShape}) {
|
||||||
params.push_back(std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(shape)));
|
params.push_back(std::make_shared<ov::op::v0::Parameter>(inPrec, shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mul1 = ngraph::builder::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY);
|
auto mul1 = ov::test::utils::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY);
|
||||||
auto add1 = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
auto add1 = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
||||||
auto tanh1 = ngraph::builder::makeActivation(add1, ngPrc, ActivationTypes::Tanh);
|
auto tanh1 = ov::test::utils::make_activation(add1, inPrec, ActivationTypes::Tanh);
|
||||||
auto mul2 = ngraph::builder::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY);
|
auto mul2 = ov::test::utils::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY);
|
||||||
auto add2 = ngraph::builder::makeEltwise(params[4], mul2, EltwiseTypes::ADD);
|
auto add2 = ov::test::utils::makeEltwise(params[4], mul2, EltwiseTypes::ADD);
|
||||||
|
|
||||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add2)};
|
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add2)};
|
||||||
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_3");
|
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_3");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(FuseMulAddAndEwSimpleTest3, CompareWithRefs) {
|
TEST_P(FuseMulAddAndEwSimpleTest3, CompareWithRefs) {
|
||||||
Run();
|
run();
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest3, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
|
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
|
||||||
} // namespace SubgraphTestsDefinitions
|
FuseMulAddAndEwSimpleTest3,
|
||||||
|
mulAddAndEwSimpleCommonParams,
|
||||||
|
FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||||
|
} // namespace test
|
||||||
|
} // namespace ov
|
||||||
|
Loading…
Reference in New Issue
Block a user