[CPU] Removed custom ShapeInference impl for RandomUniform (#20599)

This commit is contained in:
Nikolay Shchegolev 2023-10-25 11:15:03 +04:00 committed by GitHub
parent 706d657637
commit dc4240bc61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 34 additions and 127 deletions

View File

@ -2,12 +2,9 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <string>
#include <vector>
#include "grid_sample.hpp" #include "grid_sample.hpp"
#include "ie_parallel.hpp" #include "ie_parallel.hpp"
#include <ngraph/opsets/opset1.hpp> #include "openvino/op/grid_sample.hpp"
using namespace InferenceEngine; using namespace InferenceEngine;
using namespace ov::intel_cpu; using namespace ov::intel_cpu;
@ -16,8 +13,6 @@ using namespace ov::intel_cpu::node;
using namespace dnnl::impl::cpu; using namespace dnnl::impl::cpu;
#endif // OPENVINO_ARCH_X86_64 #endif // OPENVINO_ARCH_X86_64
#define THROW_ERROR IE_THROW() << getTypeStr() << " node with name '" << getName() << "' "
bool GridSample::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept { bool GridSample::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try { try {
@ -46,21 +41,21 @@ GridSample::GridSample(const std::shared_ptr<ov::Node>& op, const GraphContext::
: Node(op, context, NgraphShapeInferFactory(op, PortMask(1))) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(1))) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage; THROW_CPU_NODE_ERR(errorMessage);
} }
if (op->get_input_size() != 2 || op->get_output_size() != 1) if (op->get_input_size() != 2 || op->get_output_size() != 1)
THROW_ERROR << "has incorrect number of input/output ports."; THROW_CPU_NODE_ERR("has incorrect number of input/output ports.");
const auto& dataShape = getInputShapeAtPort(IN_DATA); const auto& dataShape = getInputShapeAtPort(IN_DATA);
if (dataShape.getRank() != 4) if (dataShape.getRank() != 4)
THROW_ERROR << "has incorrect rank of the Data input."; THROW_CPU_NODE_ERR("has incorrect rank of the Data input.");
const auto& gridShape = getInputShapeAtPort(IN_GRID); const auto& gridShape = getInputShapeAtPort(IN_GRID);
if (gridShape.getRank() != 4) if (gridShape.getRank() != 4)
THROW_ERROR << "has incorrect rank of the Grid input."; THROW_CPU_NODE_ERR("has incorrect rank of the Grid input.");
if (gridShape.isStatic() && gridShape.getDims()[3] != 2) if (gridShape.isStatic() && gridShape.getDims()[3] != 2)
THROW_ERROR << "has incorrect shape of the Grid input. The 4th dimension should be equal to 2."; THROW_CPU_NODE_ERR("has incorrect shape of the Grid input. The 4th dimension should be equal to 2.");
const auto& attributes = ov::as_type_ptr<ov::op::v9::GridSample>(op)->get_attributes(); const auto& attributes = ov::as_type_ptr<ov::op::v9::GridSample>(op)->get_attributes();
alignCorners = attributes.align_corners; alignCorners = attributes.align_corners;
@ -75,7 +70,7 @@ GridSample::GridSample(const std::shared_ptr<ov::Node>& op, const GraphContext::
interpolationMode = GridSampleInterpolationMode::NEAREST; interpolationMode = GridSampleInterpolationMode::NEAREST;
break; break;
default: default:
THROW_ERROR << "supports only BILINEAR, BICUBIC, NEAREST interpolation modes."; THROW_CPU_NODE_ERR("supports only BILINEAR, BICUBIC, NEAREST interpolation modes.");
} }
switch (attributes.padding_mode) { switch (attributes.padding_mode) {
case op::v9::GridSample::PaddingMode::ZEROS: case op::v9::GridSample::PaddingMode::ZEROS:
@ -88,7 +83,7 @@ GridSample::GridSample(const std::shared_ptr<ov::Node>& op, const GraphContext::
paddingMode = GridSamplePaddingMode::REFLECTION; paddingMode = GridSamplePaddingMode::REFLECTION;
break; break;
default: default:
THROW_ERROR << "supports only BORDER, REFLECTION, ZEROS paddings modes."; THROW_CPU_NODE_ERR("supports only BORDER, REFLECTION, ZEROS paddings modes.");
} }
} }
@ -149,7 +144,7 @@ void GridSample::createPrimitive() {
jitKernel.reset(new kernel::GridSampleKernel<x64::sse41>(jcp)); jitKernel.reset(new kernel::GridSampleKernel<x64::sse41>(jcp));
} }
if (!jitKernel) { if (!jitKernel) {
THROW_ERROR << " could not create JIT kernel."; THROW_CPU_NODE_ERR("could not create JIT kernel.");
} }
jitKernel->create_ker(); jitKernel->create_ker();
@ -187,15 +182,15 @@ void GridSample::createPrimitive() {
void GridSample::prepareParams() { void GridSample::prepareParams() {
auto dataMemPtr = getParentEdgeAt(IN_DATA)->getMemoryPtr(); auto dataMemPtr = getParentEdgeAt(IN_DATA)->getMemoryPtr();
if (!dataMemPtr || !dataMemPtr->isAllocated()) if (!dataMemPtr || !dataMemPtr->isAllocated())
THROW_ERROR << " has not allocated input data memory."; THROW_CPU_NODE_ERR("has not allocated input data memory.");
auto gridMemPtr = getParentEdgeAt(IN_GRID)->getMemoryPtr(); auto gridMemPtr = getParentEdgeAt(IN_GRID)->getMemoryPtr();
if (!gridMemPtr || !gridMemPtr->isAllocated()) if (!gridMemPtr || !gridMemPtr->isAllocated())
THROW_ERROR << " has not allocated input grid memory."; THROW_CPU_NODE_ERR("has not allocated input grid memory.");
auto dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); auto dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
if (!dstMemPtr || !dstMemPtr->isAllocated()) if (!dstMemPtr || !dstMemPtr->isAllocated())
THROW_ERROR << " has not allocated output memory."; THROW_CPU_NODE_ERR("has not allocated output memory.");
if (getSelectedPrimitiveDescriptor() == nullptr) if (getSelectedPrimitiveDescriptor() == nullptr)
THROW_ERROR << " has unidentified preferable primitive descriptor."; THROW_CPU_NODE_ERR("has unidentified preferable primitive descriptor.");
const uint64_t dataElPerVec = jitKernel->getDataElPerVec(); const uint64_t dataElPerVec = jitKernel->getDataElPerVec();
const auto& srcDataShape = dataMemPtr->getStaticDims(); const auto& srcDataShape = dataMemPtr->getStaticDims();

View File

@ -7,10 +7,6 @@
#include <node.h> #include <node.h>
#include "kernels/x64/grid_sample.hpp" #include "kernels/x64/grid_sample.hpp"
#include <memory>
#include <string>
#include <vector>
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
namespace node { namespace node {

View File

@ -48,7 +48,7 @@ template <x64::cpu_isa_t isa>
void jitUniGatherKernel<isa>::create_ker() { void jitUniGatherKernel<isa>::create_ker() {
auto code = x64::jit_generator::create_kernel(); auto code = x64::jit_generator::create_kernel();
if (code != dnnl::impl::status::success) if (code != dnnl::impl::status::success)
IE_THROW() << "Could not create Gather kernel. Error code: " << std::to_string(code); OPENVINO_THROW("Could not create Gather kernel. Error code: ", std::to_string(code));
ker_ = (decltype(ker_))jit_ker(); ker_ = (decltype(ker_))jit_ker();
} }
@ -154,7 +154,7 @@ void jitUniGatherKernel<isa>::generate() {
process(true, true); process(true, true);
} else { // Long case. } else { // Long case.
IE_THROW() << "Gather kernel does not support static shape with after axis size greater than elements in vector."; OPENVINO_THROW("Gather kernel does not support static shape with after axis size greater than elements in vector.");
} }
} }
} else { // Dynamic shapes. } else { // Dynamic shapes.
@ -526,7 +526,7 @@ template <x64::cpu_isa_t isa>
void jitUniGatherKernel<isa>::calcSrcShiftLongBlock(Vmm* vAuxPool, bool shiftFirst) { void jitUniGatherKernel<isa>::calcSrcShiftLongBlock(Vmm* vAuxPool, bool shiftFirst) {
// Most likely there will no significant performance gain vs memcpy in reference implementation on big blocks after axis, // Most likely there will no significant performance gain vs memcpy in reference implementation on big blocks after axis,
// therefore no time was invested to this case yet. // therefore no time was invested to this case yet.
IE_THROW() << "Unsupported case."; OPENVINO_THROW("Unsupported case.");
} }
// Requires vAuxPool length 3. // Requires vAuxPool length 3.

View File

@ -22,6 +22,7 @@
#pragma once #pragma once
#include "jit_kernel_base.hpp"
#include "cpu/x64/jit_generator.hpp" #include "cpu/x64/jit_generator.hpp"
#include <dnnl_types.h> #include <dnnl_types.h>

View File

@ -30,7 +30,7 @@ template <x64::cpu_isa_t isa>
void GridSampleKernel<isa>::create_ker() { void GridSampleKernel<isa>::create_ker() {
auto code = x64::jit_generator::create_kernel(); auto code = x64::jit_generator::create_kernel();
if (code != dnnl::impl::status::success) if (code != dnnl::impl::status::success)
IE_THROW() << "Could not create GridSample kernel. Error code: " << std::to_string(code); OPENVINO_THROW("Could not create GridSample kernel. Error code: ", std::to_string(code));
ker_ = (decltype(ker_))jit_ker(); ker_ = (decltype(ker_))jit_ker();
} }

View File

@ -89,13 +89,13 @@ void JitKernelBase::uni_vpaddd(const Xbyak::Ymm& v_dst,
paddd(xmmDst, ptr[op.getAddress().getRegExp() + vlen]); paddd(xmmDst, ptr[op.getAddress().getRegExp() + vlen]);
vperm2f128(v_dst, v_dst, v_dst, 0x1); vperm2f128(v_dst, v_dst, v_dst, 0x1);
} else { } else {
IE_THROW() << "Not supported operand type."; OPENVINO_THROW("Not supported operand type.");
} }
} else if (isValidIsa(x64::sse41)) { } else if (isValidIsa(x64::sse41)) {
assert(v_dst.getIdx() != v_src.getIdx()); assert(v_dst.getIdx() != v_src.getIdx());
paddd(v_dst, op); paddd(v_dst, op);
} else { } else {
IE_THROW() << "Not defined behavior for instruction 'vpaddd' in current instructions set."; OPENVINO_THROW("Not defined behavior for instruction 'vpaddd' in current instructions set.");
} }
} }
@ -136,13 +136,13 @@ void JitKernelBase::uni_vpsubd(const Xbyak::Ymm& v_dst,
psubd(xmmDst, ptr[op.getAddress().getRegExp() + vlen]); psubd(xmmDst, ptr[op.getAddress().getRegExp() + vlen]);
vperm2f128(v_dst, v_dst, v_dst, 0x1); vperm2f128(v_dst, v_dst, v_dst, 0x1);
} else { } else {
IE_THROW() << "Not supported operand type."; OPENVINO_THROW("Not supported operand type.");
} }
} else if (isValidIsa(x64::sse41)) { } else if (isValidIsa(x64::sse41)) {
assert(v_dst.getIdx() != v_src.getIdx()); assert(v_dst.getIdx() != v_src.getIdx());
psubd(v_dst, op); psubd(v_dst, op);
} else { } else {
IE_THROW() << "Not defined behavior for instruction 'vpsubd' in current instructions set."; OPENVINO_THROW("Not defined behavior for instruction 'vpsubd' in current instructions set.");
} }
} }
@ -244,7 +244,7 @@ void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst,
const bool useMask, const bool useMask,
const bool zeroFill) { const bool zeroFill) {
if (kReadMask.getIdx() == 0) { if (kReadMask.getIdx() == 0) {
IE_THROW() << "The vpgatherdd instruction cannot use the register k0 as mask."; OPENVINO_THROW("The vpgatherdd instruction cannot use the register k0 as mask.");
} }
if (!useMask) if (!useMask)
kxnord(kReadMask, kReadMask, kReadMask); kxnord(kReadMask, kReadMask, kReadMask);
@ -261,7 +261,7 @@ void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst,
const bool useMask, const bool useMask,
const bool zeroFill) { const bool zeroFill) {
if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || vSrcShift.getIdx() == vReadMask.getIdx()) { if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || vSrcShift.getIdx() == vReadMask.getIdx()) {
IE_THROW() << "Any pair of the index, mask, or destination registers cannot be the same."; OPENVINO_THROW("Any pair of the index, mask, or destination registers cannot be the same.");
} }
if (zeroFill) if (zeroFill)
pxor(v_dst, v_dst); // Don't use vpxor. It zeros the rest of the YMM register. pxor(v_dst, v_dst); // Don't use vpxor. It zeros the rest of the YMM register.
@ -299,7 +299,7 @@ void JitKernelBase::gatherdd(const Xbyak::Ymm& v_dst,
const bool useMask, const bool useMask,
const bool zeroFill) { const bool zeroFill) {
if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || vSrcShift.getIdx() == vReadMask.getIdx()) { if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || vSrcShift.getIdx() == vReadMask.getIdx()) {
IE_THROW() << "Any pair of the index, mask, or destination registers cannot be the same."; OPENVINO_THROW("Any pair of the index, mask, or destination registers cannot be the same.");
} }
if (isValidIsa(x64::avx2)) { if (isValidIsa(x64::avx2)) {
if (!useMask) if (!useMask)
@ -430,7 +430,7 @@ void JitKernelBase::fillRestWorkMask(const Xbyak::Xmm& xmmDstMask,
const Xbyak::Reg64& rWorkRest, const Xbyak::Reg64& rWorkRest,
const uint64_t typeSize) { const uint64_t typeSize) {
if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) {
IE_THROW() << "Could not fill data with type size " << typeSize; OPENVINO_THROW("Could not fill data with type size ", typeSize);
} }
Xbyak::Label lEnd; Xbyak::Label lEnd;
auto r32Ones = getReg32(); auto r32Ones = getReg32();
@ -459,7 +459,7 @@ void JitKernelBase::fillRestWorkMask(const Xbyak::Ymm& ymmDstMask,
const Xbyak::Reg64& rWorkRest, const Xbyak::Reg64& rWorkRest,
const uint64_t typeSize) { const uint64_t typeSize) {
if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) {
IE_THROW() << "Could not fill data with type size " << typeSize; OPENVINO_THROW("Could not fill data with type size ", typeSize);
} }
Xbyak::Label lEnd; Xbyak::Label lEnd;
auto elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize; auto elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
@ -499,7 +499,7 @@ void JitKernelBase::load(const Xbyak::Xmm& v_dst,
const size_t typeSize, const size_t typeSize,
const bool zeroFilling) { const bool zeroFilling) {
if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) {
IE_THROW() << "Could not load data with type size " << typeSize; OPENVINO_THROW("Could not load data with type size ", typeSize);
} }
const uint8_t elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize; const uint8_t elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
Xbyak::Label lEnd; Xbyak::Label lEnd;
@ -529,7 +529,7 @@ void JitKernelBase::load(const Xbyak::Ymm& v_dst,
const size_t typeSize, const size_t typeSize,
const bool zeroFilling) { const bool zeroFilling) {
if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) {
IE_THROW() << "Could not load data with type size " << typeSize; OPENVINO_THROW("Could not load data with type size ", typeSize);
} }
const size_t elPerXmm = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize; const size_t elPerXmm = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
Xbyak::Label lEnd; Xbyak::Label lEnd;
@ -568,7 +568,7 @@ void JitKernelBase::store(const Xbyak::Address& dstAddr,
const Xbyak::Reg64& rToStoreNum, const Xbyak::Reg64& rToStoreNum,
const size_t typeSize) { const size_t typeSize) {
if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) {
IE_THROW() << "Could not store data with type size " << typeSize; OPENVINO_THROW("Could not store data with type size ", typeSize);
} }
Xbyak::Label lEnd; Xbyak::Label lEnd;
const size_t elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize; const size_t elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
@ -596,7 +596,7 @@ void JitKernelBase::store(const Xbyak::Address& dstAddr,
const Xbyak::Reg64& rToStoreNum, const Xbyak::Reg64& rToStoreNum,
const size_t typeSize) { const size_t typeSize) {
if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) {
IE_THROW() << "Could not store data with type size " << typeSize; OPENVINO_THROW("Could not store data with type size ", typeSize);
} }
Xbyak::Label lEnd; Xbyak::Label lEnd;
Xbyak::Xmm xmmSrc(v_src.getIdx()); Xbyak::Xmm xmmSrc(v_src.getIdx());

