[IE][VPU][Tests]: Support DTS for GatherElements (#3688)

* Support DTS for GatherElements
* Extract GatherBase to a common part
* Introduce tests on inference
* Introduce tests on function comparing
* Disable failing tests
This commit is contained in:
Andrew Bakalin 2021-01-20 12:16:46 +03:00 committed by GitHub
parent e08ad2989e
commit 7c4f435335
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 316 additions and 54 deletions

View File

@ -0,0 +1,13 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/node.hpp"
namespace vpu {
void dynamicToStaticShapeGatherElements(std::shared_ptr<ngraph::Node> node);
} // namespace vpu

View File

@ -8,6 +8,7 @@
#include "vpu/ngraph/transformations/dynamic_to_static_shape_broadcast.hpp"
#include "vpu/ngraph/transformations/dynamic_to_static_shape_concat.hpp"
#include "vpu/ngraph/transformations/dynamic_to_static_shape_gather.hpp"
#include "vpu/ngraph/transformations/dynamic_to_static_shape_gather_elements.hpp"
#include "vpu/ngraph/transformations/dynamic_to_static_shape_gather_nd.hpp"
#include "vpu/ngraph/transformations/dynamic_to_static_shape_matmul.hpp"
#include "vpu/ngraph/transformations/dynamic_to_static_shape_non_max_suppression.hpp"
@ -30,6 +31,7 @@
#include "ngraph/opsets/opset3.hpp"
#include <ngraph/validation_util.hpp>
#include "ngraph/opsets/opset5.hpp"
#include "ngraph/opsets/opset6.hpp"
namespace vpu {
@ -130,6 +132,7 @@ const Transformations& getDefaultTransformations() {
{ngraph::opset3::MatMul::type_info, dynamicToStaticShapeMatMul},
{ngraph::opset5::Split::type_info, dynamicToStaticShapeSplit},
{ngraph::opset5::GatherND::type_info, dynamicToStaticShapeGatherND},
{ngraph::opset6::GatherElements::type_info, dynamicToStaticShapeGatherElements},
// reduction
{ngraph::opset3::ReduceLogicalAnd::type_info, dynamicToStaticShapeReduce},

View File

@ -0,0 +1,31 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "vpu/ngraph/transformations/dynamic_to_static_shape_gather_elements.hpp"
#include "vpu/ngraph/operations/dynamic_shape_resolver.hpp"
#include "vpu/ngraph/utilities.hpp"
#include <vpu/utils/error.hpp>
#include "ngraph/ops.hpp"
#include <memory>
namespace vpu {
void dynamicToStaticShapeGatherElements(std::shared_ptr<ngraph::Node> target) {
const auto dsr = target->input_value(1).get_node_shared_ptr();
VPU_THROW_UNLESS(ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(dsr),
"DynamicToStaticShape transformation for {} of type {} expects {} as input with index {}",
target->get_friendly_name(), target->get_type_info(), ngraph::vpu::op::DynamicShapeResolver::type_info, 1);
const auto shape = dsr->input(1).get_source_output();
const auto copied = target->clone_with_new_inputs(target->input_values());
auto outDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(copied, shape);
outDSR->set_friendly_name(target->get_friendly_name());
ngraph::replace_node(target, std::move(outDSR));
}
} // namespace vpu

View File

@ -0,0 +1,133 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vpu/ngraph/transformations/dynamic_to_static_shape_gather_elements.hpp>
#include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
#include <common_test_utils/test_common.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/op/op.hpp>
namespace {
using DataType = ngraph::element::Type_t;
using DataDims = ngraph::Shape;
struct GatherElementsTestCase {
ngraph::Shape dataShape, indexShape;
int64_t axis;
};
enum class DataShapeType {
DYNAMIC,
STATIC
};
const auto combinations = testing::Combine(
testing::Values(
ngraph::element::f16,
ngraph::element::f32,
ngraph::element::i32,
ngraph::element::i64,
ngraph::element::u8),
testing::Values(
ngraph::element::i32,
ngraph::element::i64),
testing::Values(
GatherElementsTestCase{{6, 4, 20, 28}, {15, 4, 20, 28}, 0},
GatherElementsTestCase{{6, 12, 10, 24}, {3, 12, 10, 24}, 0},
GatherElementsTestCase{{6, 12}, {6, 20}, 1},
GatherElementsTestCase{{6, 12, 10, 24}, {6, 12, 10, 28}, 3},
GatherElementsTestCase{{6, 12, 10, 24}, {6, 12, 10, 28}, -1},
GatherElementsTestCase{{6, 12, 10, 24}, {15, 12, 10, 24}, -4}),
testing::Values(
DataShapeType::DYNAMIC,
DataShapeType::STATIC));
class DynamicToStaticShapeGatherElements : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<DataType, DataType, GatherElementsTestCase, DataShapeType>> {
public:
void SetUp() override {
const auto& parameters = GetParam();
const auto& dataType = std::get<0>(parameters);
const auto& idxType = std::get<1>(parameters);
const auto& gatherElementsSetup = std::get<2>(parameters);
const auto& dataShapeType = std::get<3>(parameters);
ngraph::helpers::CompareFunctions(*transform(dataType, idxType, gatherElementsSetup, dataShapeType),
*reference(dataType, idxType, gatherElementsSetup, dataShapeType));
}
protected:
std::shared_ptr<const ngraph::Function> transform(
const ngraph::element::Type_t& dataType,
const ngraph::element::Type_t& idxType,
const GatherElementsTestCase& gatherElementsSetup,
DataShapeType dataShapeType) const {
const auto data = std::make_shared<ngraph::opset6::Parameter>(dataType, gatherElementsSetup.dataShape);
const auto indices = std::make_shared<ngraph::opset6::Parameter>(idxType, gatherElementsSetup.indexShape);
const auto indicesDims = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{gatherElementsSetup.indexShape.size()});
const auto indicesDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
ngraph::ParameterVector params{data, indices, indicesDims};
std::shared_ptr<ngraph::Node> gatherData = data;
if (dataShapeType == DataShapeType::DYNAMIC) {
params.push_back(std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{gatherElementsSetup.dataShape.size()}));
gatherData = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, params.back());
}
const auto node = std::make_shared<ngraph::opset6::GatherElements>(gatherData, indicesDsr, gatherElementsSetup.axis);
const auto function = std::make_shared<ngraph::Function>(
ngraph::NodeVector{node},
params,
"Actual");
node->set_output_type(0, dataType, ngraph::PartialShape::dynamic(1));
const auto transformations = vpu::Transformations{{node->type_info, vpu::dynamicToStaticShapeGatherElements}};
vpu::DynamicToStaticShape(transformations).run_on_function(function);
return function;
}
std::shared_ptr<const ngraph::Function> reference(
const ngraph::element::Type_t& dataType,
const ngraph::element::Type_t& idxType,
const GatherElementsTestCase& gatherElementsSetup,
DataShapeType dataShapeType) const {
const auto data = std::make_shared<ngraph::opset6::Parameter>(dataType, gatherElementsSetup.dataShape);
const auto indices = std::make_shared<ngraph::opset6::Parameter>(idxType, gatherElementsSetup.indexShape);
const auto indicesDims = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{gatherElementsSetup.indexShape.size()});
const auto indicesDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
ngraph::ParameterVector params{data, indices, indicesDims};
std::shared_ptr<ngraph::Node> gatherData = data;
if (dataShapeType == DataShapeType::DYNAMIC) {
params.push_back(std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{gatherElementsSetup.dataShape.size()}));
gatherData = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, params.back());
}
const auto node = std::make_shared<ngraph::op::v6::GatherElements>(gatherData, indicesDsr, gatherElementsSetup.axis);
const auto outDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, indicesDims);
return std::make_shared<ngraph::Function>(
ngraph::NodeVector{outDsr},
params,
"Expected");
}
};
TEST_P(DynamicToStaticShapeGatherElements, CompareFunctions) {
}
INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticShapeGatherElements, combinations);
} // namespace

