[CPU Plugin][Func Test] Upgrade FuseMulAddAndEwSimpleTest to API 2.0 (#21331)
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
parent
7ab79be0f6
commit
409ed190de
@ -4,23 +4,23 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
using FuseMulAddAndEwSimpleParams = std::tuple<ov::Shape, // Input shape
|
||||
ov::element::Type // Input precision
|
||||
>;
|
||||
|
||||
using FuseMulAddAndEwSimpleParams = std::tuple<
|
||||
InferenceEngine::SizeVector, // Input shape
|
||||
InferenceEngine::Precision // Input precision
|
||||
>;
|
||||
|
||||
class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>, public CPUTestUtils::CPUTestsBase,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
class FuseMulAddAndEwSimpleTest : public testing::WithParamInterface<FuseMulAddAndEwSimpleParams>,
|
||||
public CPUTestUtils::CPUTestsBase,
|
||||
virtual public SubgraphBaseStaticTest {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj);
|
||||
|
||||
@ -28,8 +28,8 @@ protected:
|
||||
void SetUp() override;
|
||||
virtual void CreateGraph() = 0;
|
||||
|
||||
InferenceEngine::SizeVector inputShape;
|
||||
InferenceEngine::Precision inPrec;
|
||||
ov::Shape inputShape;
|
||||
ov::element::Type inPrec;
|
||||
};
|
||||
|
||||
class FuseMulAddAndEwSimpleTest1 : public FuseMulAddAndEwSimpleTest {
|
||||
@ -47,4 +47,5 @@ protected:
|
||||
void CreateGraph() override;
|
||||
};
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
@ -3,23 +3,26 @@
|
||||
//
|
||||
|
||||
#include "subgraph_tests/include/fuse_muladd_ewsimple.hpp"
|
||||
|
||||
#include "common_test_utils/node_builders/activation.hpp"
|
||||
#include "common_test_utils/node_builders/eltwise.hpp"
|
||||
#include "ov_models/builders.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using ngraph::helpers::EltwiseTypes;
|
||||
using ngraph::helpers::ActivationTypes;
|
||||
using ov::test::utils::ActivationTypes;
|
||||
using ov::test::utils::EltwiseTypes;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
namespace ov {
|
||||
namespace test {
|
||||
|
||||
std::string FuseMulAddAndEwSimpleTest::getTestCaseName(testing::TestParamInfo<FuseMulAddAndEwSimpleParams> obj) {
|
||||
std::ostringstream result;
|
||||
SizeVector inputShape;
|
||||
Precision inPrec;
|
||||
ov::Shape inputShape;
|
||||
ov::element::Type inPrec;
|
||||
std::tie(inputShape, inPrec) = obj.param;
|
||||
|
||||
result << "IS=" << ov::test::utils::vec2str(inputShape) << "_";
|
||||
result << "Precision=" << inPrec.name();
|
||||
result << "IS=" << inputShape << "_";
|
||||
result << "Precision=" << inPrec.get_type_name();
|
||||
|
||||
return result.str();
|
||||
}
|
||||
@ -31,83 +34,85 @@ void FuseMulAddAndEwSimpleTest::SetUp() {
|
||||
CreateGraph();
|
||||
}
|
||||
|
||||
const auto mulAddAndEwSimpleCommonParams = ::testing::Combine(
|
||||
::testing::Values(SizeVector{1, 20}),
|
||||
::testing::Values(Precision::FP32)
|
||||
);
|
||||
|
||||
const auto mulAddAndEwSimpleCommonParams =
|
||||
::testing::Combine(::testing::Values(ov::Shape({1, 20})), ::testing::Values(ov::element::f32));
|
||||
|
||||
// Fused EltwiseAndSimple comes on the 3rd port into MulAdd
|
||||
void FuseMulAddAndEwSimpleTest1::CreateGraph() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||
auto mulSecondInput = inputShape;
|
||||
mulSecondInput[0] = 1;
|
||||
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(mulSecondInput))};
|
||||
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||
std::make_shared<ov::op::v0::Parameter>(inPrec, mulSecondInput)};
|
||||
|
||||
auto clamp = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||
auto tanh = ngraph::builder::makeActivation(clamp, ngPrc, ActivationTypes::Tanh);
|
||||
auto mul1 = ngraph::builder::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY);
|
||||
auto add = ngraph::builder::makeEltwise(tanh, mul1, EltwiseTypes::ADD);
|
||||
auto clamp = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||
auto tanh = ov::test::utils::make_activation(clamp, inPrec, ActivationTypes::Tanh);
|
||||
auto mul1 = ov::test::utils::makeEltwise(params[1], params[2], EltwiseTypes::MULTIPLY);
|
||||
auto add = ov::test::utils::makeEltwise(tanh, mul1, EltwiseTypes::ADD);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple");
|
||||
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
|
||||
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple");
|
||||
}
|
||||
|
||||
TEST_P(FuseMulAddAndEwSimpleTest1, CompareWithRefs) {
|
||||
Run();
|
||||
run();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest1, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
|
||||
FuseMulAddAndEwSimpleTest1,
|
||||
mulAddAndEwSimpleCommonParams,
|
||||
FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||
|
||||
// Fused EltwiseAndSimple comes on the 2nd input into MulAdd
|
||||
void FuseMulAddAndEwSimpleTest2::CreateGraph() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape)),
|
||||
std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(inputShape))};
|
||||
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape),
|
||||
std::make_shared<ov::op::v0::Parameter>(inPrec, inputShape)};
|
||||
|
||||
auto clamp1 = ngraph::builder::makeActivation(params[0], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||
auto tanh1 = ngraph::builder::makeActivation(clamp1, ngPrc, ActivationTypes::Tanh);
|
||||
auto clamp2 = ngraph::builder::makeActivation(params[1], ngPrc, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||
auto tanh2 = ngraph::builder::makeActivation(clamp2, ngPrc, ActivationTypes::Tanh);
|
||||
auto mul1 = ngraph::builder::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY);
|
||||
auto add = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
||||
auto clamp1 = ov::test::utils::make_activation(params[0], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||
auto tanh1 = ov::test::utils::make_activation(clamp1, inPrec, ActivationTypes::Tanh);
|
||||
auto clamp2 = ov::test::utils::make_activation(params[1], inPrec, ActivationTypes::Clamp, inputShape, {0, 100});
|
||||
auto tanh2 = ov::test::utils::make_activation(clamp2, inPrec, ActivationTypes::Tanh);
|
||||
auto mul1 = ov::test::utils::makeEltwise(tanh2, tanh1, EltwiseTypes::MULTIPLY);
|
||||
auto add = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_2");
|
||||
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
|
||||
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_2");
|
||||
}
|
||||
|
||||
TEST_P(FuseMulAddAndEwSimpleTest2, CompareWithRefs) {
|
||||
Run();
|
||||
run();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest2, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
|
||||
FuseMulAddAndEwSimpleTest2,
|
||||
mulAddAndEwSimpleCommonParams,
|
||||
FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||
|
||||
// Fused MulAdd with more than 3 inputs
|
||||
void FuseMulAddAndEwSimpleTest3::CreateGraph() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
|
||||
ov::ParameterVector params;
|
||||
for (auto&& shape : {inputShape, inputShape, inputShape, inputShape, inputShape}) {
|
||||
params.push_back(std::make_shared<ov::op::v0::Parameter>(ngPrc, ov::Shape(shape)));
|
||||
params.push_back(std::make_shared<ov::op::v0::Parameter>(inPrec, shape));
|
||||
}
|
||||
|
||||
auto mul1 = ngraph::builder::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY);
|
||||
auto add1 = ngraph::builder::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
||||
auto tanh1 = ngraph::builder::makeActivation(add1, ngPrc, ActivationTypes::Tanh);
|
||||
auto mul2 = ngraph::builder::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY);
|
||||
auto add2 = ngraph::builder::makeEltwise(params[4], mul2, EltwiseTypes::ADD);
|
||||
auto mul1 = ov::test::utils::makeEltwise(params[0], params[1], EltwiseTypes::MULTIPLY);
|
||||
auto add1 = ov::test::utils::makeEltwise(mul1, params[2], EltwiseTypes::ADD);
|
||||
auto tanh1 = ov::test::utils::make_activation(add1, inPrec, ActivationTypes::Tanh);
|
||||
auto mul2 = ov::test::utils::makeEltwise(tanh1, params[3], EltwiseTypes::MULTIPLY);
|
||||
auto add2 = ov::test::utils::makeEltwise(params[4], mul2, EltwiseTypes::ADD);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(add2)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "MulAdd_EwSimple_3");
|
||||
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(add2)};
|
||||
function = std::make_shared<ov::Model>(results, params, "MulAdd_EwSimple_3");
|
||||
}
|
||||
|
||||
TEST_P(FuseMulAddAndEwSimpleTest3, CompareWithRefs) {
|
||||
Run();
|
||||
run();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseMulAddAndEwSimpleTest3, mulAddAndEwSimpleCommonParams, FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Basic,
|
||||
FuseMulAddAndEwSimpleTest3,
|
||||
mulAddAndEwSimpleCommonParams,
|
||||
FuseMulAddAndEwSimpleTest::getTestCaseName);
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user