[CPU] LRN: dynamic shapes support (#8724)
This commit is contained in:
parent
bcf0879785
commit
734185c04c
@ -14,23 +14,18 @@ using namespace InferenceEngine;
|
|||||||
|
|
||||||
bool MKLDNNLrnNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
bool MKLDNNLrnNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||||
try {
|
try {
|
||||||
if (isDynamicNgraphNode(op)) {
|
auto lrn = ngraph::as_type_ptr<const ngraph::opset1::LRN>(op);
|
||||||
errorMessage = "Doesn't support op with dynamic shapes";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto lrn = std::dynamic_pointer_cast<const ngraph::opset1::LRN>(op);
|
|
||||||
if (!lrn) {
|
if (!lrn) {
|
||||||
errorMessage = "Only opset1 LRN operation is supported";
|
errorMessage = "Only opset1 LRN operation is supported";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto dataDims = lrn->get_input_shape(0);
|
const auto& dataDims = lrn->get_input_partial_shape(0);
|
||||||
if (dataDims.size() < 2 || dataDims.size() > 5) {
|
if (dataDims.size() < 2 || dataDims.size() > 5) {
|
||||||
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(dataDims.size());
|
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(dataDims.size());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const auto axesNode = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1));
|
auto axesNode = ngraph::as_type_ptr<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1));
|
||||||
if (!axesNode) {
|
if (!axesNode) {
|
||||||
errorMessage = "Only Constant operation on 'axis' input is supported";
|
errorMessage = "Only Constant operation on 'axis' input is supported";
|
||||||
return false;
|
return false;
|
||||||
@ -69,9 +64,10 @@ MKLDNNLrnNode::MKLDNNLrnNode(const std::shared_ptr<ngraph::Node>& op, const mkld
|
|||||||
if (isSupportedOperation(op, errorMessage)) {
|
if (isSupportedOperation(op, errorMessage)) {
|
||||||
errorPrefix = "LRN node with name '" + getName() + "'";
|
errorPrefix = "LRN node with name '" + getName() + "'";
|
||||||
|
|
||||||
const auto lrn = std::dynamic_pointer_cast<const ngraph::opset1::LRN>(op);
|
auto lrn = ngraph::as_type_ptr<const ngraph::opset1::LRN>(op);
|
||||||
const auto axes = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1))->cast_vector<int64_t>();
|
auto axes = ngraph::as_type_ptr<const ngraph::opset1::Constant>(lrn->get_input_node_shared_ptr(1))->cast_vector<int64_t>();
|
||||||
isAcrossMaps = (axes.size() == 1 && axes[0] == 1);
|
bool isAcrossMaps = (axes.size() == 1 && axes[0] == 1);
|
||||||
|
alg = isAcrossMaps ? mkldnn::algorithm::lrn_across_channels : mkldnn::algorithm::lrn_within_channel;
|
||||||
alpha = static_cast<float>(lrn->get_alpha());
|
alpha = static_cast<float>(lrn->get_alpha());
|
||||||
beta = static_cast<float>(lrn->get_beta());
|
beta = static_cast<float>(lrn->get_beta());
|
||||||
k = static_cast<float>(lrn->get_bias());
|
k = static_cast<float>(lrn->get_bias());
|
||||||
@ -107,21 +103,56 @@ std::shared_ptr<MemoryDesc> MKLDNNLrnNode::getSrcMemDesc(mkldnn::primitive_desc_
|
|||||||
if (idx > 0) {
|
if (idx > 0) {
|
||||||
return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), getInputShapeAtPort(idx));
|
return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), getInputShapeAtPort(idx));
|
||||||
} else {
|
} else {
|
||||||
return MKLDNNExtensionUtils::makeDescriptor(primitive_desc_it.dst_desc(idx));
|
if (getInputShapeAtPort(idx).isDynamic()) {
|
||||||
|
return MKLDNNExtensionUtils::makeUndefinedDesc(primitive_desc_it.src_desc(idx), getInputShapeAtPort(idx));
|
||||||
|
}
|
||||||
|
return MKLDNNExtensionUtils::makeDescriptor(primitive_desc_it.src_desc(idx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MKLDNNLrnNode::createPrimitive() {
|
void MKLDNNLrnNode::createPrimitive() {
|
||||||
if (prim)
|
if (inputShapesDefined()) {
|
||||||
return;
|
if (needPrepareParams())
|
||||||
|
prepareParams();
|
||||||
|
updateLastInputDims();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto prim_desc = createPrimitiveDescriptor<mkldnn::lrn_forward::primitive_desc, mkldnn::lrn_forward::desc>();
|
void MKLDNNLrnNode::prepareParams() {
|
||||||
|
auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
|
||||||
|
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
|
||||||
|
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
|
||||||
|
IE_THROW() << errorPrefix << " input memory did not allocate";
|
||||||
|
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
|
||||||
|
IE_THROW() << errorPrefix << "destination memory did not allocate";
|
||||||
|
|
||||||
|
const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor();
|
||||||
|
if (selected_pd == nullptr)
|
||||||
|
IE_THROW() << errorPrefix << "preferable primitive descriptor did not set";
|
||||||
|
|
||||||
|
auto inpDesc = getParentEdgeAt(0)->getMemory().GetDescWithType<DnnlMemoryDesc>();
|
||||||
|
const auto& in_candidate = inpDesc->getDnnlDesc();
|
||||||
|
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
|
||||||
|
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, in_candidate, size, alpha, beta, k)));
|
||||||
|
|
||||||
|
mkldnn::lrn_forward::primitive_desc prim_desc;
|
||||||
|
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(getEngine());
|
||||||
|
while (static_cast<bool>(itpd)) {
|
||||||
|
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
|
||||||
|
|
||||||
|
if (impl_type == selected_pd->getImplementationType()) {
|
||||||
|
prim_desc = itpd.get();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!itpd.next_impl())
|
||||||
|
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||||
|
}
|
||||||
|
|
||||||
prim.reset(new mkldnn::lrn_forward(prim_desc));
|
prim.reset(new mkldnn::lrn_forward(prim_desc));
|
||||||
|
|
||||||
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
|
auto src = srcMemPtr->GetPrimitive();
|
||||||
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
|
auto dst = dstMemPtr->GetPrimitive();
|
||||||
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}};
|
primArgs = { {DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst} };
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MKLDNNLrnNode::created() const {
|
bool MKLDNNLrnNode::created() const {
|
||||||
@ -130,11 +161,21 @@ bool MKLDNNLrnNode::created() const {
|
|||||||
|
|
||||||
void MKLDNNLrnNode::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
|
void MKLDNNLrnNode::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
|
||||||
const std::vector<MemoryDescPtr> &outputDesc) {
|
const std::vector<MemoryDescPtr> &outputDesc) {
|
||||||
mkldnn::algorithm alg = isAcrossMaps ? mkldnn::algorithm::lrn_across_channels : mkldnn::algorithm::lrn_within_channel;
|
auto inpDesc = inputDesc[0]->isDefined() ? inputDesc[0] : MemoryDescUtils::makeDummyDesc(*inputDesc[0]);
|
||||||
|
DnnlMemoryDescPtr definedInpMemDesc = MemoryDescUtils::convertToDnnlMemoryDesc(inpDesc);
|
||||||
|
const auto& in_candidate = definedInpMemDesc->getDnnlDesc();
|
||||||
|
|
||||||
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
|
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
|
||||||
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, MemoryDescUtils::convertToDnnlMemoryDesc(inputDesc[0])->getDnnlDesc(),
|
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, in_candidate, size, alpha, beta, k)));
|
||||||
size, alpha, beta, k)));
|
|
||||||
descs.push_back(desc);
|
descs.push_back(desc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<VectorDims> MKLDNNLrnNode::shapeInfer() const {
|
||||||
|
return { getParentEdgesAtPort(0).front()->getMemory().getStaticDims() };
|
||||||
|
}
|
||||||
|
|
||||||
|
void MKLDNNLrnNode::executeDynamicImpl(dnnl::stream strm) {
|
||||||
|
MKLDNNNode::execute(strm);
|
||||||
|
}
|
||||||
|
|
||||||
REG_MKLDNN_PRIM_FOR(MKLDNNLrnNode, Lrn);
|
REG_MKLDNN_PRIM_FOR(MKLDNNLrnNode, Lrn);
|
||||||
|
@ -29,10 +29,14 @@ public:
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void prepareParams() override;
|
||||||
|
void executeDynamicImpl(mkldnn::stream strm) override;
|
||||||
|
std::vector<VectorDims> shapeInfer() const override;
|
||||||
|
|
||||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool isAcrossMaps = false;
|
mkldnn::algorithm alg;
|
||||||
size_t size = 1;
|
size_t size = 1;
|
||||||
int k = 1;
|
int k = 1;
|
||||||
float alpha = 1.0f;
|
float alpha = 1.0f;
|
||||||
|
@ -0,0 +1,116 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "functional_test_utils/ov_tensor_utils.hpp"
|
||||||
|
#include "test_utils/cpu_test_utils.hpp"
|
||||||
|
|
||||||
|
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||||
|
#include "ngraph_functions/builders.hpp"
|
||||||
|
|
||||||
|
using namespace ngraph;
|
||||||
|
using namespace CPUTestUtils;
|
||||||
|
using namespace ov::test;
|
||||||
|
|
||||||
|
namespace CPULayerTestsDefinitions {
|
||||||
|
using LRNParams = std::tuple<
|
||||||
|
ElementType, // data precision
|
||||||
|
InputShape, // data shape
|
||||||
|
double, // alpha
|
||||||
|
double, // beta
|
||||||
|
double, // bias
|
||||||
|
size_t, // size
|
||||||
|
std::vector<int64_t>>; // axes to reduction
|
||||||
|
|
||||||
|
class LRNLayerCPUTest : public testing::WithParamInterface<LRNParams>, public ov::test::SubgraphBaseTest, public CPUTestsBase {
|
||||||
|
public:
|
||||||
|
static std::string getTestCaseName(testing::TestParamInfo<LRNParams> obj) {
|
||||||
|
ElementType inputPrecision;
|
||||||
|
InputShape inputShapes;
|
||||||
|
double alpha, beta, bias;
|
||||||
|
size_t size;
|
||||||
|
std::vector<int64_t> axes;
|
||||||
|
std::tie(inputPrecision, inputShapes, alpha, beta, bias, size, axes) = obj.param;
|
||||||
|
|
||||||
|
std::ostringstream result;
|
||||||
|
result << inputPrecision << "_" << "IS=" << CommonTestUtils::partialShape2str({ inputShapes.first }) << "_" << "TS=(";
|
||||||
|
for (const auto& shape : inputShapes.second) {
|
||||||
|
result << CommonTestUtils::vec2str(shape) << "_";
|
||||||
|
}
|
||||||
|
|
||||||
|
result << ")_alpha=" << alpha << "_beta=" << beta << "_bias=" << bias << "_size=" << size << "_axes=" << CommonTestUtils::vec2str(axes);
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||||
|
ElementType inputPrecision;
|
||||||
|
InputShape inputShapes;
|
||||||
|
double alpha, beta, bias;
|
||||||
|
size_t size;
|
||||||
|
std::vector<int64_t> axes;
|
||||||
|
std::tie(inputPrecision, inputShapes, alpha, beta, bias, size, axes) = this->GetParam();
|
||||||
|
|
||||||
|
init_input_shapes({ inputShapes });
|
||||||
|
selectedType = makeSelectedTypeStr("ref_any", inputPrecision);
|
||||||
|
|
||||||
|
auto params = ngraph::builder::makeDynamicParams(inputPrecision, { inputDynamicShapes });
|
||||||
|
auto axesNode = ngraph::opset1::Constant::create(ngraph::element::i32, { axes.size() }, axes);
|
||||||
|
auto lrn = std::make_shared<ngraph::opset3::LRN>(params[0], axesNode, alpha, beta, bias, size);
|
||||||
|
function = makeNgraphFunction(inputPrecision, params, lrn, "LRN");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(LRNLayerCPUTest, CompareWithRefs) {
|
||||||
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
|
run();
|
||||||
|
CheckPluginRelatedResults(executableNetwork, "LRN");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<ElementType> inputPrecisions = {
|
||||||
|
ngraph::element::f32,
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<std::vector<std::int64_t>> axes = {
|
||||||
|
{ 1 },
|
||||||
|
{ 2, 3 },
|
||||||
|
{ 3, 2 },
|
||||||
|
{ 1, 2, 3 }
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<double> alpha = { 9.9e-05 };
|
||||||
|
const std::vector<double> beta = { 2. };
|
||||||
|
const std::vector<double> bias = { 1. };
|
||||||
|
const std::vector<size_t> size = { 5ul };
|
||||||
|
|
||||||
|
const std::vector<InputShape> inputShapes = {
|
||||||
|
InputShape{{}, {{10, 10, 3, 8}}},
|
||||||
|
InputShape{
|
||||||
|
// dynamic
|
||||||
|
{-1, -1, -1, -1},
|
||||||
|
// static
|
||||||
|
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
|
||||||
|
},
|
||||||
|
InputShape{
|
||||||
|
// dynamic
|
||||||
|
{{1, 15}, {3, 10}, {3, 7}, {5, 8}},
|
||||||
|
// static
|
||||||
|
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto testCases = ::testing::Combine(
|
||||||
|
::testing::ValuesIn(inputPrecisions),
|
||||||
|
::testing::ValuesIn(inputShapes),
|
||||||
|
::testing::ValuesIn(alpha),
|
||||||
|
::testing::ValuesIn(beta),
|
||||||
|
::testing::ValuesIn(bias),
|
||||||
|
::testing::ValuesIn(size),
|
||||||
|
::testing::ValuesIn(axes)
|
||||||
|
);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs, LRNLayerCPUTest, testCases, LRNLayerCPUTest::getTestCaseName);
|
||||||
|
|
||||||
|
} // namespace CPULayerTestsDefinitions
|
Loading…
Reference in New Issue
Block a user