[CPU][ARM] Perf fixes for FP16 precision (#18973)

This commit is contained in:
Egor Duplenskii 2023-08-14 07:22:03 +02:00 committed by GitHub
parent 4e96b6ba9d
commit f09d2e2666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 512 additions and 388 deletions

View File

@ -19,6 +19,7 @@
#include "graph_optimizer.h"
#include "dnnl_extension_utils.h"
#include "extension_mngr.h"
#include "ie_ngraph_utils.hpp"
#include "memory_solver.hpp"
#include "itt.h"
#include "infer_request.h"
@ -198,8 +199,8 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph) {
void Graph::Replicate(const CNNNetwork &network) {
OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "Graph::Replicate", "CNNNetwork");
InputsDataMap inputsInfo = network.getInputsInfo();
OutputsDataMap outputsInfo = network.getOutputsInfo();
const InputsDataMap& inputsInfo = network.getInputsInfo();
const OutputsDataMap& outputsInfo = network.getOutputsInfo();
this->_name = network.getName();
@ -290,8 +291,6 @@ void Graph::Replicate(const CNNNetwork &network) {
graphNodes.push_back(outNode);
}
EnforceInferencePrecision();
auto hasSubgraphConsumers = [] (const NodePtr& node) -> bool {
const auto & childEdges = node->getChildEdges();
return std::any_of(childEdges.begin(), childEdges.end(),
@ -302,13 +301,23 @@ void Graph::Replicate(const CNNNetwork &network) {
return edgePtr->getChild()->getType() == Type::Subgraph;
});
};
// change precision for input/output nodes to avoid extra data conversion when set input/output blobs
// also we need to change input/output precisions for consumers/producers to avoid inserting reorder
for (auto &input : inputNodesMap) {
const auto precToSet = normalizeToSupportedPrecision(inputsInfo.at(input.first)->getPrecision());
input.second->setOriginalOutputPrecisionAtPort(0, precToSet);
const auto childEdges = input.second->getChildEdgesAtPort(0);
}
for (auto &output : outputNodesMap) {
const auto precToSet = normalizeToSupportedPrecision(outputsInfo.at(output.first)->getPrecision());
output.second->setOriginalInputPrecisionAtPort(0, precToSet);
}
// enforce must be performed after inputs and outputs info are taken into account
EnforceInferencePrecision();
// also we need to change input/output precisions for consumers/producers to avoid inserting reorder
for (auto &input : inputNodesMap) {
const auto& inputNode = input.second;
const auto precToSet = inputNode->getOriginalOutputPrecisionAtPort(0);
const auto childEdges = inputNode->getChildEdgesAtPort(0);
for (size_t i = 0; i < childEdges.size(); i++) {
const auto child = childEdges[i]->getChild();
const auto child_prec = child->getOriginalInputPrecisionAtPort(childEdges[i]->getOutputNum());
@ -320,9 +329,9 @@ void Graph::Replicate(const CNNNetwork &network) {
}
for (auto &output : outputNodesMap) {
const auto precToSet = normalizeToSupportedPrecision(outputsInfo.at(output.first)->getPrecision());
output.second->setOriginalInputPrecisionAtPort(0, precToSet);
const auto parentEdges = output.second->getParentEdgesAtPort(0);
const auto& outputNode = output.second;
const auto precToSet = outputNode->getOriginalInputPrecisionAtPort(0);
const auto parentEdges = outputNode->getParentEdgesAtPort(0);
for (size_t i = 0; i < parentEdges.size(); i++) {
const auto parent = parentEdges[i]->getParent();
parent->setOriginalOutputPrecisionAtPort(parentEdges[i]->getInputNum(), precToSet);
@ -337,7 +346,7 @@ void Graph::Replicate(const CNNNetwork &network) {
} else {
outShape = inputNodesMap[input.first]->outputShapes.front();
}
InputInfo::Ptr ii = inputsInfo[input.first];
InputInfo::Ptr ii = input.second;
if (ii && ii->getPreProcess().getNumberOfChannels()) {
_normalizePreprocMap[input.first].Load(outShape, ii);
}
@ -1685,21 +1694,14 @@ bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPo
return true;
}
// Set all non const data paths precision to BF16
// Apply inference precision configuration
void Graph::EnforceInferencePrecision() {
CPU_DEBUG_CAP_ENABLE(static EnforceInferPrcDebug inferPrecDebug);
auto inferPrec = InferenceEngine::Precision::FP32;
switch (getConfig().inferencePrecision) {
case ov::element::bf16:
inferPrec = InferenceEngine::Precision::BF16;
break;
case ov::element::f16:
inferPrec = InferenceEngine::Precision::FP16;
break;
default:
return;
break;
}
const auto inferPrec = convertPrecision(getConfig().inferencePrecision);
if (inferPrec == Precision::FP32)
return; // nothing to do, only precision reduction is currently allowed
std::function<void(const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set<NodePtr>& skipNodes) -> void {
@ -1743,44 +1745,60 @@ void Graph::EnforceInferencePrecision() {
std::unordered_set<NodePtr> nodesToSkip;
// starting from output nodes
for (const auto& entry : outputNodesMap) {
const auto& node = entry.second;
if (node->getOriginalInputPrecisionAtPort(0) == Precision::BF16)
const auto& output = entry.second;
// do not skip outputs which precisions are explicitly set equal to inferPrec
if (output->getOriginalInputPrecisionAtPort(0) == inferPrec)
continue;
searchForNodesToSkip(node, nodesToSkip);
searchForNodesToSkip(output, nodesToSkip);
}
for (const auto& node : graphNodes) {
if (nodesToSkip.count(node) && !node->enforceBF16evenForGraphTail)
continue;
if (node->getType() != Type::Input && node->getType() != Type::Output) {
if (one_of(node->getType(), Type::Input, Type::Output))
continue;
#ifdef CPU_DEBUG_CAPS
if (!inferPrecDebug.enabled(NameFromType(node->getType()), node->getName()))
continue;
if (!inferPrecDebug.enabled(NameFromType(node->getType()), node->getName()))
continue;
#endif
DEBUG_LOG("#", node->getExecIndex(), " ", node->getName(), " is enforced to use", inferPrec);
DEBUG_LOG("#", node->getExecIndex(),
" ", node->getName(),
" is enforced to use", inferPrec);
for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) {
auto keepOriginalInputPrecisionAtPort = [](const NodePtr& node, const size_t inPort) {
// keep non-float precisions
if (node->getOriginalInputPrecisionAtPort(inPort) != Precision::FP32)
return true;
for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) {
const auto &parent = node->getParentEdgesAtPort(i)[0]->getParent();
const auto &parent = node->getParentEdgesAtPort(inPort)[0]->getParent();
/* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing.
* Precision conversion to BF16 does automatically, if convolution follows up after Constant Inputs
* and if activation is BF16 */
if (!(parent->getType() == Type::Input && parent->isConstant() &&
* Precision conversion to BF16 is done automatically, if convolution follows up after Constant Inputs
* and activation is BF16 */
if (parent->getType() == Type::Input && parent->isConstant() &&
// Concatenation node is exception because it doesn't change an accuracy for BF16 activation
node->getType() != Type::Concatenation) &&
// exclude Eltwise after Input since it supports conversion to BF16
!(parent->getType() == Type::Input && (node->getType() == Type::Eltwise || node->getType() == Type::Subgraph)) &&
node->getOriginalInputPrecisionAtPort(i) == Precision::FP32)
node->setOriginalInputPrecisionAtPort(i, inferPrec);
}
node->getType() != Type::Concatenation)
return true;
// Eltwise and Subgraph (snippets) nodes support precision conversion
if (parent->getType() == Type::Input && one_of(node->getType(), Type::Eltwise, Type::Subgraph))
return true;
for (size_t i = 0; i < node->getOriginalOutputsNumber(); i++) {
if (node->getOriginalOutputPrecisionAtPort(i) == Precision::FP32)
node->setOriginalOutputPrecisionAtPort(i, inferPrec);
}
return false;
};
if (keepOriginalInputPrecisionAtPort(node, i))
continue;
node->setOriginalInputPrecisionAtPort(i, inferPrec);
}
for (size_t i = 0; i < node->getOriginalOutputsNumber(); i++) {
// keep non-float precisions
if (node->getOriginalOutputPrecisionAtPort(i) != Precision::FP32)
continue;
node->setOriginalOutputPrecisionAtPort(i, inferPrec);
}
}
}

View File

@ -5,6 +5,7 @@
#include "eltwise.h"
#include <common/float16.hpp>
#include <map>
#include <set>
@ -1537,6 +1538,13 @@ public:
static const int optimalTensorRank = 6;
};
/* enabled only for float at float16_t at the moment
* can be extended in the future */
template<typename T,
typename std::enable_if<
std::is_same<T, float>::value ||
std::is_same<T, dnnl::impl::float16_t>::value>
::type* = nullptr>
class EltwiseRefExecutor : public Eltwise::IEltwiseExecutor {
public:
EltwiseRefExecutor(Eltwise::EltwiseData opData,
@ -1571,30 +1579,30 @@ public:
_dst_offsets.resize(input_size, 1);
EltwiseJitExecutor::offset_out_calc(_dst_offsets, _dims);
for (size_t j = 0; j < input_size; j++) {
_dst_offsets[j] *= sizeof(float); // only FP32 out prc is supported
_dst_offsets[j] *= sizeof(T);
}
for (size_t i = 0; i < _inputNum; i++) {
_src_offsets[i].resize(input_size, 1);
EltwiseJitExecutor::offset_in_calc(_src_offsets[i], inpDims[i], _dims);
for (size_t j = 0; j < input_size; j++) {
_src_offsets[i][j] *= sizeof(float); // only FP32 inp prcs are supported
_src_offsets[i][j] *= sizeof(T);
}
}
}
void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override {
if (_opData.algo == Algorithm::EltwiseLog) {
const float* src_ptr_f = reinterpret_cast<const float*>(args_ptrs.src_ptr[0]);
float* dst_ptr_f = reinterpret_cast<float*>(args_ptrs.dst_ptr);
const T* src_ptr_f = reinterpret_cast<const T*>(args_ptrs.src_ptr[0]);
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr);
parallel_for(_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = logf(src_ptr_f[i]);
});
return;
}
if (_opData.algo == Algorithm::EltwisePowerStatic) {
const float* src_ptr_f = reinterpret_cast<const float*>(args_ptrs.src_ptr[0]);
float* dst_ptr_f = reinterpret_cast<float*>(args_ptrs.dst_ptr);
const T* src_ptr_f = reinterpret_cast<const T*>(args_ptrs.src_ptr[0]);
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr);
if (_opData.alpha == 2) {
parallel_for(_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = (_opData.beta * src_ptr_f[i] + _opData.gamma) *
@ -1608,9 +1616,9 @@ public:
return;
}
if (_opData.algo == Algorithm::EltwisePowerDynamic) {
const float* src_ptr_f = reinterpret_cast<const float*>(args_ptrs.src_ptr[0]);
const float* src_ptr_f_pow = reinterpret_cast<const float*>(args_ptrs.src_ptr[1]);
float* dst_ptr_f = reinterpret_cast<float*>(args_ptrs.dst_ptr);
const T* src_ptr_f = reinterpret_cast<const T*>(args_ptrs.src_ptr[0]);
const T* src_ptr_f_pow = reinterpret_cast<const T*>(args_ptrs.src_ptr[1]);
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr);
uint32_t count_of_power_values = 1;
for (unsigned long i : _inpDims[1]) {
@ -1656,20 +1664,20 @@ public:
for (size_t j = 0; j < counters.size(); j++) {
index_in[i] += counters[j] * _src_offsets[i][j];
}
index_in[i] /= sizeof(float);
index_in[i] /= sizeof(T);
}
size_t index_out = 0;
for (size_t j = 0; j < counters.size(); j++) {
index_out += counters[j] * _dst_offsets[j];
}
index_out /= sizeof(float);
index_out /= sizeof(T);
std::vector<float> src_f(_inputNum);
std::vector<T> src_f(_inputNum);
for (size_t i = 0; i < _inputNum; i++) {
src_f[i] = (reinterpret_cast<const float*>(args_ptrs.src_ptr[i]) + index_in[i])[0];
src_f[i] = (reinterpret_cast<const T*>(args_ptrs.src_ptr[i]) + index_in[i])[0];
}
float* dst_ptr_f = reinterpret_cast<float*>(args_ptrs.dst_ptr) + index_out;
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr) + index_out;
switch (_opData.algo) {
case Algorithm::EltwiseRelu:
@ -1712,13 +1720,14 @@ public:
case Algorithm::EltwiseLogicalOr: *dst_ptr_f = src_f[0] || src_f[1]; break;
case Algorithm::EltwiseLogicalXor: *dst_ptr_f = (src_f[0] || src_f[1]) - (src_f[0] && src_f[1]); break;
case Algorithm::EltwiseLogicalNot: *dst_ptr_f = !src_f[0]; break;
case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : src_f[0] * src_f[1]; break;
case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : static_cast<T>(src_f[0] * src_f[1]); break;
case Algorithm::EltwiseErf: *dst_ptr_f = std::erf(src_f[0]); break;
case Algorithm::EltwiseSoftSign: *dst_ptr_f = src_f[0] / (1 + std::fabs(src_f[0])); break;
case Algorithm::EltwiseIsFinite: *dst_ptr_f = std::isfinite(src_f[0]); break;
// @todo implement proper isinfinite for non-float precisions
case Algorithm::EltwiseIsFinite: *dst_ptr_f = std::isfinite(static_cast<float>(src_f[0])); break;
case Algorithm::EltwiseIsInf:
*dst_ptr_f = (_opData.alpha && (src_f[0] == -std::numeric_limits<float>::infinity())) ||
(_opData.beta && (src_f[0] == std::numeric_limits<float>::infinity()));
*dst_ptr_f = (_opData.alpha && (src_f[0] == -std::numeric_limits<T>::infinity())) ||
(_opData.beta && (src_f[0] == std::numeric_limits<T>::infinity()));
break;
case Algorithm::EltwiseIsNaN: *dst_ptr_f = std::isnan(src_f[0]); break;
case Algorithm::EltwiseSelect: *dst_ptr_f = src_f[0] ? src_f[1] : src_f[2]; break;
@ -1757,24 +1766,32 @@ bool Eltwise::EltwiseData::operator==(const EltwiseData &rhs) const noexcept {
gamma == rhs.gamma;
}
static Eltwise::executorPtr buildExecutor(const EltwiseKey& key) {
Eltwise::executorPtr execPtr;
if (key.implType != EltwiseImplType::reference) {
execPtr = std::make_shared<EltwiseJitExecutor>(key.eltwise_data,
key.ops_list,
key.outBlkDims,
key.outOrder,
key.inpDims,
key.inpPrc,
key.outPrc,
key.postOps,
key.implType == EltwiseImplType::optimizedShapeAgnostic);
} else {
execPtr = std::make_shared<EltwiseRefExecutor>(key.eltwise_data.front(),
static Eltwise::executorPtr buildRefExecutor(const EltwiseKey& key) {
if (key.outPrc == Precision::FP16) {
return std::make_shared<EltwiseRefExecutor<dnnl::impl::float16_t>>(key.eltwise_data.front(),
key.outBlkDims,
key.inpDims);
}
// use float reference executor for any other precision for now
return std::make_shared<EltwiseRefExecutor<float>>(key.eltwise_data.front(),
key.outBlkDims,
key.inpDims);
}
static Eltwise::executorPtr buildExecutor(const EltwiseKey& key) {
if (key.implType == EltwiseImplType::reference) {
return buildRefExecutor(key);
}
return execPtr;
return std::make_shared<EltwiseJitExecutor>(key.eltwise_data,
key.ops_list,
key.outBlkDims,
key.outOrder,
key.inpDims,
key.inpPrc,
key.outPrc,
key.postOps,
key.implType == EltwiseImplType::optimizedShapeAgnostic);
}
bool Eltwise::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
@ -1965,8 +1982,9 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
}
#if defined(OV_CPU_WITH_ACL)
Precision forcedPrec;
//ACL implementation supports only identical precisions on inputs/outputs so they are aligned it to highest one
// Use original output precision as a reference point since some eltwise algorithms have non-float inputs (i.e. EltwiseSelect)
Precision forcedPrec = getOriginalOutputPrecisionAtPort(0) == Precision::FP16 ? Precision::FP16 : Precision::FP32;
// ACL implementation supports only identical precisions on inputs/outputs so they are aligned it to highest one
if (AclEltwiseExecutor::isEltwiseAlgorithmSupported(getAlgorithm())) {
for (size_t i = 0; i < getParentEdges().size(); i++) {
if (!getParentEdgeAt(i)->getParent()->isConstant()) {
@ -1978,9 +1996,8 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
if (!forcedPrec.is_float()) {
forcedPrec = Precision::FP32;
}
} else {
forcedPrec = Precision::FP32;
}
for (size_t i = 0; i < inputPrecisions.size(); i++) {
inputPrecisions[i] = forcedPrec;
}

View File

@ -217,12 +217,14 @@ void FullyConnected::getSupportedDescriptors() {
outputDataType = memory::data_type::bf16;
}
} else if (inputDataType == memory::data_type::f16) {
#if defined(OV_CPU_WITH_ACL)
// acl fc does not support precisions conversion
outputDataType = weightsDataType = memory::data_type::f16;
#else
// f16 input only supports f16/f32 output, even if FQ is fused as post-ops
if (!one_of(outputDataType , memory::data_type::f32, memory::data_type::f16)) {
outputDataType = memory::data_type::f16;
}
#if defined(OV_CPU_WITH_ACL)
weightsDataType = memory::data_type::f16;
#endif
} else if (one_of(inputDataType, memory::data_type::u8, memory::data_type::s8)) {
if (weightsDataType != memory::data_type::s8) {

View File

@ -256,7 +256,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
{ov::element::i4, ov::element::i8},
{ov::element::u4, ov::element::u8}
};
// @todo should we always convert to f32 regardless of hardware support, as it is done for f16?
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
map.insert({ov::element::bf16, ov::element::f32});

View File

@ -11,6 +11,7 @@
#include "ie_common.h"
#include "ie_layouts.h"
#include "general_utils.h"
#include "precision_support.h"
namespace ov {
namespace intel_cpu {
@ -98,10 +99,14 @@ inline bool isEmptyTensorDesc(const InferenceEngine::TensorDesc &td) {
*/
inline InferenceEngine::Precision normalizeToSupportedPrecision(InferenceEngine::Precision precision) {
switch (precision) {
case InferenceEngine::Precision::BF16:
case InferenceEngine::Precision::FP16: {
if (!hasHardwareSupport(precision))
precision = InferenceEngine::Precision::FP32;
}
case InferenceEngine::Precision::U8:
case InferenceEngine::Precision::I8:
case InferenceEngine::Precision::I32:
case InferenceEngine::Precision::BF16:
case InferenceEngine::Precision::FP32: {
break;
}
@ -121,14 +126,11 @@ inline InferenceEngine::Precision normalizeToSupportedPrecision(InferenceEngine:
precision = InferenceEngine::Precision::I32;
break;
}
case InferenceEngine::Precision::FP16: {
precision = InferenceEngine::Precision::FP32;
break;
}
default: {
precision = InferenceEngine::Precision::UNSPECIFIED;
}
}
return precision;
}
@ -161,6 +163,5 @@ inline std::vector<float> makeAlignedBuffer(size_t targetSize, const std::vector
}
return alignedBuffer;
}
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,42 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "precision_support.h"
#include "ie_precision.hpp"
#include "cpu/x64/cpu_isa_traits.hpp"
#include "openvino/core/visibility.hpp"
namespace ov {
namespace intel_cpu {
bool hasHardwareSupport(const InferenceEngine::Precision& precision) {
switch (precision) {
case InferenceEngine::Precision::FP16: {
#if defined(OPENVINO_ARCH_X86_64)
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_fp16))
return true;
return false;
#elif defined(OV_CPU_ARM_ENABLE_FP16)
return true; // @todo add runtime check for arm as well
#else
return false;
#endif
}
case InferenceEngine::Precision::BF16: {
#if defined(OPENVINO_ARCH_X86_64)
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
return true;
return false;
#else
return false;
#endif
}
default:
return true;
}
}
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,15 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ie_precision.hpp"
namespace ov {
namespace intel_cpu {
bool hasHardwareSupport(const InferenceEngine::Precision& precision);
} // namespace intel_cpu
} // namespace ov

View File

@ -88,12 +88,12 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*Hetero.*Behavior.*ExecutableNetworkBaseTest.*ExecGraphInfo.*)",
R"(.*Hetero.*Behavior.*OVCompiledModelBaseTest.*ExecGraphInfo.*)",
R"(.*Hetero.*Behavior.*ExecutableNetworkBaseTest.*CanCreateTwoExeNetworksAndCheckFunction.*)",
// TODO: CVS-104942
// TODO: 104942
R"(.*(Auto|Multi).*Behavior.*ExecutableNetworkBaseTest.*canLoadCorrectNetworkToGetExecutableAndCheckConfig.*)",
R"(.*(Auto|Multi).*SetPropLoadNetWorkGetPropTests.*)",
R"(.*Hetero.*Behavior.*OVCompiledModelBaseTest.*canCreateTwoCompiledModelAndCheckTheir.*)",
// CPU does not support dynamic rank
// Issue: CVS-66778
// Issue: 66778
R"(.*smoke_BehaviorTests.*InferFullyDynamicNetworkWith(S|G)etTensor.*)",
R"(.*smoke_Hetero_BehaviorTests.*InferFullyDynamicNetworkWith(S|G)etTensor.*)",
R"(.*smoke_Auto_BehaviorTests.*InferFullyDynamicNetworkWith(S|G)etTensor.*)",
@ -111,7 +111,7 @@ std::vector<std::string> disabledTestPatterns() {
// Issue 67214
R"(smoke_PrePostProcess.*resize_and_convert_layout_i8.*)",
// TODO: CVS-67255
// TODO: 67255
R"(smoke_If.*SimpleIf2OutTest.*)",
// Issue: 69086
@ -188,7 +188,7 @@ std::vector<std::string> disabledTestPatterns() {
// New plugin API doesn't support changes of pre-processing
R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInputInfo.*)",
R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInferRequest.*)",
// TODO: for 22.2 (CVS-68949)
// TODO: for 22.2 (Issue 68949)
R"(.*smoke_AutoBatching_CPU/AutoBatching_Test_DetectionOutput.*)",
};
@ -212,7 +212,7 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(smoke_NegativeQuantizedMatMulMultiplyFusion.*)");
// int8 specific
retVector.emplace_back(R"(smoke_Quantized.*)");
// TODO: fix CVS-115961
// TODO: Issue 115961
retVector.emplace_back(R"(.*compareAutoBatchingToSingleBatch/CPU_get_blob_batch_size_4_num_streams_1_num_req_64*)");
retVector.emplace_back(R"(.*compareAutoBatchingToSingleBatch/CPU_get_blob_batch_size_4_num_streams_2_num_req_64*)");
retVector.emplace_back(R"(.*compareAutoBatchingToSingleBatch/CPU_set_blob_batch_size_4_num_streams_1_num_req_64*)");
@ -230,9 +230,9 @@ std::vector<std::string> disabledTestPatterns() {
// TODO: generate new 'expected' runtime graph for non-x64 CPU
retVector.emplace_back(R"(smoke_serialization/ExecGraphSerializationTest.ExecutionGraph.*)");
retVector.emplace_back(R"(smoke_ExecGraph/ExecGraphRuntimePrecision.CheckRuntimePrecision/Function=(EltwiseWithTwoDynamicInputs|FakeQuantizeRelu).*)");
// CVS-108803: bug in CPU scalar implementation
// Issue 108803: bug in CPU scalar implementation
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
// CVS-88764, CVS-91647, CVS-108802: accuracy issue
// Issue 88764, 91647, 108802: accuracy issue
retVector.emplace_back(R"(MultipleLSTMCellTest/MultipleLSTMCellTest.CompareWithRefs.*)");
// int8 / code-generation specific
retVector.emplace_back(R"(smoke_LPT.*)");
@ -258,10 +258,14 @@ std::vector<std::string> disabledTestPatterns() {
// Skip fp16 tests for paltforms that don't support fp16 precision
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
}
#endif
#if defined(OV_CPU_ARM_ENABLE_FP16)
// Skip fp16 tests for paltforms that don't support fp16 precision
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
#elif defined(OPENVINO_ARCH_ARM64) || defined(OPENVINO_ARCH_ARM)
#if !defined(OV_CPU_ARM_ENABLE_FP16)
// Skip fp16 tests for paltforms that don't support fp16 precision
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
#else
// Issue 117407
retVector.emplace_back(R"(.*EltwiseLayerCPUTest.*IS=\(\[1\.\.10\.2\.5\.6\]_\).*eltwiseOpType=SqDiff.*_configItem=INFERENCE_PRECISION_HINT=f16.*)");
#endif // OV_CPU_ARM_ENABLE_FP16
#endif
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
// MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions

View File

@ -4,6 +4,8 @@
#include "eltwise.hpp"
#include "gtest/gtest.h"
#include "openvino/core/type/element_type.hpp"
#include "openvino/runtime/properties.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace InferenceEngine;
@ -14,144 +16,149 @@ using namespace ov::test;
namespace CPULayerTestsDefinitions {
std::string EltwiseLayerCPUTest::getTestCaseName(testing::TestParamInfo<EltwiseLayerCPUTestParamsSet> obj) {
subgraph::EltwiseTestParams basicParamsSet;
CPUSpecificParams cpuParams;
fusingSpecificParams fusingParams;
std::tie(basicParamsSet, cpuParams, fusingParams) = obj.param;
subgraph::EltwiseTestParams basicParamsSet;
CPUSpecificParams cpuParams;
fusingSpecificParams fusingParams;
std::tie(basicParamsSet, cpuParams, fusingParams) = obj.param;
std::ostringstream result;
result << subgraph::EltwiseLayerTest::getTestCaseName(testing::TestParamInfo<subgraph::EltwiseTestParams>(
basicParamsSet, 0));
result << CPUTestsBase::getTestCaseName(cpuParams);
result << CpuTestWithFusing::getTestCaseName(fusingParams);
std::ostringstream result;
result << subgraph::EltwiseLayerTest::getTestCaseName(testing::TestParamInfo<subgraph::EltwiseTestParams>(
basicParamsSet, 0));
result << CPUTestsBase::getTestCaseName(cpuParams);
result << CpuTestWithFusing::getTestCaseName(fusingParams);
return result.str();
return result.str();
}
ov::Tensor EltwiseLayerCPUTest::generate_eltwise_input(const ov::element::Type& type, const ngraph::Shape& shape) {
struct gen_params {
uint32_t range;
int32_t start_from;
int32_t resolution;
struct gen_params {
uint32_t range;
int32_t start_from;
int32_t resolution;
gen_params(uint32_t range = 10, int32_t start_from = 0, int32_t resolution = 1)
: range(range), start_from(start_from), resolution(resolution) {}
};
gen_params(uint32_t range = 10, int32_t start_from = 0, int32_t resolution = 1)
: range(range), start_from(start_from), resolution(resolution) {}
};
gen_params params = gen_params();
if (type.is_real()) {
switch (eltwiseType) {
case ngraph::helpers::EltwiseTypes::POWER:
params = gen_params(6, -3);
case ngraph::helpers::EltwiseTypes::MOD:
case ngraph::helpers::EltwiseTypes::FLOOR_MOD:
params = gen_params(2, 2, 8);
break;
case ngraph::helpers::EltwiseTypes::DIVIDE:
params = gen_params(2, 2, 8);
break;
case ngraph::helpers::EltwiseTypes::ERF:
params = gen_params(6, -3);
break;
default:
params = gen_params(80, 0, 8);
break;
}
} else {
params = gen_params(INT32_MAX, INT32_MIN);
gen_params params = gen_params();
if (type.is_real()) {
switch (eltwiseType) {
case ngraph::helpers::EltwiseTypes::POWER:
params = gen_params(6, -3);
case ngraph::helpers::EltwiseTypes::MOD:
case ngraph::helpers::EltwiseTypes::FLOOR_MOD:
params = gen_params(2, 2, 8);
break;
case ngraph::helpers::EltwiseTypes::DIVIDE:
params = gen_params(2, 2, 8);
break;
case ngraph::helpers::EltwiseTypes::ERF:
params = gen_params(6, -3);
break;
default:
params = gen_params(80, 0, 8);
break;
}
return ov::test::utils::create_and_fill_tensor(type, shape, params.range, params.start_from, params.resolution);
} else {
params = gen_params(INT32_MAX, INT32_MIN);
}
return ov::test::utils::create_and_fill_tensor(type, shape, params.range, params.start_from, params.resolution);
}
void EltwiseLayerCPUTest::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
inputs.clear();
const auto& funcInputs = function->inputs();
for (size_t i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
inputs.insert({funcInput.get_node_shared_ptr(), generate_eltwise_input(funcInput.get_element_type(), targetInputStaticShapes[i])});
}
inputs.clear();
const auto& funcInputs = function->inputs();
for (size_t i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
inputs.insert({funcInput.get_node_shared_ptr(), generate_eltwise_input(funcInput.get_element_type(), targetInputStaticShapes[i])});
}
}
void EltwiseLayerCPUTest::SetUp() {
subgraph::EltwiseTestParams basicParamsSet;
CPUSpecificParams cpuParams;
fusingSpecificParams fusingParams;
std::tie(basicParamsSet, cpuParams, fusingParams) = this->GetParam();
std::vector<InputShape> shapes;
ElementType netType;
ngraph::helpers::InputLayerType secondaryInputType;
ov::test::utils::OpType opType;
ov::AnyMap additional_config;
std::tie(shapes, eltwiseType, secondaryInputType, opType, netType, inType, outType, targetDevice, additional_config) = basicParamsSet;
subgraph::EltwiseTestParams basicParamsSet;
CPUSpecificParams cpuParams;
fusingSpecificParams fusingParams;
std::tie(basicParamsSet, cpuParams, fusingParams) = this->GetParam();
std::vector<InputShape> shapes;
ElementType netType;
ngraph::helpers::InputLayerType secondaryInputType;
ov::test::utils::OpType opType;
ov::AnyMap additionalConfig;
std::tie(shapes, eltwiseType, secondaryInputType, opType, netType, inType, outType, targetDevice, additionalConfig) = basicParamsSet;
// we have to change model precision as well, otherwise inference precision won't affect single-node graph
// due to enforce inference precision optimization for the eltwise as first node of the model
if (ov::element::Type(netType).is_real() && additionalConfig.count(ov::hint::inference_precision.name())) {
netType = additionalConfig[ov::hint::inference_precision.name()].as<ov::element::Type>();
}
if (ElementType::bf16 == netType) {
rel_threshold = 2e-2f;
} else if (ElementType::i32 == netType) {
abs_threshold = 0;
if (ElementType::bf16 == netType) {
rel_threshold = 2e-2f;
} else if (ElementType::i32 == netType) {
abs_threshold = 0;
}
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
std::tie(postOpMgrPtr, fusedOps) = fusingParams;
shapes.resize(2);
switch (opType) {
case ov::test::utils::OpType::SCALAR: {
std::vector<ngraph::Shape> identityShapes(shapes[0].second.size(), {1});
shapes[1] = {{}, identityShapes};
break;
}
case ov::test::utils::OpType::VECTOR:
if (shapes[1].second.empty()) {
shapes[1] = shapes[0];
}
break;
default:
FAIL() << "Unsupported Secondary operation type";
}
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
std::tie(postOpMgrPtr, fusedOps) = fusingParams;
init_input_shapes(shapes);
configuration.insert(additionalConfig.begin(), additionalConfig.end());
updateSelectedType(getPrimitiveType(), netType, configuration);
// selectedType = makeSelectedTypeStr(getPrimitiveType(), netType);
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
if (eltwiseType == POWER) {
selectedType = std::regex_replace(selectedType, std::regex("acl"), "ref");
}
#endif
selectedType = makeSelectedTypeStr(getPrimitiveType(), netType);
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
if (eltwiseType == POWER) {
selectedType = std::regex_replace(selectedType, std::regex("acl"), "ref");
}
#endif
shapes.resize(2);
switch (opType) {
case ov::test::utils::OpType::SCALAR: {
std::vector<ngraph::Shape> identityShapes(shapes[0].second.size(), {1});
shapes[1] = {{}, identityShapes};
break;
}
case ov::test::utils::OpType::VECTOR:
if (shapes[1].second.empty()) {
shapes[1] = shapes[0];
}
break;
default:
FAIL() << "Unsupported Secondary operation type";
}
init_input_shapes(shapes);
configuration.insert(additional_config.begin(), additional_config.end());
auto parameters = ngraph::builder::makeDynamicParams(netType, {inputDynamicShapes.front()});
std::shared_ptr<ngraph::Node> secondaryInput;
if (secondaryInputType == ngraph::helpers::InputLayerType::PARAMETER) {
secondaryInput = ngraph::builder::makeDynamicParams(netType, {inputDynamicShapes.back()}).front();
parameters.push_back(std::dynamic_pointer_cast<ngraph::opset3::Parameter>(secondaryInput));
auto parameters = ngraph::builder::makeDynamicParams(netType, {inputDynamicShapes.front()});
std::shared_ptr<ngraph::Node> secondaryInput;
if (secondaryInputType == ngraph::helpers::InputLayerType::PARAMETER) {
secondaryInput = ngraph::builder::makeDynamicParams(netType, {inputDynamicShapes.back()}).front();
parameters.push_back(std::dynamic_pointer_cast<ngraph::opset3::Parameter>(secondaryInput));
} else {
auto pShape = inputDynamicShapes.back();
ngraph::Shape shape;
if (pShape.is_static()) {
shape = pShape.get_shape();
} else {
auto pShape = inputDynamicShapes.back();
ngraph::Shape shape;
if (pShape.is_static()) {
shape = pShape.get_shape();
} else {
ASSERT_TRUE(pShape.rank().is_static());
shape = std::vector<size_t>(pShape.rank().get_length(), 1);
for (size_t i = 0; i < pShape.size(); ++i) {
if (pShape[i].is_static()) {
shape[i] = pShape[i].get_length();
}
ASSERT_TRUE(pShape.rank().is_static());
shape = std::vector<size_t>(pShape.rank().get_length(), 1);
for (size_t i = 0; i < pShape.size(); ++i) {
if (pShape[i].is_static()) {
shape[i] = pShape[i].get_length();
}
}
if (netType == ElementType::i32) {
auto data_tensor = generate_eltwise_input(ElementType::i32, shape);
auto data_ptr = reinterpret_cast<int32_t*>(data_tensor.data());
std::vector<int32_t> data(data_ptr, data_ptr + ngraph::shape_size(shape));
secondaryInput = ngraph::builder::makeConstant(netType, shape, data);
} else {
auto data_tensor = generate_eltwise_input(ElementType::f32, shape);
auto data_ptr = reinterpret_cast<float*>(data_tensor.data());
std::vector<float> data(data_ptr, data_ptr + ngraph::shape_size(shape));
secondaryInput = ngraph::builder::makeConstant(netType, shape, data);
}
}
auto eltwise = ngraph::builder::makeEltwise(parameters[0], secondaryInput, eltwiseType);
function = makeNgraphFunction(netType, parameters, eltwise, "Eltwise");
if (netType == ElementType::i32) {
auto data_tensor = generate_eltwise_input(ElementType::i32, shape);
auto data_ptr = reinterpret_cast<int32_t*>(data_tensor.data());
std::vector<int32_t> data(data_ptr, data_ptr + ngraph::shape_size(shape));
secondaryInput = ngraph::builder::makeConstant(netType, shape, data);
} else {
auto data_tensor = generate_eltwise_input(ElementType::f32, shape);
auto data_ptr = reinterpret_cast<float*>(data_tensor.data());
std::vector<float> data(data_ptr, data_ptr + ngraph::shape_size(shape));
secondaryInput = ngraph::builder::makeConstant(netType, shape, data);
}
}
auto eltwise = ngraph::builder::makeEltwise(parameters[0], secondaryInput, eltwiseType);
function = makeNgraphFunction(netType, parameters, eltwise, "Eltwise");
}
TEST_P(EltwiseLayerCPUTest, CompareWithRefs) {
@ -161,250 +168,250 @@ TEST_P(EltwiseLayerCPUTest, CompareWithRefs) {
namespace Eltwise {
const std::vector<ov::AnyMap>& additional_config() {
static const std::vector<ov::AnyMap> additional_config = {
static const std::vector<ov::AnyMap> additionalConfig = {
{{ov::hint::inference_precision.name(), ov::element::f32}},
{{ov::hint::inference_precision.name(), ov::element::f16}}
};
return additional_config;
};
return additionalConfig;
}
const std::vector<ElementType>& netType() {
static const std::vector<ElementType> netType = {
ElementType::f32};
return netType;
static const std::vector<ElementType> netType = {
ElementType::f32};
return netType;
}
const std::vector<ov::test::utils::OpType>& opTypes() {
static const std::vector<ov::test::utils::OpType> opTypes = {
ov::test::utils::OpType::VECTOR,
};
return opTypes;
static const std::vector<ov::test::utils::OpType> opTypes = {
ov::test::utils::OpType::VECTOR,
};
return opTypes;
}
const std::vector<ngraph::helpers::EltwiseTypes>& eltwiseOpTypesBinInp() {
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesBinInp = {
ngraph::helpers::EltwiseTypes::ADD,
ngraph::helpers::EltwiseTypes::MULTIPLY,
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
ngraph::helpers::EltwiseTypes::SUBTRACT, // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::DIVIDE, // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::FLOOR_MOD, // TODO: Fix CVS-111875
#endif
ngraph::helpers::EltwiseTypes::SQUARED_DIFF,
};
return eltwiseOpTypesBinInp;
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesBinInp = {
ngraph::helpers::EltwiseTypes::ADD,
ngraph::helpers::EltwiseTypes::MULTIPLY,
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
ngraph::helpers::EltwiseTypes::SUBTRACT, // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::DIVIDE, // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::FLOOR_MOD, // TODO: Fix CVS-111875
#endif
ngraph::helpers::EltwiseTypes::SQUARED_DIFF,
};
return eltwiseOpTypesBinInp;
}
const std::vector<ngraph::helpers::EltwiseTypes>& eltwiseOpTypesDiffInp() {
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesDiffInp = { // Different number of input nodes depending on optimizations
ngraph::helpers::EltwiseTypes::POWER,
// ngraph::helpers::EltwiseTypes::MOD // Does not execute because of transformations
};
return eltwiseOpTypesDiffInp;
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesDiffInp = { // Different number of input nodes depending on optimizations
ngraph::helpers::EltwiseTypes::POWER,
// ngraph::helpers::EltwiseTypes::MOD // Does not execute because of transformations
};
return eltwiseOpTypesDiffInp;
}
const std::vector<ngraph::helpers::EltwiseTypes>& eltwiseOpTypesBinDyn() {
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesBinDyn = {
ngraph::helpers::EltwiseTypes::ADD,
ngraph::helpers::EltwiseTypes::MULTIPLY,
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::SUBTRACT,
#endif
ngraph::helpers::EltwiseTypes::SQUARED_DIFF,
};
return eltwiseOpTypesBinDyn;
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesBinDyn = {
ngraph::helpers::EltwiseTypes::ADD,
ngraph::helpers::EltwiseTypes::MULTIPLY,
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::SUBTRACT,
#endif
ngraph::helpers::EltwiseTypes::SQUARED_DIFF,
};
return eltwiseOpTypesBinDyn;
}
const std::vector<CPUSpecificParams>& cpuParams_4D() {
static const std::vector<CPUSpecificParams> cpuParams_4D = {
CPUSpecificParams({nhwc, nhwc}, {nhwc}, {}, {}),
CPUSpecificParams({nchw, nchw}, {nchw}, {}, {})
};
return cpuParams_4D;
static const std::vector<CPUSpecificParams> cpuParams_4D = {
CPUSpecificParams({nhwc, nhwc}, {nhwc}, {}, {}),
CPUSpecificParams({nchw, nchw}, {nchw}, {}, {})
};
return cpuParams_4D;
}
const std::vector<CPUSpecificParams>& cpuParams_5D() {
static const std::vector<CPUSpecificParams> cpuParams_5D = {
CPUSpecificParams({ndhwc, ndhwc}, {ndhwc}, {}, {}),
CPUSpecificParams({ncdhw, ncdhw}, {ncdhw}, {}, {})
};
return cpuParams_5D;
static const std::vector<CPUSpecificParams> cpuParams_5D = {
CPUSpecificParams({ndhwc, ndhwc}, {ndhwc}, {}, {}),
CPUSpecificParams({ncdhw, ncdhw}, {ncdhw}, {}, {})
};
return cpuParams_5D;
}
const std::vector<std::vector<ov::Shape>>& inShapes_4D() {
static const std::vector<std::vector<ov::Shape>> inShapes_4D = {
{{2, 4, 4, 1}},
{{2, 17, 5, 4}},
{{2, 17, 5, 4}, {1, 17, 1, 1}},
{{2, 17, 5, 1}, {1, 17, 1, 4}},
};
return inShapes_4D;
static const std::vector<std::vector<ov::Shape>> inShapes_4D = {
{{2, 4, 4, 1}},
{{2, 17, 5, 4}},
{{2, 17, 5, 4}, {1, 17, 1, 1}},
{{2, 17, 5, 1}, {1, 17, 1, 4}},
};
return inShapes_4D;
}
const std::vector<std::vector<ov::Shape>>& inShapes_5D() {
static const std::vector<std::vector<ov::Shape>> inShapes_5D = {
{{2, 4, 3, 4, 1}},
{{2, 17, 7, 5, 4}},
{{2, 17, 6, 5, 4}, {1, 17, 6, 1, 1}},
{{2, 17, 6, 5, 1}, {1, 17, 1, 1, 4}},
};
return inShapes_5D;
static const std::vector<std::vector<ov::Shape>> inShapes_5D = {
{{2, 4, 3, 4, 1}},
{{2, 17, 7, 5, 4}},
{{2, 17, 6, 5, 4}, {1, 17, 6, 1, 1}},
{{2, 17, 6, 5, 1}, {1, 17, 1, 1, 4}},
};
return inShapes_5D;
}
const std::vector<ngraph::helpers::EltwiseTypes>& eltwiseOpTypesI32() {
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesI32 = {
ngraph::helpers::EltwiseTypes::ADD,
ngraph::helpers::EltwiseTypes::MULTIPLY,
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::SUBTRACT,
ngraph::helpers::EltwiseTypes::DIVIDE,
#endif
ngraph::helpers::EltwiseTypes::SQUARED_DIFF,
};
return eltwiseOpTypesI32;
static const std::vector<ngraph::helpers::EltwiseTypes> eltwiseOpTypesI32 = {
ngraph::helpers::EltwiseTypes::ADD,
ngraph::helpers::EltwiseTypes::MULTIPLY,
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) // TODO: Fix CVS-105430
ngraph::helpers::EltwiseTypes::SUBTRACT,
ngraph::helpers::EltwiseTypes::DIVIDE,
#endif
ngraph::helpers::EltwiseTypes::SQUARED_DIFF,
};
return eltwiseOpTypesI32;
}
const std::vector<ngraph::helpers::InputLayerType>& secondaryInputTypes() {
static const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
ngraph::helpers::InputLayerType::CONSTANT,
ngraph::helpers::InputLayerType::PARAMETER,
};
return secondaryInputTypes;
static const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
ngraph::helpers::InputLayerType::CONSTANT,
ngraph::helpers::InputLayerType::PARAMETER,
};
return secondaryInputTypes;
}
const std::vector<std::vector<ngraph::Shape>>& inShapes_4D_1D() {
static const std::vector<std::vector<ngraph::Shape>> inShapes_4D_1D = {
{{2, 17, 5, 4}, {4}},
{{1, 3, 3, 3}, {3}},
};
return inShapes_4D_1D;
static const std::vector<std::vector<ngraph::Shape>> inShapes_4D_1D = {
{{2, 17, 5, 4}, {4}},
{{1, 3, 3, 3}, {3}},
};
return inShapes_4D_1D;
}
const std::vector<CPUSpecificParams> & cpuParams_4D_1D_Constant_mode() {
static const std::vector<CPUSpecificParams> cpuParams_4D_1D_Constant_mode = {
CPUSpecificParams({nhwc, nhwc}, {nhwc}, {}, {}),
CPUSpecificParams({nchw, nchw}, {nchw}, {}, {})
};
return cpuParams_4D_1D_Constant_mode;
static const std::vector<CPUSpecificParams> cpuParams_4D_1D_Constant_mode = {
CPUSpecificParams({nhwc, nhwc}, {nhwc}, {}, {}),
CPUSpecificParams({nchw, nchw}, {nchw}, {}, {})
};
return cpuParams_4D_1D_Constant_mode;
}
const std::vector<CPUSpecificParams>& cpuParams_4D_1D_Parameter_mode() {
static const std::vector<CPUSpecificParams> cpuParams_4D_1D_Parameter_mode = {
CPUSpecificParams({nchw, x}, {nchw}, {}, {})
};
return cpuParams_4D_1D_Parameter_mode;
static const std::vector<CPUSpecificParams> cpuParams_4D_1D_Parameter_mode = {
CPUSpecificParams({nchw, x}, {nchw}, {}, {})
};
return cpuParams_4D_1D_Parameter_mode;
}
const std::vector<std::vector<ngraph::Shape>>& inShapes_5D_1D() {
static const std::vector<std::vector<ngraph::Shape>> inShapes_5D_1D = {
{{2, 17, 5, 4, 10}, {10}},
{{1, 3, 3, 3, 3}, {3}},
};
return inShapes_5D_1D;
static const std::vector<std::vector<ngraph::Shape>> inShapes_5D_1D = {
{{2, 17, 5, 4, 10}, {10}},
{{1, 3, 3, 3, 3}, {3}},
};
return inShapes_5D_1D;
}
const std::vector<CPUSpecificParams>& cpuParams_5D_1D_parameter() {
static const std::vector<CPUSpecificParams> cpuParams_5D_1D_parameter = {
CPUSpecificParams({ncdhw, x}, {ncdhw}, {}, {})
};
return cpuParams_5D_1D_parameter;
static const std::vector<CPUSpecificParams> cpuParams_5D_1D_parameter = {
CPUSpecificParams({ncdhw, x}, {ncdhw}, {}, {})
};
return cpuParams_5D_1D_parameter;
}
const std::vector<InputShape>& inShapes_4D_dyn_param() {
static const std::vector<InputShape> inShapes_4D_dyn_param = {
static const std::vector<InputShape> inShapes_4D_dyn_param = {
{
// dynamic
{-1, {2, 15}, -1, -1},
// target
{
// dynamic
{-1, {2, 15}, -1, -1},
// target
{
{3, 2, 1, 1},
{1, 7, 5, 1},
{3, 3, 4, 11},
}
}
},
{
// dynamic
{-1, {2, 25}, -1, -1},
// target
{
// dynamic
{-1, {2, 25}, -1, -1},
// target
{
{1, 2, 5, 1},
{3, 7, 1, 10},
{3, 3, 4, 11}
}
}
}
};
return inShapes_4D_dyn_param;
};
return inShapes_4D_dyn_param;
}
const std::vector<InputShape>& inShapes_5D_dyn_param() {
static const std::vector<InputShape> inShapes_5D_dyn_param = {
static const std::vector<InputShape> inShapes_5D_dyn_param = {
{
// dynamic
{-1, {2, 15}, -1, -1, -1},
// target
{
// dynamic
{-1, {2, 15}, -1, -1, -1},
// target
{
{3, 2, 1, 1, 1},
{1, 7, 5, 1, 12},
{3, 3, 4, 11, 6},
}
}
},
{
// dynamic
{-1, {2, 25}, -1, -1, -1},
// target
{
// dynamic
{-1, {2, 25}, -1, -1, -1},
// target
{
{1, 2, 5, 1, 5},
{3, 7, 1, 10, 1},
{3, 3, 4, 11, 6}
}
}
}
};
return inShapes_5D_dyn_param;
};
return inShapes_5D_dyn_param;
}
const std::vector<InputShape>& inShapes_5D_dyn_const() {
static const std::vector<InputShape> inShapes_5D_dyn_const = {
{
// dynamic
{3, 2, -1, -1, -1},
// target
{
{3, 2, 1, 1, 1},
{3, 2, 5, 1, 7},
{3, 2, 1, 6, 1},
{3, 2, 4, 11, 2},
}
},
};
// dynamic
{3, 2, -1, -1, -1},
// target
{
{3, 2, 1, 1, 1},
{3, 2, 5, 1, 7},
{3, 2, 1, 6, 1},
{3, 2, 4, 11, 2},
}
},
};
return inShapes_5D_dyn_const;
}
const std::vector<std::vector<InputShape>>& inShapes_4D_dyn_const() {
static const std::vector<std::vector<InputShape>> inShapes_4D_dyn_const = {
{
{
{
// dynamic
{3, 2, -1, -1},
// target
{
{3, 2, 1, 1},
{3, 2, 5, 1},
{3, 2, 1, 6},
{3, 2, 4, 11},
}
{3, 2, 1, 1},
{3, 2, 5, 1},
{3, 2, 1, 6},
{3, 2, 4, 11},
}
}
},
{
{
{
// dynamic
{{1, 10}, 2, 5, 6},
// target
{
{3, 2, 5, 6},
{1, 2, 5, 6},
{2, 2, 5, 6},
}
{3, 2, 5, 6},
{1, 2, 5, 6},
{2, 2, 5, 6},
}
}
},
};
return inShapes_4D_dyn_const;

View File

@ -114,9 +114,7 @@ void MvnLayerCPUTest::SetUp() {
rel_threshold = 250.f;
}
configuration.insert(additionalConfig.begin(), additionalConfig.end());
selectedType = getPrimitiveType();
selectedType = makeSelectedTypeStr(selectedType, netPrecision);
updateSelectedType(getPrimitiveType(), netPrecision, configuration);
function = makeNgraphFunction(netPrecision, param, mvn, "mvn");
}
@ -358,4 +356,4 @@ const std::vector<double>& epsilon() {
}
} // namespace MVN
} // namespace CPULayerTestsDefinitions
} // namespace CPULayerTestsDefinitions

View File

@ -5,6 +5,7 @@
#include "reduce.hpp"
#include "gtest/gtest.h"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace InferenceEngine;
@ -86,16 +87,10 @@ void ReduceCPULayerTest::SetUp() {
std::tie(axes, opType, keepDims, reductionType, netPrecision, inPrc, outPrc, inputShapes) = basicParams;
if (netPrecision == ElementType::boolean) {
inPrc = outPrc = netPrecision;
} else {
if (additionalConfig[ov::hint::inference_precision.name()] == ov::element::bf16) {
inPrc = outPrc = netPrecision = ElementType::bf16;
} else if (additionalConfig[ov::hint::inference_precision.name()] == ov::element::f16) {
inPrc = outPrc = netPrecision = ElementType::f16;
} else {
inPrc = outPrc = netPrecision;
}
}
configuration.insert(additionalConfig.begin(), additionalConfig.end());
updateSelectedType(getPrimitiveType(), netPrecision == ElementType::boolean ? ElementType::i8 : netPrecision, configuration);
init_input_shapes(inputShapes);
@ -120,9 +115,6 @@ void ReduceCPULayerTest::SetUp() {
const auto reduce = ngraph::builder::makeReduce(paramOuts[0], reductionAxesNode, keepDims, reductionType);
selectedType = getPrimitiveType() + "_" +
(inPrc == ElementType::boolean ? "I8" : InferenceEngine::details::convertPrecision(inPrc).name());
// hybrid layouts
if (inFmts.size() != 0 && outFmts.size() == 0) {
size_t outShapeSize = inputDynamicShapes[0].size() - axes.size();
@ -281,4 +273,4 @@ const std::vector<ngraph::helpers::ReductionType>& reductionTypesInt32() {
}
} // namespace Reduce
} // namespace CPULayerTestsDefinitions
} // namespace CPULayerTestsDefinitions

View File

@ -58,4 +58,4 @@ const std::vector<std::map<std::string, ov::element::Type>> additionalConfigFP32
const std::vector<ngraph::helpers::ReductionType>& reductionTypesInt32();
} // namespace Reduce
} // namespace CPULayerTestsDefinitions
} // namespace CPULayerTestsDefinitions

View File

@ -55,7 +55,7 @@ void TransposeLayerCPUTest::SetUp() {
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
selectedType = makeSelectedTypeStr("unknown", inType);
updateSelectedType("unknown", inType, configuration);
init_input_shapes({inputShapes});
@ -123,4 +123,4 @@ const std::vector<std::vector<size_t>>& inputOrder4D() {
return inputOrder4D;
}
} // namespace Transpose
} // namespace CPULayerTestsDefinitions
} // namespace CPULayerTestsDefinitions

View File

@ -4,7 +4,9 @@
#include "cpu_test_utils.hpp"
#include "ie_ngraph_utils.hpp"
#include "openvino/core/type/element_type.hpp"
#include "utils/rt_info/memory_formats_attribute.hpp"
#include "utils/general_utils.h"
#include <cstdint>
namespace CPUTestUtils {
@ -271,7 +273,6 @@ std::string CPUTestsBase::getPrimitiveType() const {
}
return isaType;
}
#endif
std::string CPUTestsBase::getISA(bool skip_amx) const {
@ -352,6 +353,32 @@ std::string CPUTestsBase::makeSelectedTypeStr(std::string implString, ngraph::el
return implString;
}
void CPUTestsBase::updateSelectedType(const std::string& primitiveType, const ov::element::Type netType, const ov::AnyMap& config) {
auto getExecType = [&](){
// inference_precision affects only floating point type networks
if (!netType.is_real())
return netType;
const auto it = config.find(ov::hint::inference_precision.name());
if (it == config.end())
return netType;
const auto inference_precision_type = it->second.as<ov::element::Type>();
// currently plugin only allows to change precision from higher to lower (i.e. f32 -> f16 or f32 -> bf16)
if (netType.bitwidth() < inference_precision_type.bitwidth()) {
return netType;
}
return inference_precision_type;
};
const auto execType = getExecType();
selectedType = primitiveType;
selectedType.push_back('_');
selectedType += InferenceEngine::details::convertPrecision(execType).name();
}
std::vector<CPUSpecificParams> filterCPUSpecificParams(const std::vector<CPUSpecificParams> &paramsVector) {
auto adjustBlockedFormatByIsa = [](std::vector<cpu_memory_format_t>& formats) {
for (auto& format : formats) {

View File

@ -128,6 +128,7 @@ public:
const std::vector<std::string>& priority);
//TODO: change to setter method
static std::string makeSelectedTypeStr(std::string implString, ngraph::element::Type_t elType);
void updateSelectedType(const std::string& primitiveType, const ov::element::Type netType, const ov::AnyMap& config);
CPUInfo getCPUInfo() const;
std::shared_ptr<ngraph::Function> makeNgraphFunction(const ngraph::element::Type &ngPrc,

View File

@ -22,7 +22,7 @@ typedef std::tuple<
ElementType, // In precision
ElementType, // Out precision
TargetDevice, // Device name
ov::AnyMap // Additional network configuration
ov::AnyMap // Additional network configuration
> EltwiseTestParams;
class EltwiseLayerTest : public testing::WithParamInterface<EltwiseTestParams>,

View File

@ -41,7 +41,7 @@ std::string EltwiseLayerTest::getTestCaseName(const testing::TestParamInfo<Eltwi
results << "OutType=" << outType << "_";
results << "trgDev=" << targetName;
for (auto const& configItem : additional_config) {
results << "_configItem=" << configItem.first << "_";
results << "_configItem=" << configItem.first << "=";
configItem.second.print(results);
}
return results.str();