[CPU] Select via Eltwise implementation (#15740)

This commit is contained in:
Vladislav Golubev 2023-03-01 11:03:47 +01:00 committed by GitHub
parent 113aefa3ff
commit f0e12cf38b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 231 additions and 492 deletions

View File

@ -68,6 +68,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
{ "Erf", Type::Eltwise },
{ "SoftPlus", Type::Eltwise },
{ "SoftSign", Type::Eltwise },
{ "Select", Type::Eltwise},
{ "Reshape", Type::Reshape },
{ "Squeeze", Type::Reshape },
{ "Unsqueeze", Type::Reshape },
@ -143,7 +144,6 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
{ "GridSample", Type::GridSample},
{ "OneHot", Type::OneHot},
{ "RegionYolo", Type::RegionYolo},
{ "Select", Type::Select},
{ "ShuffleChannels", Type::ShuffleChannels},
{ "DFT", Type::DFT},
{ "IDFT", Type::DFT},
@ -335,8 +335,6 @@ std::string NameFromType(const Type type) {
return "OneHot";
case Type::RegionYolo:
return "RegionYolo";
case Type::Select:
return "Select";
case Type::Roll:
return "Roll";
case Type::ShuffleChannels:
@ -450,6 +448,7 @@ std::string algToString(const Algorithm alg) {
CASE(EltwiseGelu);
CASE(EltwiseElu);
CASE(EltwiseTanh);
CASE(EltwiseSelect);
CASE(EltwiseSigmoid);
CASE(EltwiseAbs);
CASE(EltwiseSqrt);

View File

@ -75,7 +75,6 @@ enum class Type {
GridSample,
OneHot,
RegionYolo,
Select,
Roll,
Reference,
ShuffleChannels,
@ -165,6 +164,7 @@ enum class Algorithm {
EltwiseTanh,
EltwiseSigmoid,
EltwiseAbs,
EltwiseSelect,
EltwiseSqrt,
EltwiseSoftRelu,
EltwiseExp,

View File

@ -1663,8 +1663,9 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) {
int outNum = parentNode->getParentEdges().size();
if (remEdge) {
inNum = remEdge->getInputNum();
// Need to keep order for MulAdd
if (childNode->getAlgorithm() == Algorithm::EltwiseMulAdd) {
// Need to keep order for these algorithms
if (childNode->getAlgorithm() == Algorithm::EltwiseMulAdd ||
childNode->getAlgorithm() == Algorithm::EltwiseSelect) {
outNum = initialParentInNum + remEdge->getOutputNum() - 1;
}
graph.RemoveEdge(remEdge);

View File

@ -528,7 +528,8 @@ private:
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter),
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter));
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
if (precisions.empty())
IE_THROW() << "Unsupported operation type for Eltwise emitter";
@ -589,7 +590,8 @@ private:
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter),
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter));
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
if (!ctx.emitter)
IE_THROW() << "Unsupported operation type for Eltwise emitter";
@ -886,8 +888,8 @@ private:
};
Eltwise::BroadcastingPolicy Eltwise::determineBroadcastingPolicy(const std::shared_ptr<ngraph::Node>& op) {
const auto const1 = std::dynamic_pointer_cast<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(0));
const auto const2 = std::dynamic_pointer_cast<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(1));
const auto const1 = ov::as_type_ptr<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(0));
const auto const2 = ov::as_type_ptr<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(1));
int constPort = -1;
if (const2) {
constPort = 1;
@ -1103,6 +1105,9 @@ const std::map<const ngraph::DiscreteTypeInfo, Eltwise::Initializer> Eltwise::in
{ngraph::op::v9::SoftSign::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseSoftSign;
}},
{ngraph::op::v1::Select::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseSelect;
}},
};
@ -1593,6 +1598,7 @@ public:
(_opData.beta && (src_f[0] == std::numeric_limits<float>::infinity()));
break;
case Algorithm::EltwiseIsNaN: *dst_ptr_f = std::isnan(src_f[0]); break;
case Algorithm::EltwiseSelect: *dst_ptr_f = src_f[0] ? src_f[1] : src_f[2]; break;
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
}
}
@ -1653,13 +1659,20 @@ bool Eltwise::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op
errorMessage = "Doesn't support Eltwise algorithm: " + std::string(op->get_type_name());
return false;
}
if (const auto binOp = std::dynamic_pointer_cast<const ov::op::util::BinaryElementwiseArithmetic>(op)) {
if (const auto binOp = ov::as_type_ptr<const ov::op::util::BinaryElementwiseArithmetic>(op)) {
if (binOp->get_autob().m_type != ngraph::op::AutoBroadcastType::NONE &&
binOp->get_autob().m_type != ngraph::op::AutoBroadcastType::NUMPY) {
errorMessage = "Doesn't support broadcast type: " + ngraph::as_string(binOp->get_autob().m_type);
return false;
}
}
if (const auto select = ov::as_type_ptr<const ov::op::v1::Select>(op)) {
if (select->get_auto_broadcast().m_type != ngraph::op::AutoBroadcastType::NONE &&
select->get_auto_broadcast().m_type != ngraph::op::AutoBroadcastType::NUMPY) {
errorMessage = "Doesn't support broadcast type: " + ngraph::as_string(select->get_autob().m_type);
return false;
}
}
} catch (...) {
return false;
}
@ -1723,6 +1736,7 @@ size_t Eltwise::getOpInputsNum() const {
case Algorithm::EltwisePrelu:
return 2;
case Algorithm::EltwiseMulAdd:
case Algorithm::EltwiseSelect:
return 3;
default: IE_THROW() << "Unsupported operation for Eltwise node with name `" << getName() << "`.";
}
@ -2404,7 +2418,8 @@ bool Eltwise::canFuse(const NodePtr& node) const {
Algorithm::EltwiseGreaterEqual,
Algorithm::EltwiseLess,
Algorithm::EltwiseLessEqual,
Algorithm::EltwiseMulAdd)) {
Algorithm::EltwiseMulAdd,
Algorithm::EltwiseSelect)) {
return false;
}