View File

@ -6,9 +6,8 @@
#include "ie_parallel.hpp" #include "ie_parallel.hpp"
#include "ie_ngraph_utils.hpp" #include "ie_ngraph_utils.hpp"
#include <openvino/op/constant.hpp> #include "openvino/op/constant.hpp"
#include <openvino/op/random_uniform.hpp> #include "openvino/op/random_uniform.hpp"
#include "shape_inference/custom/random_uniform.hpp"
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
@ -27,7 +26,7 @@ bool RandomUniform::isSupportedOperation(const std::shared_ptr<const ov::Node>&
} }
RandomUniform::RandomUniform(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context) RandomUniform::RandomUniform(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context)
: Node(op, context, RandomUniformShapeInferFactory(op)) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(0, 1, 2))) {
std::string errorMessage; std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) { if (!isSupportedOperation(op, errorMessage)) {
THROW_CPU_NODE_ERR(errorMessage); THROW_CPU_NODE_ERR(errorMessage);

View File

@ -1,47 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "random_uniform.hpp"
#include <openvino/op/random_uniform.hpp>
namespace ov {
namespace intel_cpu {
namespace node {
// TODO: remove after fixing the issue 123011
IShapeInfer::Result RandomUniformShapeInfer::infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) {
VectorDims dims;
const auto& mem = data_dependency.at(0);
const auto rank = mem->getShape().getElementsCount();
auto shape_prc = mem->getDesc().getPrecision();
switch (shape_prc) {
case InferenceEngine::Precision::I32: {
auto data = reinterpret_cast<const int32_t*>(mem->getData());
dims.assign(data, data + rank);
} break;
case InferenceEngine::Precision::I64: {
auto data = reinterpret_cast<const int64_t*>(mem->getData());
dims.assign(data, data + rank);
} break;
default:
OPENVINO_THROW("Unexpected Shape input precision: ", shape_prc);
}
return {{dims}, ShapeInferStatus::success};
}
RandomUniformShapeInferFactory::RandomUniformShapeInferFactory(const std::shared_ptr<ov::Node>& op) : m_op(op) {
OPENVINO_ASSERT(ov::is_type<const op::v8::RandomUniform>(m_op),
"Unexpected op type in RandomUniform shape inference factory: ", m_op->get_type_name());
}
ShapeInferPtr RandomUniformShapeInferFactory::makeShapeInfer() const {
return std::make_shared<RandomUniformShapeInfer>();
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -1,37 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shape_inference/shape_inference_cpu.hpp"
#include <node.h>
#pragma once
namespace ov {
namespace intel_cpu {
namespace node {
class RandomUniformShapeInfer : public ShapeInferEmptyPads {
public:
explicit RandomUniformShapeInfer() {}
IShapeInfer::Result infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override;
port_mask_t get_port_mask() const override {
return PortMask(0);
}
};
class RandomUniformShapeInferFactory : public ShapeInferFactory {
public:
explicit RandomUniformShapeInferFactory(const std::shared_ptr<ov::Node>& op);
ShapeInferPtr makeShapeInfer() const override;
private:
std::shared_ptr<ov::Node> m_op;
};
} // namespace node
} // namespace intel_cpu
} // namespace ov