[CPU] LRN: dynamic shapes support (#8724)

This commit is contained in:
Vladislav Golubev 2021-11-26 17:41:48 +03:00 committed by GitHub
parent bcf0879785
commit 734185c04c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 183 additions and 22 deletions

View File

@ -14,23 +14,18 @@ using namespace InferenceEngine;
bool MKLDNNLrnNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept { bool MKLDNNLrnNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try { try {
if (isDynamicNgraphNode(op)) { auto lrn = ngraph::as_type_ptr<const ngraph::opset1::LRN>(op);
errorMessage = "Doesn't support op with dynamic shapes";
return false;
}
const auto lrn = std::dynamic_pointer_cast<const ngraph::opset1::LRN>(op);
if (!lrn) { if (!lrn) {
errorMessage = "Only opset1 LRN operation is supported"; errorMessage = "Only opset1 LRN operation is supported";
return false; return false;
} }
const auto dataDims = lrn->get_input_shape(0); const auto& dataDims = lrn->get_input_partial_shape(0);
if (dataDims.size() < 2 || dataDims.size() > 5) { if (dataDims.size() < 2 || dataDims.size() > 5) {
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(dataDims.size()); errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(dataDims.size());
return false; return false;
} }
const auto axesNode = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1)); auto axesNode = ngraph::as_type_ptr<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1));
if (!axesNode) { if (!axesNode) {
errorMessage = "Only Constant operation on 'axis' input is supported"; errorMessage = "Only Constant operation on 'axis' input is supported";
return false; return false;
@ -69,9 +64,10 @@ MKLDNNLrnNode::MKLDNNLrnNode(const std::shared_ptr<ngraph::Node>& op, const mkld
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "LRN node with name '" + getName() + "'"; errorPrefix = "LRN node with name '" + getName() + "'";
const auto lrn = std::dynamic_pointer_cast<const ngraph::opset1::LRN>(op); auto lrn = ngraph::as_type_ptr<const ngraph::opset1::LRN>(op);
const auto axes = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1))->cast_vector<int64_t>(); auto axes = ngraph::as_type_ptr<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1))->cast_vector<int64_t>();
isAcrossMaps = (axes.size() == 1 && axes[0] == 1); bool isAcrossMaps = (axes.size() == 1 && axes[0] == 1);
alg = isAcrossMaps ? mkldnn::algorithm::lrn_across_channels : mkldnn::algorithm::lrn_within_channel;
alpha = static_cast<float>(lrn->get_alpha()); alpha = static_cast<float>(lrn->get_alpha());
beta = static_cast<float>(lrn->get_beta()); beta = static_cast<float>(lrn->get_beta());
k = static_cast<float>(lrn->get_bias()); k = static_cast<float>(lrn->get_bias());
@ -107,21 +103,56 @@ std::shared_ptr<MemoryDesc> MKLDNNLrnNode::getSrcMemDesc(mkldnn::primitive_desc_
if (idx > 0) { if (idx > 0) {
return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), getInputShapeAtPort(idx)); return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), getInputShapeAtPort(idx));
} else { } else {
return MKLDNNExtensionUtils::makeDescriptor(primitive_desc_it.dst_desc(idx)); if (getInputShapeAtPort(idx).isDynamic()) {
return MKLDNNExtensionUtils::makeUndefinedDesc(primitive_desc_it.src_desc(idx), getInputShapeAtPort(idx));
}
return MKLDNNExtensionUtils::makeDescriptor(primitive_desc_it.src_desc(idx));
} }
} }
void MKLDNNLrnNode::createPrimitive() { void MKLDNNLrnNode::createPrimitive() {
if (prim) if (inputShapesDefined()) {
return; if (needPrepareParams())
prepareParams();
updateLastInputDims();
}
}
auto prim_desc = createPrimitiveDescriptor<mkldnn::lrn_forward::primitive_desc, mkldnn::lrn_forward::desc>(); void MKLDNNLrnNode::prepareParams() {
auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
IE_THROW() << errorPrefix << " input memory did not allocate";
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
IE_THROW() << errorPrefix << "destination memory did not allocate";
const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor();
if (selected_pd == nullptr)
IE_THROW() << errorPrefix << "preferable primitive descriptor did not set";
auto inpDesc = getParentEdgeAt(0)->getMemory().GetDescWithType<DnnlMemoryDesc>();
const auto& in_candidate = inpDesc->getDnnlDesc();
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, in_candidate, size, alpha, beta, k)));
mkldnn::lrn_forward::primitive_desc prim_desc;
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(getEngine());
while (static_cast<bool>(itpd)) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
if (impl_type == selected_pd->getImplementationType()) {
prim_desc = itpd.get();
break;
}
if (!itpd.next_impl())
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
}
prim.reset(new mkldnn::lrn_forward(prim_desc)); prim.reset(new mkldnn::lrn_forward(prim_desc));
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive(); auto src = srcMemPtr->GetPrimitive();
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive(); auto dst = dstMemPtr->GetPrimitive();
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}; primArgs = { {DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst} };
} }
bool MKLDNNLrnNode::created() const { bool MKLDNNLrnNode::created() const {
@ -130,11 +161,21 @@ bool MKLDNNLrnNode::created() const {
void MKLDNNLrnNode::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc, void MKLDNNLrnNode::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
const std::vector<MemoryDescPtr> &outputDesc) { const std::vector<MemoryDescPtr> &outputDesc) {
mkldnn::algorithm alg = isAcrossMaps ? mkldnn::algorithm::lrn_across_channels : mkldnn::algorithm::lrn_within_channel; auto inpDesc = inputDesc[0]->isDefined() ? inputDesc[0] : MemoryDescUtils::makeDummyDesc(*inputDesc[0]);
DnnlMemoryDescPtr definedInpMemDesc = MemoryDescUtils::convertToDnnlMemoryDesc(inpDesc);
const auto& in_candidate = definedInpMemDesc->getDnnlDesc();
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>( MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, MemoryDescUtils::convertToDnnlMemoryDesc(inputDesc[0])->getDnnlDesc(), new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, in_candidate, size, alpha, beta, k)));
size, alpha, beta, k)));
descs.push_back(desc); descs.push_back(desc);
} }
std::vector<VectorDims> MKLDNNLrnNode::shapeInfer() const {
return { getParentEdgesAtPort(0).front()->getMemory().getStaticDims() };
}
void MKLDNNLrnNode::executeDynamicImpl(dnnl::stream strm) {
MKLDNNNode::execute(strm);
}
REG_MKLDNN_PRIM_FOR(MKLDNNLrnNode, Lrn); REG_MKLDNN_PRIM_FOR(MKLDNNLrnNode, Lrn);