View File

@ -35,6 +35,8 @@ std::vector<std::string> disabledTestPatterns() {
// TODO: Issue 43781
".*ROIPoolingLayerTest.*",
// TODO: Issue 26090
".*DSR_GatherStaticDataDynamicIdx.*f32.*1.3.200.304.*"
".*DSR_GatherStaticDataDynamicIdx.*f32.*1.3.200.304.*",
// TODO: Issue 46755
".*DSR_GatherElements.*"
};
}

View File

@ -3,6 +3,7 @@
//
#include "dsr_tests_common.hpp"
#include "dsr_gather_base.hpp"
#include <shared_test_classes/base/layer_test_utils.hpp>
#include <ngraph_functions/builders.hpp>
@ -23,59 +24,6 @@ const std::vector<ngraph::element::Type> idxTypeVector = {
ngraph::element::i32,
};
struct GatherTestCase {
DataShapeWithUpperBound inputShapes;
DataShapeWithUpperBound indexShape;
int64_t axis;
};
using GatherParameters = std::tuple<
DataType,
DataType,
GatherTestCase,
LayerTestsUtils::TargetDevice
>;
class DSR_GatherBase : public testing::WithParamInterface<GatherParameters>,
public DSR_TestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<GatherParameters> obj) {
DataType dataType, idxType;
GatherTestCase gatherTestCase;
LayerTestsUtils::TargetDevice targetDevice;
std::tie(dataType, idxType, gatherTestCase, targetDevice) = obj.param;
std::ostringstream result;
result << "DT=" << dataType << "_";
result << "IT=" << idxType << "_";
result << "DataRealShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.shape) << "_";
result << "DataUBShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.upperBoundShape) << "_";
result << "IdxRealShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.shape) << "_";
result << "IdxUBShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.upperBoundShape) << "_";
result << "Axis=" << gatherTestCase.axis << "_";
result << "trgDev=" << targetDevice;
return result.str();
}
protected:
std::set<std::string> m_indicesInputNames;
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override {
const auto& name = info.name();
if (m_indicesInputNames.count(name)) {
const auto& parameters = GetParam();
const auto& gatherSetup = std::get<2>(parameters);
const auto& inputRank = gatherSetup.inputShapes.shape.size();
const auto axis = gatherSetup.axis < 0 ? gatherSetup.axis + inputRank : gatherSetup.axis;
const auto endValue = gatherSetup.inputShapes.shape[axis] - 1;
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), endValue, 0);
}
return DSR_TestsCommon::GenerateInput(info);
}
};
class DSR_GatherDynamicDataStaticIdx : public DSR_GatherBase {
protected:
std::shared_ptr<ngraph::Node> createTestedOp() override {

View File

@ -0,0 +1,66 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "dsr_tests_common.hpp"
namespace LayerTestsUtils {
namespace vpu {
struct GatherTestCase {
DataShapeWithUpperBound inputShapes;
DataShapeWithUpperBound indexShape;
int64_t axis;
};
using GatherParameters = std::tuple<
DataType,
DataType,
GatherTestCase,
LayerTestsUtils::TargetDevice
>;
class DSR_GatherBase : public testing::WithParamInterface<GatherParameters>,
public DSR_TestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<GatherParameters> obj) {
DataType dataType, idxType;
GatherTestCase gatherTestCase;
LayerTestsUtils::TargetDevice targetDevice;
std::tie(dataType, idxType, gatherTestCase, targetDevice) = obj.param;
std::ostringstream result;
result << "DT=" << dataType << "_";
result << "IT=" << idxType << "_";
result << "DataRealShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.shape) << "_";
result << "DataUBShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.upperBoundShape) << "_";
result << "IdxRealShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.shape) << "_";
result << "IdxUBShape=" << CommonTestUtils::vec2str(gatherTestCase.inputShapes.upperBoundShape) << "_";
result << "Axis=" << gatherTestCase.axis << "_";
result << "trgDev=" << targetDevice;
return result.str();
}
protected:
std::set<std::string> m_indicesInputNames;
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override {
const auto& name = info.name();
if (m_indicesInputNames.count(name)) {
const auto& parameters = GetParam();
const auto& gatherSetup = std::get<2>(parameters);
const auto& inputRank = gatherSetup.inputShapes.shape.size();
const auto axis = gatherSetup.axis < 0 ? gatherSetup.axis + inputRank : gatherSetup.axis;
const auto endValue = gatherSetup.inputShapes.shape[axis] - 1;
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), endValue, 0);
}
return DSR_TestsCommon::GenerateInput(info);
}
};
} // namespace vpu
} // namespace LayerTestsUtils

