[IE][VPU][Tests]: Support DTS for ScatterElementsUpdate (#3559)

* Enable DTS for ScatterElementsUpdate
* Update DTS tests
* Update inference tests
This commit is contained in:
Andrew Bakalin 2020-12-11 12:46:26 +03:00 committed by GitHub
parent a0952798ba
commit d90c05aab4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 72 deletions

View File

@ -116,6 +116,7 @@ const Transformations& getDefaultTransformations() {
{ngraph::opset3::Exp::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Sqrt::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::LogicalNot::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset5::ScatterElementsUpdate::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::StridedSlice::type_info, dynamicToStaticShapeStridedSlice},
{ngraph::opset3::Squeeze::type_info, dynamicToStaticShapeSqueeze},
{ngraph::opset3::Gather::type_info, dynamicToStaticShapeGather},

View File

@ -5,6 +5,7 @@
#include <common_test_utils/test_common.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
#include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
#include <vpu/ngraph/transformations/dynamic_to_static_shape_unary_elementwise.hpp>
@ -19,71 +20,113 @@ using DataType = ngraph::element::Type_t;
struct ScatterTestCase {
ngraph::NodeTypeInfo scatter_type_info;
ngraph::Shape data_shape, indices_shape, updates_shape;
ngraph::NodeTypeInfo scatterTypeInfo;
ngraph::Shape dataShape, indicesShape, updatesShape;
int64_t axis;
};
enum class ShapeType {
DYNAMIC,
STATIC
};
using ScatterParameters = std::tuple<
DataType,
DataType,
ScatterTestCase,
ShapeType>;
class DynamicToStaticShapeScatter : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<DataType, DataType, ScatterTestCase>> {
public testing::WithParamInterface<ScatterParameters> {
public:
void SetUp() override {
const auto& parameters = GetParam();
const auto& numeric_type = std::get<0>(parameters);
const auto& integer_type = std::get<1>(parameters);
const auto& scatter_setup = std::get<2>(parameters);
const auto& numericType = std::get<0>(parameters);
const auto& integerType = std::get<1>(parameters);
const auto& scatterSetup = std::get<2>(parameters);
const auto& indicesUpdatesShapeType = std::get<3>(parameters);
ngraph::helpers::CompareFunctions(*transform(numeric_type, integer_type, scatter_setup),
*reference(numeric_type, integer_type, scatter_setup));
ngraph::helpers::CompareFunctions(
*transform(numericType, integerType, scatterSetup, indicesUpdatesShapeType),
*reference(numericType, integerType, scatterSetup, indicesUpdatesShapeType));
}
protected:
std::shared_ptr<const ngraph::Function> transform(
const ngraph::element::Type_t& numeric_type,
const ngraph::element::Type_t& integer_type,
const ScatterTestCase& scatter_setup) const {
const auto data = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.data_shape);
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integer_type, scatter_setup.indices_shape);
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.updates_shape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integer_type, ngraph::Shape{1}, std::vector<int64_t>{scatter_setup.axis});
const ngraph::element::Type_t& numericType,
const ngraph::element::Type_t& integerType,
const ScatterTestCase& scatterSetup,
ShapeType indicesUpdatesShapeType) const {
const auto data = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.dataShape);
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integerType, scatterSetup.indicesShape);
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.updatesShape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integerType, ngraph::Shape{1}, std::vector<int64_t>{scatterSetup.axis});
const auto dataDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()});
const auto dataDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dataDims);
const auto dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()});
const auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dims);
ngraph::ParameterVector params{data, indices, updates, dataDims};
const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis});
std::shared_ptr<ngraph::Node> scatterIndices = indices;
std::shared_ptr<ngraph::Node> scatterUpdates = updates;
if (indicesUpdatesShapeType == ShapeType::DYNAMIC) {
const auto indicesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()});
scatterIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
params.push_back(indicesDims);
const auto updatesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()});
scatterUpdates = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(updates, updatesDims);
params.push_back(updatesDims);
}
const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis});
auto outputShape = node->get_output_partial_shape(0);
const auto function = std::make_shared<ngraph::Function>(
ngraph::NodeVector{node},
ngraph::ParameterVector{data, indices, updates, dims},
params,
"Actual");
node->set_output_type(0, dsr->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank()));
node->set_output_type(0, dataDSR->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank()));
const auto transformations = vpu::Transformations{{scatter_setup.scatter_type_info, vpu::dynamicToStaticUnaryElementwise}};
const auto transformations = vpu::Transformations{{scatterSetup.scatterTypeInfo, vpu::dynamicToStaticUnaryElementwise}};
vpu::DynamicToStaticShape(transformations).run_on_function(function);
return function;
}
std::shared_ptr<const ngraph::Function> reference(
const ngraph::element::Type_t& numeric_type,
const ngraph::element::Type_t& integer_type,
const ScatterTestCase& scatter_setup) const {
const auto data = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.data_shape);
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integer_type, scatter_setup.indices_shape);
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.updates_shape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integer_type, ngraph::Shape{1}, std::vector<int64_t>{scatter_setup.axis});
const ngraph::element::Type_t& numericType,
const ngraph::element::Type_t& integerType,
const ScatterTestCase& scatterSetup,
ShapeType indicesUpdatesShapeType) const {
const auto data = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.dataShape);
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integerType, scatterSetup.indicesShape);
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.updatesShape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integerType, ngraph::Shape{1}, std::vector<int64_t>{scatterSetup.axis});
const auto dataDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()});
const auto dataDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dataDims);
ngraph::ParameterVector params{data, indices, updates, dataDims};
std::shared_ptr<ngraph::Node> scatterIndices = indices;
std::shared_ptr<ngraph::Node> scatterUpdates = updates;
if (indicesUpdatesShapeType == ShapeType::DYNAMIC) {
const auto indicesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()});
scatterIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
params.push_back(indicesDims);
const auto updatesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()});
scatterUpdates = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(updates, updatesDims);
params.push_back(updatesDims);
}
const auto dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()});
const auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dims);
const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis});
const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis});
std::shared_ptr<ngraph::Node> outNode = node;
const auto outDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, dataDims);
const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, dims);
return std::make_shared<ngraph::Function>(
ngraph::NodeVector{dsr1},
ngraph::ParameterVector{data, indices, updates, dims},
ngraph::NodeVector{outDSR},
params,
"Expected");
}
};
@ -103,6 +146,11 @@ INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticShapeScatter, testing::Comb
ngraph::element::i64,
ngraph::element::u8),
testing::Values(
ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1})));
ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1},
ScatterTestCase{ngraph::opset5::ScatterElementsUpdate::type_info, {300}, {300}, {300}, 0}),
testing::Values(
ShapeType::DYNAMIC,
ShapeType::STATIC)
));
} // namespace

View File

@ -56,6 +56,12 @@ INSTANTIATE_TEST_CASE_P(smoke_DynamicScatter, DSR_Scatter,
{{84, 256, 7, 7}, {100, 256, 7, 7}},
{{84}, {100}},
{{84, 256, 7, 7}, {100, 256, 7, 7}},
0},
ScatterTestCase{
ngraph::opset5::ScatterElementsUpdate::type_info,
{{142}, {300}},
{{80}, {300}},
{{80}, {300}},
0}),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)));