View File

@ -29,10 +29,14 @@ public:
return false; return false;
} }
void prepareParams() override;
void executeDynamicImpl(mkldnn::stream strm) override;
std::vector<VectorDims> shapeInfer() const override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
private: private:
bool isAcrossMaps = false; mkldnn::algorithm alg;
size_t size = 1; size_t size = 1;
int k = 1; int k = 1;
float alpha = 1.0f; float alpha = 1.0f;

View File

@ -0,0 +1,116 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "functional_test_utils/ov_tensor_utils.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ngraph_functions/builders.hpp"
using namespace ngraph;
using namespace CPUTestUtils;
using namespace ov::test;
namespace CPULayerTestsDefinitions {
using LRNParams = std::tuple<
ElementType, // data precision
InputShape, // data shape
double, // alpha
double, // beta
double, // bias
size_t, // size
std::vector<int64_t>>; // axes to reduction
class LRNLayerCPUTest : public testing::WithParamInterface<LRNParams>, public ov::test::SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(testing::TestParamInfo<LRNParams> obj) {
ElementType inputPrecision;
InputShape inputShapes;
double alpha, beta, bias;
size_t size;
std::vector<int64_t> axes;
std::tie(inputPrecision, inputShapes, alpha, beta, bias, size, axes) = obj.param;
std::ostringstream result;
result << inputPrecision << "_" << "IS=" << CommonTestUtils::partialShape2str({ inputShapes.first }) << "_" << "TS=(";
for (const auto& shape : inputShapes.second) {
result << CommonTestUtils::vec2str(shape) << "_";
}
result << ")_alpha=" << alpha << "_beta=" << beta << "_bias=" << bias << "_size=" << size << "_axes=" << CommonTestUtils::vec2str(axes);
return result.str();
}
protected:
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
ElementType inputPrecision;
InputShape inputShapes;
double alpha, beta, bias;
size_t size;
std::vector<int64_t> axes;
std::tie(inputPrecision, inputShapes, alpha, beta, bias, size, axes) = this->GetParam();
init_input_shapes({ inputShapes });
selectedType = makeSelectedTypeStr("ref_any", inputPrecision);
auto params = ngraph::builder::makeDynamicParams(inputPrecision, { inputDynamicShapes });
auto axesNode = ngraph::opset1::Constant::create(ngraph::element::i32, { axes.size() }, axes);
auto lrn = std::make_shared<ngraph::opset3::LRN>(params[0], axesNode, alpha, beta, bias, size);
function = makeNgraphFunction(inputPrecision, params, lrn, "LRN");
}
};
TEST_P(LRNLayerCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
CheckPluginRelatedResults(executableNetwork, "LRN");
}
const std::vector<ElementType> inputPrecisions = {
ngraph::element::f32,
};
const std::vector<std::vector<std::int64_t>> axes = {
{ 1 },
{ 2, 3 },
{ 3, 2 },
{ 1, 2, 3 }
};
const std::vector<double> alpha = { 9.9e-05 };
const std::vector<double> beta = { 2. };
const std::vector<double> bias = { 1. };
const std::vector<size_t> size = { 5ul };
const std::vector<InputShape> inputShapes = {
InputShape{{}, {{10, 10, 3, 8}}},
InputShape{
// dynamic
{-1, -1, -1, -1},
// static
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
},
InputShape{
// dynamic
{{1, 15}, {3, 10}, {3, 7}, {5, 8}},
// static
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
},
};
const auto testCases = ::testing::Combine(
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(alpha),
::testing::ValuesIn(beta),
::testing::ValuesIn(bias),
::testing::ValuesIn(size),
::testing::ValuesIn(axes)
);
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs, LRNLayerCPUTest, testCases, LRNLayerCPUTest::getTestCaseName);
} // namespace CPULayerTestsDefinitions