[CPU] GatherElements support dynamic shape (#8663)
This commit is contained in:
parent
7720d82366
commit
fa65018773
@ -5,7 +5,6 @@
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <mkldnn_types.h>
|
||||
#include "ie_parallel.hpp"
|
||||
#include "mkldnn_gather_elements_node.h"
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
@ -16,14 +15,10 @@
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
bool MKLDNNGatherElementsNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
bool MKLDNNGatherElementsNode::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
if (isDynamicNgraphNode(op)) {
|
||||
errorMessage = "Doesn't support op with dynamic shapes";
|
||||
return false;
|
||||
}
|
||||
const auto gatherElementsOp = ngraph::as_type_ptr<const ngraph::op::v6::GatherElements>(op);
|
||||
if (!gatherElementsOp) {
|
||||
if (!one_of(op->get_type_info(),
|
||||
ov::op::v6::GatherElements::get_type_info_static())) {
|
||||
errorMessage = "Node is not an instance of the GatherElements operation from operation set v6.";
|
||||
return false;
|
||||
}
|
||||
@ -42,32 +37,43 @@ MKLDNNGatherElementsNode::MKLDNNGatherElementsNode(const std::shared_ptr<ngraph:
|
||||
}
|
||||
errorPrefix_ = std::string("Layer GatherElements with name '") + op->get_friendly_name() + "'";
|
||||
|
||||
if (op->get_input_size() != 2 || op->get_output_size() != 1)
|
||||
if (inputShapes.size() != 2 || outputShapes.size() != 1)
|
||||
IE_THROW() << errorPrefix_ << " has invalid number of input/output edges.";
|
||||
|
||||
const auto& dataDims = op->get_input_shape(dataIndex_);
|
||||
const auto& indicesDims = op->get_input_shape(indicesIndex_);
|
||||
if (dataDims.size() != indicesDims.size())
|
||||
const auto dataRank = getInputShapeAtPort(dataIndex_).getRank();
|
||||
const auto indicesRank = getInputShapeAtPort(indicesIndex_).getRank();
|
||||
if (dataRank != indicesRank)
|
||||
IE_THROW() << errorPrefix_ << " has invalid input shapes. Inputs 'Data' and 'Indices' must have equal ranks.";
|
||||
|
||||
auto gatherElementsOp = ngraph::as_type_ptr<const ngraph::op::v6::GatherElements>(op);
|
||||
auto gatherElementsOp = ov::as_type_ptr<ov::op::v6::GatherElements>(op);
|
||||
auto axis = gatherElementsOp->get_axis();
|
||||
if (axis < 0)
|
||||
axis += dataDims.size();
|
||||
if (axis < 0 || axis >= static_cast<int>(dataDims.size()))
|
||||
axis += dataRank;
|
||||
if (axis < 0 || axis >= static_cast<int>(dataRank))
|
||||
IE_THROW() << errorPrefix_ << " has invalid axis attribute: " << axis;
|
||||
axis_ = axis;
|
||||
}
|
||||
|
||||
auto outputShape = op->get_output_shape(0);
|
||||
void MKLDNNGatherElementsNode::createPrimitive() {
|
||||
if (inputShapesDefined()) {
|
||||
if (needPrepareParams())
|
||||
prepareParams();
|
||||
updateLastInputDims();
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNGatherElementsNode::prepareParams() {
|
||||
const auto& dataDims = getParentEdgesAtPort(dataIndex_)[0]->getMemory().getStaticDims();
|
||||
const auto& dstDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
strideAxDst_ = 1;
|
||||
for (int i = outputShape.size() - 1; i > axis_; i--)
|
||||
strideAxDst_ *= outputShape[i];
|
||||
dstAxDim_ = op->get_output_shape(0)[axis_];
|
||||
for (int i = dstDims.size() - 1; i > axis_; i--)
|
||||
strideAxDst_ *= dstDims[i];
|
||||
dstAxDim_ = dstDims[axis_];
|
||||
if (axis_ > 0) {
|
||||
strideAx1Diff_ = 1;
|
||||
for (int i = dataDims.size() - 1; i >= axis_; i--)
|
||||
strideAx1Diff_ *= dataDims[i];
|
||||
strideAx1Diff_ -= strideAxDst_ * outputShape[axis_];
|
||||
strideAx1Diff_ -= strideAxDst_ * dstDims[axis_];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,11 +18,15 @@ public:
|
||||
|
||||
void getSupportedDescriptors() override {};
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void createPrimitive() override {};
|
||||
void createPrimitive() override;
|
||||
void execute(mkldnn::stream strm) override;
|
||||
bool created() const override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
protected:
|
||||
void executeDynamicImpl(mkldnn::stream strm) override { execute(strm); }
|
||||
void prepareParams() override;
|
||||
|
||||
private:
|
||||
const size_t dataIndex_ = 0;
|
||||
|
@ -2,88 +2,136 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <shared_test_classes/single_layer/gather_elements.hpp>
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "functional_test_utils/ov_tensor_utils.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace ov::test;
|
||||
using namespace ngraph;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace InferenceEngine;
|
||||
using namespace ngraph::helpers;
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
typedef std::tuple<
|
||||
using GatherElementsParams = std::tuple<
|
||||
std::vector<InputShape>, // Dynamic shape + Target static shapes
|
||||
int, // Axis
|
||||
ElementType, // Data precision
|
||||
ElementType, // Indices precision
|
||||
TargetDevice // Device name
|
||||
>;
|
||||
|
||||
using GatherElementsCPUTestParamSet = std::tuple<
|
||||
GatherElementsParams,
|
||||
CPUSpecificParams
|
||||
> GatherElementsCPUTestParamSet;
|
||||
>;
|
||||
|
||||
class GatherElementsCPUTest : public testing::WithParamInterface<GatherElementsCPUTestParamSet>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase {
|
||||
virtual public ov::test::SubgraphBaseTest, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseNameCommon(const testing::TestParamInfo<GatherElementsParams>& obj) {
|
||||
std::vector<InputShape> shapes;
|
||||
ElementType dPrecision, iPrecision;
|
||||
int axis;
|
||||
std::string device;
|
||||
std::tie(shapes, axis, dPrecision, iPrecision, device) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=(";
|
||||
for (const auto& shape : shapes) {
|
||||
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
|
||||
}
|
||||
result << ")_TS=(";
|
||||
for (const auto& shape : shapes) {
|
||||
for (const auto& item : shape.second) {
|
||||
result << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
}
|
||||
result << "Ax=" << axis << "_";
|
||||
result << "DP=" << dPrecision << "_";
|
||||
result << "IP=" << iPrecision << "_";
|
||||
result << "device=" << device;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherElementsCPUTestParamSet> &obj) {
|
||||
GatherElementsParams basicParamsSet;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(basicParamsSet, cpuParams) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << GatherElementsLayerTest::getTestCaseName(testing::TestParamInfo<GatherElementsParams>(basicParamsSet, 0));
|
||||
result << getTestCaseNameCommon(testing::TestParamInfo<GatherElementsParams>(basicParamsSet, 0));
|
||||
|
||||
result << CPUTestsBase::getTestCaseName(cpuParams);
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &info) const override {
|
||||
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), 15, 0, 32768);
|
||||
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
|
||||
inputs.clear();
|
||||
const auto& funcInputs = function->inputs();
|
||||
for (int i = 0; i < funcInputs.size(); ++i) {
|
||||
const auto& funcInput = funcInputs[i];
|
||||
ov::runtime::Tensor tensor;
|
||||
|
||||
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 15, 0, 32768);
|
||||
|
||||
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||
InferenceEngine::Precision dPrecision, iPrecision;
|
||||
std::vector<InputShape> shapes;
|
||||
ElementType dPrecision, iPrecision;
|
||||
int axis;
|
||||
|
||||
GatherElementsParams basicParamsSet;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(basicParamsSet, cpuParams) = this->GetParam();
|
||||
|
||||
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
||||
|
||||
std::tie(dataShape, indicesShape, axis, dPrecision, iPrecision, targetDevice) = basicParamsSet;
|
||||
selectedType = std::string("ref_any_") + dPrecision.name();
|
||||
std::tie(shapes, axis, dPrecision, iPrecision, targetDevice) = basicParamsSet;
|
||||
selectedType = std::string("ref_any_") + ov::element::Type(dPrecision).get_type_name();
|
||||
init_input_shapes(shapes);
|
||||
|
||||
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
|
||||
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
|
||||
ngraph::ParameterVector params = {
|
||||
std::make_shared<ngraph::opset1::Parameter>(dPrecision, inputDynamicShapes[0]),
|
||||
std::make_shared<ngraph::opset1::Parameter>(iPrecision, inputDynamicShapes[1]),
|
||||
};
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
|
||||
auto activation = ngraph::builder::makeGatherElements(params[0], indicesShape, ngIPrc, axis);
|
||||
activation->get_rt_info() = getCPUInfo();
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{activation}, params, "GatherElements");
|
||||
auto gather = std::make_shared<ngraph::op::v6::GatherElements>(
|
||||
params[0], params[1], axis);
|
||||
function = makeNgraphFunction(dPrecision, params, gather, "GatherElements");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(GatherElementsCPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
CheckPluginRelatedResults(executableNetwork, "GatherElements");
|
||||
run();
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
std::vector<CPUSpecificParams> cpuParams_4D = {
|
||||
CPUSpecificParams({nchw}, {nchw}, {}, {})
|
||||
};
|
||||
|
||||
const std::vector<std::vector<InputShape>> inDynamicShapeParams = {
|
||||
{{{-1, -1, -1, -1}, {{2, 3, 5, 7}, {3, 4, 6, 8}}},
|
||||
{{-1, -1, -1, -1}, {{2, 3, 9, 7}, {3, 4, 4, 8}}}},
|
||||
{{{{1, 10}, {1, 10}, {1, 10}, {1, 10}}, {{3, 4, 6, 8}, {2, 3, 5, 7}}},
|
||||
{{{1, 10}, {1, 10}, {1, 10}, {1, 10}}, {{3, 4, 4, 8}, {2, 3, 9, 7}}}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_set1, GatherElementsCPUTest,
|
||||
::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 3, 5, 7})), // Data shape
|
||||
::testing::Values(std::vector<size_t>({2, 3, 9, 7})), // Indices shape
|
||||
::testing::ValuesIn(inDynamicShapeParams), // shape
|
||||
::testing::ValuesIn(std::vector<int>({2, -2})), // Axis
|
||||
::testing::Values(Precision::BF16),
|
||||
::testing::Values(Precision::I32),
|
||||
::testing::ValuesIn(std::vector<ElementType>({ElementType::bf16, ElementType::f32})),
|
||||
::testing::Values(ElementType::i32),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D))),
|
||||
GatherElementsCPUTest::getTestCaseName);
|
||||
|
Loading…
Reference in New Issue
Block a user