View File

@ -0,0 +1,66 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "dsr_tests_common.hpp"
#include "dsr_gather_base.hpp"
#include <shared_test_classes/base/layer_test_utils.hpp>
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
#include <ngraph/opsets/opset6.hpp>
namespace {
using namespace LayerTestsUtils::vpu;
const std::vector<ngraph::element::Type> dataTypeVector = {
ngraph::element::f16,
ngraph::element::f32,
ngraph::element::i32,
};
const std::vector<ngraph::element::Type> idxTypeVector = {
ngraph::element::i32,
};
class DSR_GatherElements : public DSR_GatherBase {
protected:
std::shared_ptr<ngraph::Node> createTestedOp() override {
SetRefMode(LayerTestsUtils::RefMode::INTERPRETER);
const auto& parameters = GetParam();
const auto& inDataType = std::get<0>(parameters);
const auto& idxType = std::get<1>(parameters);
const auto& gatherSetup = std::get<2>(parameters);
targetDevice = std::get<3>(parameters);
const auto dataParam = std::make_shared<ngraph::opset6::Parameter>(inDataType, gatherSetup.inputShapes.shape);
m_parameterVector.push_back(dataParam);
const auto inputIdxSubgraph = createInputSubgraphWithDSR(idxType, gatherSetup.indexShape);
m_indicesInputNames.insert(inputIdxSubgraph->get_input_node_shared_ptr(0)->get_friendly_name());
const auto gather = std::make_shared<ngraph::opset6::GatherElements>(dataParam, inputIdxSubgraph, gatherSetup.axis);
return gather;
}
};
TEST_P(DSR_GatherElements, CompareWithReference) {
Run();
}
INSTANTIATE_TEST_CASE_P(smoke_DynamicGatherElements, DSR_GatherElements,
testing::Combine(
testing::ValuesIn(dataTypeVector),
testing::ValuesIn(idxTypeVector),
testing::Values(
GatherTestCase{DataShapeWithUpperBound{{1000}, {}}, DataShapeWithUpperBound{{800}, {1000}}, 0},
GatherTestCase{DataShapeWithUpperBound{{1000, 4}, {}}, DataShapeWithUpperBound{{100, 4}, {800, 4}}, 0},
GatherTestCase{DataShapeWithUpperBound{{4, 1000}, {}}, DataShapeWithUpperBound{{4, 100}, {4, 800}}, 1},
GatherTestCase{DataShapeWithUpperBound{{300, 3, 64, 608}, {}}, DataShapeWithUpperBound{{300, 3, 64, 60}, {300, 3, 64, 64}}, 3},
GatherTestCase{DataShapeWithUpperBound{{800}, {1000}}, DataShapeWithUpperBound{{200}, {800}}, 0},
GatherTestCase{DataShapeWithUpperBound{{800, 4}, {1000, 4}}, DataShapeWithUpperBound{{300, 4}, {800, 4}}, 0},
GatherTestCase{DataShapeWithUpperBound{{4, 800}, {4, 1000}}, DataShapeWithUpperBound{{4, 700}, {4, 750}}, 1}),
testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
} // namespace