[CPU] GatherElements support dynamic shape (#8663)

This commit is contained in:
Luo Cheng 2021-11-23 21:53:04 +08:00 committed by GitHub
parent 7720d82366
commit fa65018773
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 109 additions and 51 deletions

View File

@ -5,7 +5,6 @@
#include <cmath>
#include <vector>
#include <string>
#include <mkldnn_types.h>
#include "ie_parallel.hpp"
#include "mkldnn_gather_elements_node.h"
#include <ngraph/opsets/opset1.hpp>
@ -16,14 +15,10 @@
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
bool MKLDNNGatherElementsNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
bool MKLDNNGatherElementsNode::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (isDynamicNgraphNode(op)) {
errorMessage = "Doesn't support op with dynamic shapes";
return false;
}
const auto gatherElementsOp = ngraph::as_type_ptr<const ngraph::op::v6::GatherElements>(op);
if (!gatherElementsOp) {
if (!one_of(op->get_type_info(),
ov::op::v6::GatherElements::get_type_info_static())) {
errorMessage = "Node is not an instance of the GatherElements operation from operation set v6.";
return false;
}
@ -42,32 +37,43 @@ MKLDNNGatherElementsNode::MKLDNNGatherElementsNode(const std::shared_ptr<ngraph:
}
errorPrefix_ = std::string("Layer GatherElements with name '") + op->get_friendly_name() + "'";
if (op->get_input_size() != 2 || op->get_output_size() != 1)
if (inputShapes.size() != 2 || outputShapes.size() != 1)
IE_THROW() << errorPrefix_ << " has invalid number of input/output edges.";
const auto& dataDims = op->get_input_shape(dataIndex_);
const auto& indicesDims = op->get_input_shape(indicesIndex_);
if (dataDims.size() != indicesDims.size())
const auto dataRank = getInputShapeAtPort(dataIndex_).getRank();
const auto indicesRank = getInputShapeAtPort(indicesIndex_).getRank();
if (dataRank != indicesRank)
IE_THROW() << errorPrefix_ << " has invalid input shapes. Inputs 'Data' and 'Indices' must have equal ranks.";
auto gatherElementsOp = ngraph::as_type_ptr<const ngraph::op::v6::GatherElements>(op);
auto gatherElementsOp = ov::as_type_ptr<ov::op::v6::GatherElements>(op);
auto axis = gatherElementsOp->get_axis();
if (axis < 0)
axis += dataDims.size();
if (axis < 0 || axis >= static_cast<int>(dataDims.size()))
axis += dataRank;
if (axis < 0 || axis >= static_cast<int>(dataRank))
IE_THROW() << errorPrefix_ << " has invalid axis attribute: " << axis;
axis_ = axis;
}
auto outputShape = op->get_output_shape(0);
void MKLDNNGatherElementsNode::createPrimitive() {
if (inputShapesDefined()) {
if (needPrepareParams())
prepareParams();
updateLastInputDims();
}
}
void MKLDNNGatherElementsNode::prepareParams() {
const auto& dataDims = getParentEdgesAtPort(dataIndex_)[0]->getMemory().getStaticDims();
const auto& dstDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
strideAxDst_ = 1;
for (int i = outputShape.size() - 1; i > axis_; i--)
strideAxDst_ *= outputShape[i];
dstAxDim_ = op->get_output_shape(0)[axis_];
for (int i = dstDims.size() - 1; i > axis_; i--)
strideAxDst_ *= dstDims[i];
dstAxDim_ = dstDims[axis_];
if (axis_ > 0) {
strideAx1Diff_ = 1;
for (int i = dataDims.size() - 1; i >= axis_; i--)
strideAx1Diff_ *= dataDims[i];
strideAx1Diff_ -= strideAxDst_ * outputShape[axis_];
strideAx1Diff_ -= strideAxDst_ * dstDims[axis_];
}
}

View File

@ -18,11 +18,15 @@ public:
void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override;
void createPrimitive() override {};
void createPrimitive() override;
void execute(mkldnn::stream strm) override;
bool created() const override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
protected:
void executeDynamicImpl(mkldnn::stream strm) override { execute(strm); }
void prepareParams() override;
private:
const size_t dataIndex_ = 0;

View File

@ -2,88 +2,136 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <shared_test_classes/single_layer/gather_elements.hpp>
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ngraph_functions/builders.hpp"
#include "functional_test_utils/ov_tensor_utils.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace InferenceEngine;
using namespace ov::test;
using namespace ngraph;
using namespace CPUTestUtils;
using namespace InferenceEngine;
using namespace ngraph::helpers;
using namespace LayerTestsDefinitions;
namespace CPULayerTestsDefinitions {
typedef std::tuple<
using GatherElementsParams = std::tuple<
std::vector<InputShape>, // Dynamic shape + Target static shapes
int, // Axis
ElementType, // Data precision
ElementType, // Indices precision
TargetDevice // Device name
>;
using GatherElementsCPUTestParamSet = std::tuple<
GatherElementsParams,
CPUSpecificParams
> GatherElementsCPUTestParamSet;
>;
class GatherElementsCPUTest : public testing::WithParamInterface<GatherElementsCPUTestParamSet>,
virtual public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase {
virtual public ov::test::SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseNameCommon(const testing::TestParamInfo<GatherElementsParams>& obj) {
std::vector<InputShape> shapes;
ElementType dPrecision, iPrecision;
int axis;
std::string device;
std::tie(shapes, axis, dPrecision, iPrecision, device) = obj.param;
std::ostringstream result;
result << "IS=(";
for (const auto& shape : shapes) {
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
}
result << ")_TS=(";
for (const auto& shape : shapes) {
for (const auto& item : shape.second) {
result << CommonTestUtils::vec2str(item) << "_";
}
}
result << "Ax=" << axis << "_";
result << "DP=" << dPrecision << "_";
result << "IP=" << iPrecision << "_";
result << "device=" << device;
return result.str();
}
static std::string getTestCaseName(const testing::TestParamInfo<GatherElementsCPUTestParamSet> &obj) {
GatherElementsParams basicParamsSet;
CPUSpecificParams cpuParams;
std::tie(basicParamsSet, cpuParams) = obj.param;
std::ostringstream result;
result << GatherElementsLayerTest::getTestCaseName(testing::TestParamInfo<GatherElementsParams>(basicParamsSet, 0));
result << getTestCaseNameCommon(testing::TestParamInfo<GatherElementsParams>(basicParamsSet, 0));
result << CPUTestsBase::getTestCaseName(cpuParams);
return result.str();
}
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &info) const override {
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), 15, 0, 32768);
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (int i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::runtime::Tensor tensor;
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 15, 0, 32768);
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
}
}
protected:
void SetUp() override {
InferenceEngine::SizeVector dataShape, indicesShape;
InferenceEngine::Precision dPrecision, iPrecision;
std::vector<InputShape> shapes;
ElementType dPrecision, iPrecision;
int axis;
GatherElementsParams basicParamsSet;
CPUSpecificParams cpuParams;
std::tie(basicParamsSet, cpuParams) = this->GetParam();
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
std::tie(dataShape, indicesShape, axis, dPrecision, iPrecision, targetDevice) = basicParamsSet;
selectedType = std::string("ref_any_") + dPrecision.name();
std::tie(shapes, axis, dPrecision, iPrecision, targetDevice) = basicParamsSet;
selectedType = std::string("ref_any_") + ov::element::Type(dPrecision).get_type_name();
init_input_shapes(shapes);
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
ngraph::ParameterVector params = {
std::make_shared<ngraph::opset1::Parameter>(dPrecision, inputDynamicShapes[0]),
std::make_shared<ngraph::opset1::Parameter>(iPrecision, inputDynamicShapes[1]),
};
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
auto activation = ngraph::builder::makeGatherElements(params[0], indicesShape, ngIPrc, axis);
activation->get_rt_info() = getCPUInfo();
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{activation}, params, "GatherElements");
auto gather = std::make_shared<ngraph::op::v6::GatherElements>(
params[0], params[1], axis);
function = makeNgraphFunction(dPrecision, params, gather, "GatherElements");
}
};
TEST_P(GatherElementsCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
Run();
CheckPluginRelatedResults(executableNetwork, "GatherElements");
run();
}
namespace {
std::vector<CPUSpecificParams> cpuParams_4D = {
CPUSpecificParams({nchw}, {nchw}, {}, {})
};
const std::vector<std::vector<InputShape>> inDynamicShapeParams = {
{{{-1, -1, -1, -1}, {{2, 3, 5, 7}, {3, 4, 6, 8}}},
{{-1, -1, -1, -1}, {{2, 3, 9, 7}, {3, 4, 4, 8}}}},
{{{{1, 10}, {1, 10}, {1, 10}, {1, 10}}, {{3, 4, 6, 8}, {2, 3, 5, 7}}},
{{{1, 10}, {1, 10}, {1, 10}, {1, 10}}, {{3, 4, 4, 8}, {2, 3, 9, 7}}}}
};
INSTANTIATE_TEST_SUITE_P(smoke_set1, GatherElementsCPUTest,
::testing::Combine(
::testing::Combine(
::testing::Values(std::vector<size_t>({2, 3, 5, 7})), // Data shape
::testing::Values(std::vector<size_t>({2, 3, 9, 7})), // Indices shape
::testing::ValuesIn(inDynamicShapeParams), // shape
::testing::ValuesIn(std::vector<int>({2, -2})), // Axis
::testing::Values(Precision::BF16),
::testing::Values(Precision::I32),
::testing::ValuesIn(std::vector<ElementType>({ElementType::bf16, ElementType::f32})),
::testing::Values(ElementType::i32),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D))),
GatherElementsCPUTest::getTestCaseName);