View File

@ -1,245 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <cmath>
#include <vector>
#include <string>
#include "ie_parallel.hpp"
#include "select.h"
#include <nodes/common/blocked_desc_creator.h>
#include <ngraph/opsets/opset1.hpp>
#include <utils/general_utils.h>
#include "common/cpu_memcpy.h"
using namespace InferenceEngine;
namespace ov {
namespace intel_cpu {
namespace node {
bool Select::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
const auto select = std::dynamic_pointer_cast<const ngraph::opset1::Select>(op);
if (!select) {
errorMessage = "Only opset1 Select operation is supported";
return false;
}
const auto broadcast = select->get_auto_broadcast();
if (!one_of(broadcast.m_type, ngraph::op::AutoBroadcastType::NONE, ngraph::op::AutoBroadcastType::NUMPY)) {
errorMessage = "Does not support broadcast type: " + ngraph::as_string(broadcast.m_type);
return false;
}
} catch (...) {
return false;
}
return true;
}
Select::Select(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage;
}
errorPrefix = "Select layer with name '" + op->get_friendly_name() + "'";
const auto select = std::dynamic_pointer_cast<const ngraph::opset1::Select>(op);
if (inputShapes.size() != numOfInputs || outputShapes.size() != 1)
IE_THROW() << errorPrefix << " has incorrect number of input/output edges!";
const auto broadcast = select->get_auto_broadcast();
if (broadcast.m_type == ngraph::op::AutoBroadcastType::NONE) {
broadcastType = SelectBroadcastType::NONE;
} else if (broadcast.m_type == ngraph::op::AutoBroadcastType::NUMPY) {
broadcastType = SelectBroadcastType::NUMPY;
} else {
IE_THROW() << errorPrefix << " has unsupported broadcast type: " + ngraph::as_string(broadcast.m_type);
}
const auto &inCondDims = getInputShapeAtPort(CONDITION).getDims();
const auto &inThenDims = getInputShapeAtPort(THEN).getDims();
const auto &inElseDims = getInputShapeAtPort(ELSE).getDims();
const auto &outputDims = getOutputShapeAtPort(0).getDims();
if (broadcastType == SelectBroadcastType::NONE && (!dimsEqualWeak(inCondDims, outputDims) || !dimsEqualWeak(inThenDims, outputDims) ||
!dimsEqualWeak(inElseDims, outputDims))) {
IE_THROW() << errorPrefix << " and auto_broadcast='none' has input shapes mismatch";
}
if (broadcastType == SelectBroadcastType::NUMPY) {
if (outputDims.size() < inCondDims.size() || outputDims.size() < inThenDims.size() || outputDims.size() < inElseDims.size())
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible input and output shapes";
for (int condIt = inCondDims.size() - 1, outIt = outputDims.size() - 1; condIt >= 0; condIt--, outIt--)
if (!dimsEqualWeak(inCondDims[condIt], outputDims[outIt]) && !dimsEqualWeak(inCondDims[condIt], 1))
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Condition' input and output shapes";
for (int thenIt = inThenDims.size() - 1, outIt = outputDims.size() - 1; thenIt >= 0; thenIt--, outIt--)
if (!dimsEqualWeak(inThenDims[thenIt], outputDims[outIt]) && !dimsEqualWeak(inThenDims[thenIt], 1))
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Then' input and output shapes";
for (int elseIt = inElseDims.size() - 1, outIt = outputDims.size() - 1; elseIt >= 0; elseIt--, outIt--)
if (!dimsEqualWeak(inElseDims[elseIt], outputDims[outIt]) && !dimsEqualWeak(inElseDims[elseIt], 1))
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Else' input and output shapes";
}
resDims.resize(numOfDims, 1);
if (broadcastType == SelectBroadcastType::NUMPY) {
resOffset.resize(numOfDims);
condOffset.resize(numOfDims);
thenOffset.resize(numOfDims);
elseOffset.resize(numOfDims);
condDims.resize(numOfDims, 1);
thenDims.resize(numOfDims, 1);
elseDims.resize(numOfDims, 1);
}
}
void Select::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
const auto inputThenPrecision = getOriginalInputPrecisionAtPort(THEN);
const auto inputElsePrecision = getOriginalInputPrecisionAtPort(ELSE);
auto inputPrecision = inputThenPrecision;
if (inputThenPrecision == Precision::BF16 || inputElsePrecision == Precision::BF16) {
inputPrecision = Precision::BF16;
} else if (inputThenPrecision != inputElsePrecision) {
IE_THROW() << errorPrefix << " has different precisions on 'Then' and 'Else' inputs ";
}
const auto conditionPrecision = getOriginalInputPrecisionAtPort(CONDITION);
if (conditionPrecision != Precision::BOOL && conditionPrecision != Precision::I32 && conditionPrecision != Precision::U8)
IE_THROW() << errorPrefix << " has unsupported precision: " << conditionPrecision << " on 'Condition' input";
const auto inputPrecisionSize = inputPrecision.size();
if (inputPrecisionSize != 1 && inputPrecisionSize != 2 && inputPrecisionSize != 4 && inputPrecisionSize != 8)
IE_THROW() << errorPrefix << " has unsupported precision: " << inputPrecision << " on 'Then' and 'Else' inputs";
addSupportedPrimDesc({{LayoutType::ncsp, conditionPrecision},
{LayoutType::ncsp, inputPrecision},
{LayoutType::ncsp, inputPrecision}},
{{LayoutType::ncsp, inputPrecision}},
impl_desc_type::ref_any);
}
void Select::prepareParams() {
const auto &_conditionDims = getParentEdgesAtPort(CONDITION)[0]->getMemory().getStaticDims();
const auto &_thenDims = getParentEdgesAtPort(THEN)[0]->getMemory().getStaticDims();
const auto &_elseDims = getParentEdgesAtPort(ELSE)[0]->getMemory().getStaticDims();
const auto &_outputDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
std::fill(resDims.begin(), resDims.end(), 1);
std::copy(std::begin(_outputDims), std::end(_outputDims), std::begin(resDims) + (numOfDims - _outputDims.size()));
if (broadcastType == SelectBroadcastType::NUMPY) {
std::fill(resOffset.begin(), resOffset.end(), 1);
calcOutOffset(resOffset, resDims);
std::fill(condDims.begin(), condDims.end(), 1);
std::copy(std::begin(_conditionDims), std::end(_conditionDims), std::begin(condDims) + (numOfDims - _conditionDims.size()));
std::fill(condOffset.begin(), condOffset.end(), 1);
calcInOffset(condOffset, condDims, resDims);
std::fill(thenDims.begin(), thenDims.end(), 1);
std::copy(std::begin(_thenDims), std::end(_thenDims), std::begin(thenDims) + (numOfDims - _thenDims.size()));
std::fill(thenOffset.begin(), thenOffset.end(), 1);
calcInOffset(thenOffset, thenDims, resDims);
std::fill(elseDims.begin(), elseDims.end(), 1);
std::copy(std::begin(_elseDims), std::end(_elseDims), std::begin(elseDims) + (numOfDims - _elseDims.size()));
std::fill(elseOffset.begin(), elseOffset.end(), 1);
calcInOffset(elseOffset, elseDims, resDims);
}
}
void Select::calcOutOffset(VectorDims& offset, const VectorDims& dims) {
int k = 1;
for (int i = dims.size() - 1; i >= 0; i--) {
offset[i] = k;
k *= dims[i];
}
}
void Select::calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims) {
int k = 1;
for (int i = inDims.size() - 1; i >= 0; i--) {
offset[i] = (inDims[i] == outDims[i]) ? k : 0;
k *= inDims[i];
}
}
template <typename COND_T, typename DATA_T>
void Select::execute_impl() {
const auto *conditionData = reinterpret_cast<const COND_T *>(getParentEdgeAt(CONDITION)->getMemoryPtr()->GetPtr());
const auto *thenData = reinterpret_cast<const DATA_T *>(getParentEdgeAt(THEN)->getMemoryPtr()->GetPtr());
const auto *elseData = reinterpret_cast<const DATA_T *>(getParentEdgeAt(ELSE)->getMemoryPtr()->GetPtr());
auto *dstData = reinterpret_cast<DATA_T *>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
if (broadcastType == SelectBroadcastType::NONE) {
size_t dstDataSize = std::accumulate(begin(resDims), end(resDims), size_t(1), std::multiplies<size_t>());
parallel_for(dstDataSize, [&](size_t i) {
dstData[i] = conditionData[i] ? thenData[i] : elseData[i];
});
} else {
parallel_for4d(resDims[N], resDims[C], resDims[D], resDims[H], [&](int b, int c, int d, int h) {
for (int w = 0; w < resDims[W]; w++) {
size_t indexOut = b * resOffset[N] + c * resOffset[C] + d * resOffset[D] + h * resOffset[H] + w * resOffset[W];
size_t indexCond = b * condOffset[N] + c * condOffset[C] + d * condOffset[D] + h * condOffset[H] + w * condOffset[W];
size_t indexThen = b * thenOffset[N] + c * thenOffset[C] + d * thenOffset[D] + h * thenOffset[H] + w * thenOffset[W];
size_t indexElse = b * elseOffset[N] + c * elseOffset[C] + d * elseOffset[D] + h * elseOffset[H] + w * elseOffset[W];
dstData[indexOut] = conditionData[indexCond] ? thenData[indexThen] : elseData[indexElse];
}
});
}
}
void Select::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}
void Select::execute(dnnl::stream strm) {
const size_t condPrecSize = getParentEdgeAt(CONDITION)->getMemory().getDesc().getPrecision().size();
const size_t inputsPrecSize = getParentEdgeAt(THEN)->getMemory().getDesc().getPrecision().size();
switch (condPrecSize) {
case 1: {
switch (inputsPrecSize) {
case 1: { execute_impl<uint8_t, uint8_t>(); break; }
case 2: { execute_impl<uint8_t, uint16_t>(); break; }
case 4: { execute_impl<uint8_t, uint32_t>(); break; }
case 8: { execute_impl<uint8_t, uint64_t>(); break; }
default:
IE_THROW() << "Select layer doesn't support 'Then' and 'Else' inputs' precision: "
+ std::string(getParentEdgeAt(THEN)->getMemory().getDesc().getPrecision().name());
}
break;
}
case 4: {
switch (inputsPrecSize) {
case 1: { execute_impl<int32_t, uint8_t>(); break; }
case 2: { execute_impl<int32_t, uint16_t>(); break; }
case 4: { execute_impl<int32_t, uint32_t>(); break; }
case 8: { execute_impl<int32_t, uint64_t>(); break; }
default:
IE_THROW() << "Select layer doesn't support 'Then' and 'Else' inputs' precision: "
+ std::string(getParentEdgeAt(THEN)->getMemory().getDesc().getPrecision().name());
}
break;
}
default: {
IE_THROW() << "Select layer doesn't support 'Condition' inputs' precision: "
+ std::string(getParentEdgeAt(CONDITION)->getMemory().getDesc().getPrecision().name());
}
}
}
bool Select::created() const {
return getType() == Type::Select;
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -1,60 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_common.h>
#include <node.h>
#include <string>
#include <memory>
#include <vector>
namespace ov {
namespace intel_cpu {
namespace node {
class Select : public Node {
public:
Select(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override;
bool created() const override;
void executeDynamicImpl(dnnl::stream strm) override;
void prepareParams() override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
private:
enum { CONDITION, THEN, ELSE, numOfInputs };
enum { N, C, D, H, W, numOfDims };
enum class SelectBroadcastType {
NONE,
NUMPY
};
SelectBroadcastType broadcastType;
VectorDims resDims;
VectorDims resOffset;
VectorDims condOffset;
VectorDims thenOffset;
VectorDims elseOffset;
VectorDims condDims;
VectorDims thenDims;
VectorDims elseDims;
std::string errorPrefix;
void calcOutOffset(VectorDims& offset, const VectorDims& dims);
void calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims);
template <typename COND_T, typename DATA_T>
void execute_impl();
};
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -50,7 +50,6 @@
#include "nodes/concat.h"
#include "nodes/softmax.h"
#include "nodes/space_to_batch.h"
#include "nodes/select.h"
#include "nodes/topk.h"
#include "nodes/broadcast.h"
#include "nodes/matrix_nms.h"
@ -137,7 +136,6 @@ Node::NodesFactory::NodesFactory()
INTEL_CPU_NODE(DeformableConvolution, Type::DeformableConvolution);
INTEL_CPU_NODE(ReorgYolo, Type::ReorgYolo);
INTEL_CPU_NODE(EmbeddingSegmentsSum, Type::EmbeddingSegmentsSum);
INTEL_CPU_NODE(Select, Type::Select);
INTEL_CPU_NODE(ShapeOf, Type::ShapeOf);
INTEL_CPU_NODE(ExperimentalDetectronGenerateProposalsSingleImage, Type::ExperimentalDetectronGenerateProposalsSingleImage);
INTEL_CPU_NODE(GenerateProposals, Type::GenerateProposals);

View File

@ -1,174 +1,205 @@
//// Copyright (C) 2018-2023 Intel Corporation
//// SPDX-License-Identifier: Apache-2.0
////
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
//#include "test_utils/cpu_test_utils.hpp"
//#include "ngraph_functions/builders.hpp"
//
//using namespace ngraph;
//using namespace InferenceEngine;
//using namespace CPUTestUtils;
//
//namespace CPULayerTestsDefinitions {
//
//using selectParams = std::tuple<
// std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>, // input shapes
// ngraph::op::AutoBroadcastSpec>; // broadcast
//
//class SelectLayerCPUTest : public testing::WithParamInterface<selectParams>, public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase {
//public:
// static std::string getTestCaseName(testing::TestParamInfo<selectParams> obj) {
// std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
// ngraph::op::AutoBroadcastSpec broadcast;
// std::tie(shapes, broadcast) = obj.param;
//
// std::ostringstream result;
// if (!shapes.first.empty()) {
// result << "IS=" << CommonTestUtils::partialShape2str(shapes.first) << "_";
// }
// result << "TS=";
// for (const auto& shape : shapes.second) {
// result << "(";
// for (const auto& item : shape) {
// result << CommonTestUtils::vec2str(item) << "_";
// }
// result << ")_";
// }
// result << "Broadcast=" << broadcast.m_type;
//
// return result.str();
// }
//
//protected:
// void SetUp() override {
// targetDevice = CommonTestUtils::DEVICE_CPU;
//
// std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
// ngraph::op::AutoBroadcastSpec broadcast;
// std::tie(shapes, broadcast) = this->GetParam();
//
// for (size_t i = 0; i < shapes.second.size(); i++) {
// targetStaticShapes.push_back(shapes.second[i]);
// }
// inputDynamicShapes = shapes.first;
//
// selectedType = std::string("ref_any_") + Precision(Precision::I8).name();
//
// ngraph::ParameterVector paramNodesVector;
// auto paramNode = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::Type_t::boolean, ngraph::Shape(targetStaticShapes[0][0]));
// paramNodesVector.push_back(paramNode);
// auto inType = ngraph::element::Type_t::f32;
// for (size_t i = 1; i < targetStaticShapes[0].size(); i++) {
// paramNode = std::make_shared<ngraph::opset1::Parameter>(inType, ngraph::Shape(targetStaticShapes[0][i]));
// paramNodesVector.push_back(paramNode);
// }
// auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(paramNodesVector));
//
// auto select = ngraph::builder::makeSelect(paramOuts, broadcast);
//
// function = std::make_shared<ngraph::Function>(select, paramNodesVector, "SelectLayerCPUTest");
// functionRefs = ngraph::clone_function(*function);
// }
//};
//
//TEST_P(SelectLayerCPUTest, CompareWithRefs) {
// run();
// CheckPluginRelatedResults(executableNetwork, "Select");
//}
//
//std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>> inShapesDynamicNumpy = {
// {
// // dynamic
// {
// {-1, -1, -1, -1},
// {-1, -1, -1, -1, -1},
// {-1, -1, -1, -1}
// },
//
// // target
// {
// {{5, 1, 2, 1}, {8, 1, 9, 1, 1}, {5, 1, 2, 1}},
// {{1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1}},
// {{5, 9, 8, 7}, {21, 5, 9, 8, 7}, {1, 1, 1, 1}},
// }
// },
// {
// // dynamic
// {
// {-1, -1},
// {-1, -1, -1, -1, -1},
// {-1, -1, -1}
// },
//
// // target
// {
// {{8, 1}, {2, 1, 1, 8, 1}, {9, 1, 1}},
// {{10, 5}, {7, 8, 3, 10, 5}, {3, 10, 5}},
// {{8, 7}, {1, 1, 1, 8, 1}, {1, 1, 7}},
// }
// },
// {
// // dynamic
// {
// {{2, 8}, {3, 7}, {1, 10}, {1, 6}, {1, 10}},
// {-1, -1, -1, -1, -1},
// {{1, 5}, {1, 11}, {5, 5}, {1, 8}}
// },
//
// // target
// {
// {{5, 4, 1, 1, 1}, {5, 1, 8, 1, 1}, {1, 1, 5, 1}},
// {{8, 5, 5, 5, 1}, {8, 1, 1, 1, 8}, {5, 5, 5, 8}},
// {{2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {3, 4, 5, 6}},
// }
// },
// {
// // dynamic
// {
// {{1, 10}},
// {{1, 15}, {2, 7}, {1, 6}, {5, 12}, {1, 20}},
// {{2, 10}, {1, 16}}
// },
//
// // target
// {
// {{4}, {8, 5, 6, 6, 1}, {6, 4}},
// {{10}, {15, 7, 6, 10, 10}, {10, 10}},
// {{1}, {2, 5, 4, 5, 3}, {5, 1}},
// }
// }
//};
//
//const auto numpyCases = ::testing::Combine(
// ::testing::ValuesIn(inShapesDynamicNumpy),
// ::testing::Values(ngraph::op::AutoBroadcastType::NUMPY)
//);
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNumpy_dynamic, SelectLayerCPUTest, numpyCases, SelectLayerCPUTest::getTestCaseName);
//
//std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>> inShapesDynamicNone = {
// {
// // dynamic
// {
// {{1, 10}, -1, {10, 20}, {1, 5}},
// {-1, {16, 16}, -1, -1},
// {-1, -1, -1, -1}
// },
//
// // target
// {
// {{3, 16, 15, 5}, {3, 16, 15, 5}, {3, 16, 15, 5}},
// {{1, 16, 10, 1}, {1, 16, 10, 1}, {1, 16, 10, 1}},
// {{10, 16, 20, 5}, {10, 16, 20, 5}, {10, 16, 20, 5}}
// }
// }
//};
//
//const auto noneCases = ::testing::Combine(
// ::testing::ValuesIn(inShapesDynamicNone),
// ::testing::Values(ngraph::op::AutoBroadcastType::NONE)
//);
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNone_dynamic, SelectLayerCPUTest, noneCases, SelectLayerCPUTest::getTestCaseName);
//
//} // namespace CPULayerTestsDefinitions
#include <common_test_utils/ov_tensor_utils.hpp>
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/fusing_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov::test;
namespace CPULayerTestsDefinitions {
using selectParams = std::tuple<std::vector<InputShape>, // input shapes
ElementType, // Then/Else precision
ngraph::op::AutoBroadcastSpec, // broadcast
fusingSpecificParams>;
class SelectLayerCPUTest : public testing::WithParamInterface<selectParams>,
virtual public SubgraphBaseTest,
public CpuTestWithFusing {
public:
static std::string getTestCaseName(testing::TestParamInfo<selectParams> obj) {
std::vector<InputShape> shapes;
ElementType precision;
ngraph::op::AutoBroadcastSpec broadcast;
fusingSpecificParams fusingParams;
std::tie(shapes, precision, broadcast, fusingParams) = obj.param;
std::ostringstream result;
result << "Condition_prc_" << ElementType::boolean << "_Then_Else_prc_" << precision << "_";
result << "IS=(";
for (const auto& shape : shapes) {
result << shape.first << "_";
}
result << ")_TS=(";
for (const auto& shape : shapes) {
for (const auto& item : shape.second) {
result << CommonTestUtils::vec2str(item) << "_";
}
}
result << "Broadcast=" << broadcast.m_type;
result << CpuTestWithFusing::getTestCaseName(fusingParams);
return result.str();
}
protected:
void SetUp() override {
abs_threshold = 0;
targetDevice = CommonTestUtils::DEVICE_CPU;
std::vector<InputShape> shapes;
ElementType precision;
ngraph::op::AutoBroadcastSpec broadcast;
fusingSpecificParams fusingParams;
std::tie(shapes, precision, broadcast, fusingParams) = this->GetParam();
init_input_shapes(shapes);
std::tie(inFmts, outFmts, priority, selectedType) = emptyCPUSpec;
selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::i8);
auto parameters = ngraph::builder::makeDynamicParams(ov::element::TypeVector{ov::element::boolean, precision, precision}, inputDynamicShapes);
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(parameters));
auto select = ngraph::builder::makeSelect(paramOuts, broadcast);
function = makeNgraphFunction(precision, parameters, select, "Eltwise");
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& modelInputs = function->inputs();
auto condTensor = ov::test::utils::create_and_fill_tensor(modelInputs[0].get_element_type(), targetInputStaticShapes[0], 3, -1, 2);
auto thenTensor = ov::test::utils::create_and_fill_tensor(modelInputs[1].get_element_type(), targetInputStaticShapes[1], 10, -10, 2);
auto elseTensor = ov::test::utils::create_and_fill_tensor(modelInputs[2].get_element_type(), targetInputStaticShapes[2], 10, 0, 2);
inputs.insert({modelInputs[0].get_node_shared_ptr(), condTensor});
inputs.insert({modelInputs[1].get_node_shared_ptr(), thenTensor});
inputs.insert({modelInputs[2].get_node_shared_ptr(), elseTensor});
}
};
TEST_P(SelectLayerCPUTest, CompareWithRefs) {
run();
CheckPluginRelatedResults(compiledModel, "Eltwise");
}
const std::vector<ElementType> precisions = {
ElementType::f32,
ElementType::i32,
ElementType::bf16,
ElementType::i8
};
const std::vector<fusingSpecificParams> fusingParamsSet{
emptyFusingSpec,
fusingSigmoid,
fusingMultiplyAddPerChannel,
};
const std::vector<std::vector<InputShape>> inShapesDynamicNumpy = {
{
// Condition
{
{-1, -1, -1, -1},
{{5, 1, 2, 1}, {1, 1, 1, 1}, {5, 9, 8, 7}}
},
// Then
{
{-1, -1, -1, -1, -1},
{{8, 1, 9, 1, 1}, {1, 1, 1, 1, 1}, {21, 5, 9, 8, 7}}
},
// Else
{
{-1, -1, -1, -1},
{{5, 1, 2, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}
},
},
{
// Condition
{
{-1, -1},
{{8, 1}, {10, 5}, {8, 7}}
},
// Then
{
{-1, -1, -1, -1, -1},
{{2, 1, 1, 8, 1}, {7, 8, 3, 10, 5}, {1, 1, 1, 8, 1}}
},
// Else
{
{-1, -1, -1},
{{9, 1, 1}, {3, 10, 5}, {1, 1, 7}}
},
},
{
// Condition
{
{{2, 8}, {3, 7}, {1, 10}, {1, 6}, {1, 10}},
{{5, 4, 1, 1, 1}, {8, 5, 5, 5, 1}, {2, 3, 4, 5, 6}}
},
// Then
{
{-1, -1, -1, -1, -1},
{{5, 1, 8, 1, 1}, {8, 1, 1, 1, 8}, {2, 3, 4, 5, 6}}
},
// Else
{
{{1, 5}, {1, 11}, {5, 5}, {1, 8}},
{{1, 1, 5, 1}, {5, 5, 5, 8}, {3, 4, 5, 6}}
},
},
{
// Condition
{
{{1, 10}},
{{4}, {10}, {1}}
},
// Then
{
{{1, 15}, {2, 7}, {1, 6}, {5, 12}, {1, 20}},
{{8, 5, 6, 6, 1}, {15, 7, 6, 10, 10}, {2, 5, 4, 5, 3}}
},
// Else
{
{{2, 10}, {1, 16}},
{{6, 4}, {10, 10}, {5, 1}}
},
},
};
const auto numpyCases = ::testing::Combine(::testing::ValuesIn(inShapesDynamicNumpy),
::testing::ValuesIn(precisions),
::testing::Values(ngraph::op::AutoBroadcastType::NUMPY),
::testing::ValuesIn(fusingParamsSet));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNumpy_dynamic, SelectLayerCPUTest, numpyCases, SelectLayerCPUTest::getTestCaseName);
const std::vector<std::vector<InputShape>> inShapesDynamicNone = {
{
// Condition
{
{{1, 10}, -1, {10, 20}, {1, 5}},
{{3, 16, 15, 5}, {1, 16, 10, 1}, {10, 16, 20, 5}}
},
// Then
{
{-1, {16, 16}, -1, -1},
{{3, 16, 15, 5}, {1, 16, 10, 1}, {10, 16, 20, 5}}
},
// Else
{
{-1, -1, -1, -1},
{{3, 16, 15, 5}, {1, 16, 10, 1}, {10, 16, 20, 5}}
},
},
};
const auto noneCases = ::testing::Combine(::testing::ValuesIn(inShapesDynamicNone),
::testing::ValuesIn(precisions),
::testing::Values(ngraph::op::AutoBroadcastType::NONE),
::testing::ValuesIn(fusingParamsSet));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNone_dynamic, SelectLayerCPUTest, noneCases, SelectLayerCPUTest::getTestCaseName);
} // namespace CPULayerTestsDefinitions