[CPU] Select via Eltwise implementation (#15740)
This commit is contained in:
parent
113aefa3ff
commit
f0e12cf38b
@ -68,6 +68,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
|
||||
{ "Erf", Type::Eltwise },
|
||||
{ "SoftPlus", Type::Eltwise },
|
||||
{ "SoftSign", Type::Eltwise },
|
||||
{ "Select", Type::Eltwise},
|
||||
{ "Reshape", Type::Reshape },
|
||||
{ "Squeeze", Type::Reshape },
|
||||
{ "Unsqueeze", Type::Reshape },
|
||||
@ -143,7 +144,6 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
|
||||
{ "GridSample", Type::GridSample},
|
||||
{ "OneHot", Type::OneHot},
|
||||
{ "RegionYolo", Type::RegionYolo},
|
||||
{ "Select", Type::Select},
|
||||
{ "ShuffleChannels", Type::ShuffleChannels},
|
||||
{ "DFT", Type::DFT},
|
||||
{ "IDFT", Type::DFT},
|
||||
@ -335,8 +335,6 @@ std::string NameFromType(const Type type) {
|
||||
return "OneHot";
|
||||
case Type::RegionYolo:
|
||||
return "RegionYolo";
|
||||
case Type::Select:
|
||||
return "Select";
|
||||
case Type::Roll:
|
||||
return "Roll";
|
||||
case Type::ShuffleChannels:
|
||||
@ -450,6 +448,7 @@ std::string algToString(const Algorithm alg) {
|
||||
CASE(EltwiseGelu);
|
||||
CASE(EltwiseElu);
|
||||
CASE(EltwiseTanh);
|
||||
CASE(EltwiseSelect);
|
||||
CASE(EltwiseSigmoid);
|
||||
CASE(EltwiseAbs);
|
||||
CASE(EltwiseSqrt);
|
||||
|
@ -75,7 +75,6 @@ enum class Type {
|
||||
GridSample,
|
||||
OneHot,
|
||||
RegionYolo,
|
||||
Select,
|
||||
Roll,
|
||||
Reference,
|
||||
ShuffleChannels,
|
||||
@ -165,6 +164,7 @@ enum class Algorithm {
|
||||
EltwiseTanh,
|
||||
EltwiseSigmoid,
|
||||
EltwiseAbs,
|
||||
EltwiseSelect,
|
||||
EltwiseSqrt,
|
||||
EltwiseSoftRelu,
|
||||
EltwiseExp,
|
||||
|
@ -1663,8 +1663,9 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) {
|
||||
int outNum = parentNode->getParentEdges().size();
|
||||
if (remEdge) {
|
||||
inNum = remEdge->getInputNum();
|
||||
// Need to keep order for MulAdd
|
||||
if (childNode->getAlgorithm() == Algorithm::EltwiseMulAdd) {
|
||||
// Need to keep order for these algorithms
|
||||
if (childNode->getAlgorithm() == Algorithm::EltwiseMulAdd ||
|
||||
childNode->getAlgorithm() == Algorithm::EltwiseSelect) {
|
||||
outNum = initialParentInNum + remEdge->getOutputNum() - 1;
|
||||
}
|
||||
graph.RemoveEdge(remEdge);
|
||||
|
@ -528,7 +528,8 @@ private:
|
||||
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter),
|
||||
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
|
||||
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
|
||||
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter));
|
||||
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
|
||||
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
|
||||
|
||||
if (precisions.empty())
|
||||
IE_THROW() << "Unsupported operation type for Eltwise emitter";
|
||||
@ -589,7 +590,8 @@ private:
|
||||
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter),
|
||||
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
|
||||
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
|
||||
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter));
|
||||
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
|
||||
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
|
||||
|
||||
if (!ctx.emitter)
|
||||
IE_THROW() << "Unsupported operation type for Eltwise emitter";
|
||||
@ -886,8 +888,8 @@ private:
|
||||
};
|
||||
|
||||
Eltwise::BroadcastingPolicy Eltwise::determineBroadcastingPolicy(const std::shared_ptr<ngraph::Node>& op) {
|
||||
const auto const1 = std::dynamic_pointer_cast<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(0));
|
||||
const auto const2 = std::dynamic_pointer_cast<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(1));
|
||||
const auto const1 = ov::as_type_ptr<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(0));
|
||||
const auto const2 = ov::as_type_ptr<ngraph::opset1::Constant>(op->get_input_node_shared_ptr(1));
|
||||
int constPort = -1;
|
||||
if (const2) {
|
||||
constPort = 1;
|
||||
@ -1103,6 +1105,9 @@ const std::map<const ngraph::DiscreteTypeInfo, Eltwise::Initializer> Eltwise::in
|
||||
{ngraph::op::v9::SoftSign::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
|
||||
node.algorithm = Algorithm::EltwiseSoftSign;
|
||||
}},
|
||||
{ngraph::op::v1::Select::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
|
||||
node.algorithm = Algorithm::EltwiseSelect;
|
||||
}},
|
||||
};
|
||||
|
||||
|
||||
@ -1593,6 +1598,7 @@ public:
|
||||
(_opData.beta && (src_f[0] == std::numeric_limits<float>::infinity()));
|
||||
break;
|
||||
case Algorithm::EltwiseIsNaN: *dst_ptr_f = std::isnan(src_f[0]); break;
|
||||
case Algorithm::EltwiseSelect: *dst_ptr_f = src_f[0] ? src_f[1] : src_f[2]; break;
|
||||
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
|
||||
}
|
||||
}
|
||||
@ -1653,13 +1659,20 @@ bool Eltwise::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op
|
||||
errorMessage = "Doesn't support Eltwise algorithm: " + std::string(op->get_type_name());
|
||||
return false;
|
||||
}
|
||||
if (const auto binOp = std::dynamic_pointer_cast<const ov::op::util::BinaryElementwiseArithmetic>(op)) {
|
||||
if (const auto binOp = ov::as_type_ptr<const ov::op::util::BinaryElementwiseArithmetic>(op)) {
|
||||
if (binOp->get_autob().m_type != ngraph::op::AutoBroadcastType::NONE &&
|
||||
binOp->get_autob().m_type != ngraph::op::AutoBroadcastType::NUMPY) {
|
||||
errorMessage = "Doesn't support broadcast type: " + ngraph::as_string(binOp->get_autob().m_type);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (const auto select = ov::as_type_ptr<const ov::op::v1::Select>(op)) {
|
||||
if (select->get_auto_broadcast().m_type != ngraph::op::AutoBroadcastType::NONE &&
|
||||
select->get_auto_broadcast().m_type != ngraph::op::AutoBroadcastType::NUMPY) {
|
||||
errorMessage = "Doesn't support broadcast type: " + ngraph::as_string(select->get_autob().m_type);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
@ -1723,6 +1736,7 @@ size_t Eltwise::getOpInputsNum() const {
|
||||
case Algorithm::EltwisePrelu:
|
||||
return 2;
|
||||
case Algorithm::EltwiseMulAdd:
|
||||
case Algorithm::EltwiseSelect:
|
||||
return 3;
|
||||
default: IE_THROW() << "Unsupported operation for Eltwise node with name `" << getName() << "`.";
|
||||
}
|
||||
@ -2404,7 +2418,8 @@ bool Eltwise::canFuse(const NodePtr& node) const {
|
||||
Algorithm::EltwiseGreaterEqual,
|
||||
Algorithm::EltwiseLess,
|
||||
Algorithm::EltwiseLessEqual,
|
||||
Algorithm::EltwiseMulAdd)) {
|
||||
Algorithm::EltwiseMulAdd,
|
||||
Algorithm::EltwiseSelect)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1,245 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ie_parallel.hpp"
|
||||
#include "select.h"
|
||||
#include <nodes/common/blocked_desc_creator.h>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <utils/general_utils.h>
|
||||
#include "common/cpu_memcpy.h"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace node {
|
||||
|
||||
bool Select::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
const auto select = std::dynamic_pointer_cast<const ngraph::opset1::Select>(op);
|
||||
if (!select) {
|
||||
errorMessage = "Only opset1 Select operation is supported";
|
||||
return false;
|
||||
}
|
||||
const auto broadcast = select->get_auto_broadcast();
|
||||
if (!one_of(broadcast.m_type, ngraph::op::AutoBroadcastType::NONE, ngraph::op::AutoBroadcastType::NUMPY)) {
|
||||
errorMessage = "Does not support broadcast type: " + ngraph::as_string(broadcast.m_type);
|
||||
return false;
|
||||
}
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Select::Select(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
|
||||
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
|
||||
std::string errorMessage;
|
||||
if (!isSupportedOperation(op, errorMessage)) {
|
||||
IE_THROW(NotImplemented) << errorMessage;
|
||||
}
|
||||
|
||||
errorPrefix = "Select layer with name '" + op->get_friendly_name() + "'";
|
||||
const auto select = std::dynamic_pointer_cast<const ngraph::opset1::Select>(op);
|
||||
|
||||
if (inputShapes.size() != numOfInputs || outputShapes.size() != 1)
|
||||
IE_THROW() << errorPrefix << " has incorrect number of input/output edges!";
|
||||
|
||||
const auto broadcast = select->get_auto_broadcast();
|
||||
if (broadcast.m_type == ngraph::op::AutoBroadcastType::NONE) {
|
||||
broadcastType = SelectBroadcastType::NONE;
|
||||
} else if (broadcast.m_type == ngraph::op::AutoBroadcastType::NUMPY) {
|
||||
broadcastType = SelectBroadcastType::NUMPY;
|
||||
} else {
|
||||
IE_THROW() << errorPrefix << " has unsupported broadcast type: " + ngraph::as_string(broadcast.m_type);
|
||||
}
|
||||
|
||||
const auto &inCondDims = getInputShapeAtPort(CONDITION).getDims();
|
||||
const auto &inThenDims = getInputShapeAtPort(THEN).getDims();
|
||||
const auto &inElseDims = getInputShapeAtPort(ELSE).getDims();
|
||||
const auto &outputDims = getOutputShapeAtPort(0).getDims();
|
||||
|
||||
if (broadcastType == SelectBroadcastType::NONE && (!dimsEqualWeak(inCondDims, outputDims) || !dimsEqualWeak(inThenDims, outputDims) ||
|
||||
!dimsEqualWeak(inElseDims, outputDims))) {
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='none' has input shapes mismatch";
|
||||
}
|
||||
|
||||
if (broadcastType == SelectBroadcastType::NUMPY) {
|
||||
if (outputDims.size() < inCondDims.size() || outputDims.size() < inThenDims.size() || outputDims.size() < inElseDims.size())
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible input and output shapes";
|
||||
|
||||
for (int condIt = inCondDims.size() - 1, outIt = outputDims.size() - 1; condIt >= 0; condIt--, outIt--)
|
||||
if (!dimsEqualWeak(inCondDims[condIt], outputDims[outIt]) && !dimsEqualWeak(inCondDims[condIt], 1))
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Condition' input and output shapes";
|
||||
|
||||
for (int thenIt = inThenDims.size() - 1, outIt = outputDims.size() - 1; thenIt >= 0; thenIt--, outIt--)
|
||||
if (!dimsEqualWeak(inThenDims[thenIt], outputDims[outIt]) && !dimsEqualWeak(inThenDims[thenIt], 1))
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Then' input and output shapes";
|
||||
|
||||
for (int elseIt = inElseDims.size() - 1, outIt = outputDims.size() - 1; elseIt >= 0; elseIt--, outIt--)
|
||||
if (!dimsEqualWeak(inElseDims[elseIt], outputDims[outIt]) && !dimsEqualWeak(inElseDims[elseIt], 1))
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Else' input and output shapes";
|
||||
}
|
||||
|
||||
resDims.resize(numOfDims, 1);
|
||||
if (broadcastType == SelectBroadcastType::NUMPY) {
|
||||
resOffset.resize(numOfDims);
|
||||
condOffset.resize(numOfDims);
|
||||
thenOffset.resize(numOfDims);
|
||||
elseOffset.resize(numOfDims);
|
||||
|
||||
condDims.resize(numOfDims, 1);
|
||||
thenDims.resize(numOfDims, 1);
|
||||
elseDims.resize(numOfDims, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void Select::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
const auto inputThenPrecision = getOriginalInputPrecisionAtPort(THEN);
|
||||
const auto inputElsePrecision = getOriginalInputPrecisionAtPort(ELSE);
|
||||
auto inputPrecision = inputThenPrecision;
|
||||
if (inputThenPrecision == Precision::BF16 || inputElsePrecision == Precision::BF16) {
|
||||
inputPrecision = Precision::BF16;
|
||||
} else if (inputThenPrecision != inputElsePrecision) {
|
||||
IE_THROW() << errorPrefix << " has different precisions on 'Then' and 'Else' inputs ";
|
||||
}
|
||||
|
||||
const auto conditionPrecision = getOriginalInputPrecisionAtPort(CONDITION);
|
||||
if (conditionPrecision != Precision::BOOL && conditionPrecision != Precision::I32 && conditionPrecision != Precision::U8)
|
||||
IE_THROW() << errorPrefix << " has unsupported precision: " << conditionPrecision << " on 'Condition' input";
|
||||
|
||||
const auto inputPrecisionSize = inputPrecision.size();
|
||||
if (inputPrecisionSize != 1 && inputPrecisionSize != 2 && inputPrecisionSize != 4 && inputPrecisionSize != 8)
|
||||
IE_THROW() << errorPrefix << " has unsupported precision: " << inputPrecision << " on 'Then' and 'Else' inputs";
|
||||
|
||||
addSupportedPrimDesc({{LayoutType::ncsp, conditionPrecision},
|
||||
{LayoutType::ncsp, inputPrecision},
|
||||
{LayoutType::ncsp, inputPrecision}},
|
||||
{{LayoutType::ncsp, inputPrecision}},
|
||||
impl_desc_type::ref_any);
|
||||
}
|
||||
|
||||
void Select::prepareParams() {
|
||||
const auto &_conditionDims = getParentEdgesAtPort(CONDITION)[0]->getMemory().getStaticDims();
|
||||
const auto &_thenDims = getParentEdgesAtPort(THEN)[0]->getMemory().getStaticDims();
|
||||
const auto &_elseDims = getParentEdgesAtPort(ELSE)[0]->getMemory().getStaticDims();
|
||||
const auto &_outputDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
|
||||
std::fill(resDims.begin(), resDims.end(), 1);
|
||||
std::copy(std::begin(_outputDims), std::end(_outputDims), std::begin(resDims) + (numOfDims - _outputDims.size()));
|
||||
if (broadcastType == SelectBroadcastType::NUMPY) {
|
||||
std::fill(resOffset.begin(), resOffset.end(), 1);
|
||||
calcOutOffset(resOffset, resDims);
|
||||
|
||||
std::fill(condDims.begin(), condDims.end(), 1);
|
||||
std::copy(std::begin(_conditionDims), std::end(_conditionDims), std::begin(condDims) + (numOfDims - _conditionDims.size()));
|
||||
std::fill(condOffset.begin(), condOffset.end(), 1);
|
||||
calcInOffset(condOffset, condDims, resDims);
|
||||
|
||||
std::fill(thenDims.begin(), thenDims.end(), 1);
|
||||
std::copy(std::begin(_thenDims), std::end(_thenDims), std::begin(thenDims) + (numOfDims - _thenDims.size()));
|
||||
std::fill(thenOffset.begin(), thenOffset.end(), 1);
|
||||
calcInOffset(thenOffset, thenDims, resDims);
|
||||
|
||||
std::fill(elseDims.begin(), elseDims.end(), 1);
|
||||
std::copy(std::begin(_elseDims), std::end(_elseDims), std::begin(elseDims) + (numOfDims - _elseDims.size()));
|
||||
std::fill(elseOffset.begin(), elseOffset.end(), 1);
|
||||
calcInOffset(elseOffset, elseDims, resDims);
|
||||
}
|
||||
}
|
||||
|
||||
void Select::calcOutOffset(VectorDims& offset, const VectorDims& dims) {
|
||||
int k = 1;
|
||||
for (int i = dims.size() - 1; i >= 0; i--) {
|
||||
offset[i] = k;
|
||||
k *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
void Select::calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims) {
|
||||
int k = 1;
|
||||
for (int i = inDims.size() - 1; i >= 0; i--) {
|
||||
offset[i] = (inDims[i] == outDims[i]) ? k : 0;
|
||||
k *= inDims[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename COND_T, typename DATA_T>
|
||||
void Select::execute_impl() {
|
||||
const auto *conditionData = reinterpret_cast<const COND_T *>(getParentEdgeAt(CONDITION)->getMemoryPtr()->GetPtr());
|
||||
const auto *thenData = reinterpret_cast<const DATA_T *>(getParentEdgeAt(THEN)->getMemoryPtr()->GetPtr());
|
||||
const auto *elseData = reinterpret_cast<const DATA_T *>(getParentEdgeAt(ELSE)->getMemoryPtr()->GetPtr());
|
||||
auto *dstData = reinterpret_cast<DATA_T *>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
|
||||
|
||||
if (broadcastType == SelectBroadcastType::NONE) {
|
||||
size_t dstDataSize = std::accumulate(begin(resDims), end(resDims), size_t(1), std::multiplies<size_t>());
|
||||
parallel_for(dstDataSize, [&](size_t i) {
|
||||
dstData[i] = conditionData[i] ? thenData[i] : elseData[i];
|
||||
});
|
||||
} else {
|
||||
parallel_for4d(resDims[N], resDims[C], resDims[D], resDims[H], [&](int b, int c, int d, int h) {
|
||||
for (int w = 0; w < resDims[W]; w++) {
|
||||
size_t indexOut = b * resOffset[N] + c * resOffset[C] + d * resOffset[D] + h * resOffset[H] + w * resOffset[W];
|
||||
size_t indexCond = b * condOffset[N] + c * condOffset[C] + d * condOffset[D] + h * condOffset[H] + w * condOffset[W];
|
||||
size_t indexThen = b * thenOffset[N] + c * thenOffset[C] + d * thenOffset[D] + h * thenOffset[H] + w * thenOffset[W];
|
||||
size_t indexElse = b * elseOffset[N] + c * elseOffset[C] + d * elseOffset[D] + h * elseOffset[H] + w * elseOffset[W];
|
||||
dstData[indexOut] = conditionData[indexCond] ? thenData[indexThen] : elseData[indexElse];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Select::executeDynamicImpl(dnnl::stream strm) {
|
||||
execute(strm);
|
||||
}
|
||||
|
||||
void Select::execute(dnnl::stream strm) {
|
||||
const size_t condPrecSize = getParentEdgeAt(CONDITION)->getMemory().getDesc().getPrecision().size();
|
||||
const size_t inputsPrecSize = getParentEdgeAt(THEN)->getMemory().getDesc().getPrecision().size();
|
||||
|
||||
switch (condPrecSize) {
|
||||
case 1: {
|
||||
switch (inputsPrecSize) {
|
||||
case 1: { execute_impl<uint8_t, uint8_t>(); break; }
|
||||
case 2: { execute_impl<uint8_t, uint16_t>(); break; }
|
||||
case 4: { execute_impl<uint8_t, uint32_t>(); break; }
|
||||
case 8: { execute_impl<uint8_t, uint64_t>(); break; }
|
||||
default:
|
||||
IE_THROW() << "Select layer doesn't support 'Then' and 'Else' inputs' precision: "
|
||||
+ std::string(getParentEdgeAt(THEN)->getMemory().getDesc().getPrecision().name());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
switch (inputsPrecSize) {
|
||||
case 1: { execute_impl<int32_t, uint8_t>(); break; }
|
||||
case 2: { execute_impl<int32_t, uint16_t>(); break; }
|
||||
case 4: { execute_impl<int32_t, uint32_t>(); break; }
|
||||
case 8: { execute_impl<int32_t, uint64_t>(); break; }
|
||||
default:
|
||||
IE_THROW() << "Select layer doesn't support 'Then' and 'Else' inputs' precision: "
|
||||
+ std::string(getParentEdgeAt(THEN)->getMemory().getDesc().getPrecision().name());
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
IE_THROW() << "Select layer doesn't support 'Condition' inputs' precision: "
|
||||
+ std::string(getParentEdgeAt(CONDITION)->getMemory().getDesc().getPrecision().name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Select::created() const {
|
||||
return getType() == Type::Select;
|
||||
}
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -1,60 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ie_common.h>
|
||||
#include <node.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace node {
|
||||
|
||||
class Select : public Node {
|
||||
public:
|
||||
Select(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context);
|
||||
|
||||
void getSupportedDescriptors() override {};
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void execute(dnnl::stream strm) override;
|
||||
bool created() const override;
|
||||
|
||||
void executeDynamicImpl(dnnl::stream strm) override;
|
||||
void prepareParams() override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
enum { CONDITION, THEN, ELSE, numOfInputs };
|
||||
enum { N, C, D, H, W, numOfDims };
|
||||
enum class SelectBroadcastType {
|
||||
NONE,
|
||||
NUMPY
|
||||
};
|
||||
|
||||
SelectBroadcastType broadcastType;
|
||||
VectorDims resDims;
|
||||
VectorDims resOffset;
|
||||
VectorDims condOffset;
|
||||
VectorDims thenOffset;
|
||||
VectorDims elseOffset;
|
||||
|
||||
VectorDims condDims;
|
||||
VectorDims thenDims;
|
||||
VectorDims elseDims;
|
||||
|
||||
std::string errorPrefix;
|
||||
|
||||
void calcOutOffset(VectorDims& offset, const VectorDims& dims);
|
||||
void calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims);
|
||||
template <typename COND_T, typename DATA_T>
|
||||
void execute_impl();
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -50,7 +50,6 @@
|
||||
#include "nodes/concat.h"
|
||||
#include "nodes/softmax.h"
|
||||
#include "nodes/space_to_batch.h"
|
||||
#include "nodes/select.h"
|
||||
#include "nodes/topk.h"
|
||||
#include "nodes/broadcast.h"
|
||||
#include "nodes/matrix_nms.h"
|
||||
@ -137,7 +136,6 @@ Node::NodesFactory::NodesFactory()
|
||||
INTEL_CPU_NODE(DeformableConvolution, Type::DeformableConvolution);
|
||||
INTEL_CPU_NODE(ReorgYolo, Type::ReorgYolo);
|
||||
INTEL_CPU_NODE(EmbeddingSegmentsSum, Type::EmbeddingSegmentsSum);
|
||||
INTEL_CPU_NODE(Select, Type::Select);
|
||||
INTEL_CPU_NODE(ShapeOf, Type::ShapeOf);
|
||||
INTEL_CPU_NODE(ExperimentalDetectronGenerateProposalsSingleImage, Type::ExperimentalDetectronGenerateProposalsSingleImage);
|
||||
INTEL_CPU_NODE(GenerateProposals, Type::GenerateProposals);
|
||||
|
@ -1,174 +1,205 @@
|
||||
//// Copyright (C) 2018-2023 Intel Corporation
|
||||
//// SPDX-License-Identifier: Apache-2.0
|
||||
////
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//#include "test_utils/cpu_test_utils.hpp"
|
||||
//#include "ngraph_functions/builders.hpp"
|
||||
//
|
||||
//using namespace ngraph;
|
||||
//using namespace InferenceEngine;
|
||||
//using namespace CPUTestUtils;
|
||||
//
|
||||
//namespace CPULayerTestsDefinitions {
|
||||
//
|
||||
//using selectParams = std::tuple<
|
||||
// std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>, // input shapes
|
||||
// ngraph::op::AutoBroadcastSpec>; // broadcast
|
||||
//
|
||||
//class SelectLayerCPUTest : public testing::WithParamInterface<selectParams>, public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase {
|
||||
//public:
|
||||
// static std::string getTestCaseName(testing::TestParamInfo<selectParams> obj) {
|
||||
// std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
|
||||
// ngraph::op::AutoBroadcastSpec broadcast;
|
||||
// std::tie(shapes, broadcast) = obj.param;
|
||||
//
|
||||
// std::ostringstream result;
|
||||
// if (!shapes.first.empty()) {
|
||||
// result << "IS=" << CommonTestUtils::partialShape2str(shapes.first) << "_";
|
||||
// }
|
||||
// result << "TS=";
|
||||
// for (const auto& shape : shapes.second) {
|
||||
// result << "(";
|
||||
// for (const auto& item : shape) {
|
||||
// result << CommonTestUtils::vec2str(item) << "_";
|
||||
// }
|
||||
// result << ")_";
|
||||
// }
|
||||
// result << "Broadcast=" << broadcast.m_type;
|
||||
//
|
||||
// return result.str();
|
||||
// }
|
||||
//
|
||||
//protected:
|
||||
// void SetUp() override {
|
||||
// targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
//
|
||||
// std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
|
||||
// ngraph::op::AutoBroadcastSpec broadcast;
|
||||
// std::tie(shapes, broadcast) = this->GetParam();
|
||||
//
|
||||
// for (size_t i = 0; i < shapes.second.size(); i++) {
|
||||
// targetStaticShapes.push_back(shapes.second[i]);
|
||||
// }
|
||||
// inputDynamicShapes = shapes.first;
|
||||
//
|
||||
// selectedType = std::string("ref_any_") + Precision(Precision::I8).name();
|
||||
//
|
||||
// ngraph::ParameterVector paramNodesVector;
|
||||
// auto paramNode = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::Type_t::boolean, ngraph::Shape(targetStaticShapes[0][0]));
|
||||
// paramNodesVector.push_back(paramNode);
|
||||
// auto inType = ngraph::element::Type_t::f32;
|
||||
// for (size_t i = 1; i < targetStaticShapes[0].size(); i++) {
|
||||
// paramNode = std::make_shared<ngraph::opset1::Parameter>(inType, ngraph::Shape(targetStaticShapes[0][i]));
|
||||
// paramNodesVector.push_back(paramNode);
|
||||
// }
|
||||
// auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(paramNodesVector));
|
||||
//
|
||||
// auto select = ngraph::builder::makeSelect(paramOuts, broadcast);
|
||||
//
|
||||
// function = std::make_shared<ngraph::Function>(select, paramNodesVector, "SelectLayerCPUTest");
|
||||
// functionRefs = ngraph::clone_function(*function);
|
||||
// }
|
||||
//};
|
||||
//
|
||||
//TEST_P(SelectLayerCPUTest, CompareWithRefs) {
|
||||
// run();
|
||||
// CheckPluginRelatedResults(executableNetwork, "Select");
|
||||
//}
|
||||
//
|
||||
//std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>> inShapesDynamicNumpy = {
|
||||
// {
|
||||
// // dynamic
|
||||
// {
|
||||
// {-1, -1, -1, -1},
|
||||
// {-1, -1, -1, -1, -1},
|
||||
// {-1, -1, -1, -1}
|
||||
// },
|
||||
//
|
||||
// // target
|
||||
// {
|
||||
// {{5, 1, 2, 1}, {8, 1, 9, 1, 1}, {5, 1, 2, 1}},
|
||||
// {{1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1}},
|
||||
// {{5, 9, 8, 7}, {21, 5, 9, 8, 7}, {1, 1, 1, 1}},
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// // dynamic
|
||||
// {
|
||||
// {-1, -1},
|
||||
// {-1, -1, -1, -1, -1},
|
||||
// {-1, -1, -1}
|
||||
// },
|
||||
//
|
||||
// // target
|
||||
// {
|
||||
// {{8, 1}, {2, 1, 1, 8, 1}, {9, 1, 1}},
|
||||
// {{10, 5}, {7, 8, 3, 10, 5}, {3, 10, 5}},
|
||||
// {{8, 7}, {1, 1, 1, 8, 1}, {1, 1, 7}},
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// // dynamic
|
||||
// {
|
||||
// {{2, 8}, {3, 7}, {1, 10}, {1, 6}, {1, 10}},
|
||||
// {-1, -1, -1, -1, -1},
|
||||
// {{1, 5}, {1, 11}, {5, 5}, {1, 8}}
|
||||
// },
|
||||
//
|
||||
// // target
|
||||
// {
|
||||
// {{5, 4, 1, 1, 1}, {5, 1, 8, 1, 1}, {1, 1, 5, 1}},
|
||||
// {{8, 5, 5, 5, 1}, {8, 1, 1, 1, 8}, {5, 5, 5, 8}},
|
||||
// {{2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {3, 4, 5, 6}},
|
||||
// }
|
||||
// },
|
||||
// {
|
||||
// // dynamic
|
||||
// {
|
||||
// {{1, 10}},
|
||||
// {{1, 15}, {2, 7}, {1, 6}, {5, 12}, {1, 20}},
|
||||
// {{2, 10}, {1, 16}}
|
||||
// },
|
||||
//
|
||||
// // target
|
||||
// {
|
||||
// {{4}, {8, 5, 6, 6, 1}, {6, 4}},
|
||||
// {{10}, {15, 7, 6, 10, 10}, {10, 10}},
|
||||
// {{1}, {2, 5, 4, 5, 3}, {5, 1}},
|
||||
// }
|
||||
// }
|
||||
//};
|
||||
//
|
||||
//const auto numpyCases = ::testing::Combine(
|
||||
// ::testing::ValuesIn(inShapesDynamicNumpy),
|
||||
// ::testing::Values(ngraph::op::AutoBroadcastType::NUMPY)
|
||||
//);
|
||||
//
|
||||
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNumpy_dynamic, SelectLayerCPUTest, numpyCases, SelectLayerCPUTest::getTestCaseName);
|
||||
//
|
||||
//std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>> inShapesDynamicNone = {
|
||||
// {
|
||||
// // dynamic
|
||||
// {
|
||||
// {{1, 10}, -1, {10, 20}, {1, 5}},
|
||||
// {-1, {16, 16}, -1, -1},
|
||||
// {-1, -1, -1, -1}
|
||||
// },
|
||||
//
|
||||
// // target
|
||||
// {
|
||||
// {{3, 16, 15, 5}, {3, 16, 15, 5}, {3, 16, 15, 5}},
|
||||
// {{1, 16, 10, 1}, {1, 16, 10, 1}, {1, 16, 10, 1}},
|
||||
// {{10, 16, 20, 5}, {10, 16, 20, 5}, {10, 16, 20, 5}}
|
||||
// }
|
||||
// }
|
||||
//};
|
||||
//
|
||||
//const auto noneCases = ::testing::Combine(
|
||||
// ::testing::ValuesIn(inShapesDynamicNone),
|
||||
// ::testing::Values(ngraph::op::AutoBroadcastType::NONE)
|
||||
//);
|
||||
//
|
||||
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNone_dynamic, SelectLayerCPUTest, noneCases, SelectLayerCPUTest::getTestCaseName);
|
||||
//
|
||||
//} // namespace CPULayerTestsDefinitions
|
||||
|
||||
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ov::test;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
using selectParams = std::tuple<std::vector<InputShape>, // input shapes
|
||||
ElementType, // Then/Else precision
|
||||
ngraph::op::AutoBroadcastSpec, // broadcast
|
||||
fusingSpecificParams>;
|
||||
|
||||
class SelectLayerCPUTest : public testing::WithParamInterface<selectParams>,
|
||||
virtual public SubgraphBaseTest,
|
||||
public CpuTestWithFusing {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<selectParams> obj) {
|
||||
std::vector<InputShape> shapes;
|
||||
ElementType precision;
|
||||
ngraph::op::AutoBroadcastSpec broadcast;
|
||||
fusingSpecificParams fusingParams;
|
||||
std::tie(shapes, precision, broadcast, fusingParams) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "Condition_prc_" << ElementType::boolean << "_Then_Else_prc_" << precision << "_";
|
||||
result << "IS=(";
|
||||
for (const auto& shape : shapes) {
|
||||
result << shape.first << "_";
|
||||
}
|
||||
result << ")_TS=(";
|
||||
for (const auto& shape : shapes) {
|
||||
for (const auto& item : shape.second) {
|
||||
result << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
}
|
||||
result << "Broadcast=" << broadcast.m_type;
|
||||
result << CpuTestWithFusing::getTestCaseName(fusingParams);
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
abs_threshold = 0;
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
std::vector<InputShape> shapes;
|
||||
ElementType precision;
|
||||
ngraph::op::AutoBroadcastSpec broadcast;
|
||||
fusingSpecificParams fusingParams;
|
||||
std::tie(shapes, precision, broadcast, fusingParams) = this->GetParam();
|
||||
init_input_shapes(shapes);
|
||||
std::tie(inFmts, outFmts, priority, selectedType) = emptyCPUSpec;
|
||||
selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::i8);
|
||||
|
||||
auto parameters = ngraph::builder::makeDynamicParams(ov::element::TypeVector{ov::element::boolean, precision, precision}, inputDynamicShapes);
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(parameters));
|
||||
auto select = ngraph::builder::makeSelect(paramOuts, broadcast);
|
||||
|
||||
function = makeNgraphFunction(precision, parameters, select, "Eltwise");
|
||||
}
|
||||
|
||||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||
inputs.clear();
|
||||
const auto& modelInputs = function->inputs();
|
||||
auto condTensor = ov::test::utils::create_and_fill_tensor(modelInputs[0].get_element_type(), targetInputStaticShapes[0], 3, -1, 2);
|
||||
auto thenTensor = ov::test::utils::create_and_fill_tensor(modelInputs[1].get_element_type(), targetInputStaticShapes[1], 10, -10, 2);
|
||||
auto elseTensor = ov::test::utils::create_and_fill_tensor(modelInputs[2].get_element_type(), targetInputStaticShapes[2], 10, 0, 2);
|
||||
inputs.insert({modelInputs[0].get_node_shared_ptr(), condTensor});
|
||||
inputs.insert({modelInputs[1].get_node_shared_ptr(), thenTensor});
|
||||
inputs.insert({modelInputs[2].get_node_shared_ptr(), elseTensor});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(SelectLayerCPUTest, CompareWithRefs) {
|
||||
run();
|
||||
CheckPluginRelatedResults(compiledModel, "Eltwise");
|
||||
}
|
||||
|
||||
const std::vector<ElementType> precisions = {
|
||||
ElementType::f32,
|
||||
ElementType::i32,
|
||||
ElementType::bf16,
|
||||
ElementType::i8
|
||||
};
|
||||
|
||||
const std::vector<fusingSpecificParams> fusingParamsSet{
|
||||
emptyFusingSpec,
|
||||
fusingSigmoid,
|
||||
fusingMultiplyAddPerChannel,
|
||||
};
|
||||
|
||||
const std::vector<std::vector<InputShape>> inShapesDynamicNumpy = {
|
||||
{
|
||||
// Condition
|
||||
{
|
||||
{-1, -1, -1, -1},
|
||||
{{5, 1, 2, 1}, {1, 1, 1, 1}, {5, 9, 8, 7}}
|
||||
},
|
||||
// Then
|
||||
{
|
||||
{-1, -1, -1, -1, -1},
|
||||
{{8, 1, 9, 1, 1}, {1, 1, 1, 1, 1}, {21, 5, 9, 8, 7}}
|
||||
},
|
||||
// Else
|
||||
{
|
||||
{-1, -1, -1, -1},
|
||||
{{5, 1, 2, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}
|
||||
},
|
||||
},
|
||||
{
|
||||
// Condition
|
||||
{
|
||||
{-1, -1},
|
||||
{{8, 1}, {10, 5}, {8, 7}}
|
||||
},
|
||||
// Then
|
||||
{
|
||||
{-1, -1, -1, -1, -1},
|
||||
{{2, 1, 1, 8, 1}, {7, 8, 3, 10, 5}, {1, 1, 1, 8, 1}}
|
||||
},
|
||||
// Else
|
||||
{
|
||||
{-1, -1, -1},
|
||||
{{9, 1, 1}, {3, 10, 5}, {1, 1, 7}}
|
||||
},
|
||||
},
|
||||
{
|
||||
// Condition
|
||||
{
|
||||
{{2, 8}, {3, 7}, {1, 10}, {1, 6}, {1, 10}},
|
||||
{{5, 4, 1, 1, 1}, {8, 5, 5, 5, 1}, {2, 3, 4, 5, 6}}
|
||||
},
|
||||
// Then
|
||||
{
|
||||
{-1, -1, -1, -1, -1},
|
||||
{{5, 1, 8, 1, 1}, {8, 1, 1, 1, 8}, {2, 3, 4, 5, 6}}
|
||||
},
|
||||
// Else
|
||||
{
|
||||
{{1, 5}, {1, 11}, {5, 5}, {1, 8}},
|
||||
{{1, 1, 5, 1}, {5, 5, 5, 8}, {3, 4, 5, 6}}
|
||||
},
|
||||
},
|
||||
{
|
||||
// Condition
|
||||
{
|
||||
{{1, 10}},
|
||||
{{4}, {10}, {1}}
|
||||
},
|
||||
// Then
|
||||
{
|
||||
{{1, 15}, {2, 7}, {1, 6}, {5, 12}, {1, 20}},
|
||||
{{8, 5, 6, 6, 1}, {15, 7, 6, 10, 10}, {2, 5, 4, 5, 3}}
|
||||
},
|
||||
// Else
|
||||
{
|
||||
{{2, 10}, {1, 16}},
|
||||
{{6, 4}, {10, 10}, {5, 1}}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const auto numpyCases = ::testing::Combine(::testing::ValuesIn(inShapesDynamicNumpy),
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(ngraph::op::AutoBroadcastType::NUMPY),
|
||||
::testing::ValuesIn(fusingParamsSet));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNumpy_dynamic, SelectLayerCPUTest, numpyCases, SelectLayerCPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<InputShape>> inShapesDynamicNone = {
|
||||
{
|
||||
// Condition
|
||||
{
|
||||
{{1, 10}, -1, {10, 20}, {1, 5}},
|
||||
{{3, 16, 15, 5}, {1, 16, 10, 1}, {10, 16, 20, 5}}
|
||||
},
|
||||
// Then
|
||||
{
|
||||
{-1, {16, 16}, -1, -1},
|
||||
{{3, 16, 15, 5}, {1, 16, 10, 1}, {10, 16, 20, 5}}
|
||||
},
|
||||
// Else
|
||||
{
|
||||
{-1, -1, -1, -1},
|
||||
{{3, 16, 15, 5}, {1, 16, 10, 1}, {10, 16, 20, 5}}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const auto noneCases = ::testing::Combine(::testing::ValuesIn(inShapesDynamicNone),
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(ngraph::op::AutoBroadcastType::NONE),
|
||||
::testing::ValuesIn(fusingParamsSet));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNone_dynamic, SelectLayerCPUTest, noneCases, SelectLayerCPUTest::getTestCaseName);
|
||||
|
||||
} // namespace CPULayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user