From d90c05aab45d5bb9316dfe08c39fd4ceb7d34eef Mon Sep 17 00:00:00 2001 From: Andrew Bakalin Date: Fri, 11 Dec 2020 12:46:26 +0300 Subject: [PATCH] [IE][VPU][Tests]: Support DTS for ScatterElementsUpdate (#3559) * Enable DTS for ScatterElementsUpdate * Update DTS tests * Update inference tests --- .../dynamic_to_static_shape.cpp | 75 +++++------ .../dynamic_to_static_shape_scatter.cpp | 118 ++++++++++++------ .../myriad/subgraph_tests/dsr_scatter.cpp | 6 + 3 files changed, 127 insertions(+), 72 deletions(-) diff --git a/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape.cpp b/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape.cpp index eef8903e35e..f0004f09616 100644 --- a/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape.cpp +++ b/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape.cpp @@ -89,43 +89,44 @@ void validateDynamicFunction(const ngraph::Function& function) { const Transformations& getDefaultTransformations() { static const Transformations transformations = { - {ngraph::opset3::Add::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Multiply::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Subtract::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::VariadicSplit::type_info, dynamicToStaticShapeVariadicSplit}, - {ngraph::opset3::Divide::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Equal::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Greater::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Power::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Maximum::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Minimum::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset3::Less::type_info, dynamicToStaticShapeBinaryEltwise}, - {ngraph::opset5::NonMaxSuppression::type_info, dynamicToStaticNonMaxSuppression}, - {ngraph::opset3::NonZero::type_info, dynamicToStaticShapeNonZero}, - {ngraph::opset3::TopK::type_info, dynamicToStaticShapeTopK}, - {ngraph::opset3::Transpose::type_info, dynamicToStaticShapeTranspose}, - {ngraph::opset3::Concat::type_info, dynamicToStaticShapeConcat}, - {ngraph::opset3::Convert::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Clamp::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Floor::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Log::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Relu::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::ScatterUpdate::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Sigmoid::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Softmax::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Exp::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::Sqrt::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::LogicalNot::type_info, dynamicToStaticUnaryElementwise}, - {ngraph::opset3::StridedSlice::type_info, dynamicToStaticShapeStridedSlice}, - {ngraph::opset3::Squeeze::type_info, dynamicToStaticShapeSqueeze}, - {ngraph::opset3::Gather::type_info, dynamicToStaticShapeGather}, - {ngraph::opset3::Unsqueeze::type_info, dynamicToStaticShapeUnsqueeze}, - {ngraph::opset3::ROIAlign::type_info, dynamicToStaticShapeROIAlign}, - {ngraph::opset3::Reshape::type_info, dynamicToStaticShapeReshape}, - {ngraph::opset3::Broadcast::type_info, dynamicToStaticShapeBroadcast}, - {ngraph::opset3::MatMul::type_info, dynamicToStaticShapeMatMul}, - {ngraph::opset5::Split::type_info, dynamicToStaticShapeSplit}, - {ngraph::opset5::GatherND::type_info, dynamicToStaticShapeGatherND}, + {ngraph::opset3::Add::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Multiply::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Subtract::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::VariadicSplit::type_info, dynamicToStaticShapeVariadicSplit}, + {ngraph::opset3::Divide::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Equal::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Greater::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Power::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Maximum::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Minimum::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset3::Less::type_info, dynamicToStaticShapeBinaryEltwise}, + {ngraph::opset5::NonMaxSuppression::type_info, dynamicToStaticNonMaxSuppression}, + {ngraph::opset3::NonZero::type_info, dynamicToStaticShapeNonZero}, + {ngraph::opset3::TopK::type_info, dynamicToStaticShapeTopK}, + {ngraph::opset3::Transpose::type_info, dynamicToStaticShapeTranspose}, + {ngraph::opset3::Concat::type_info, dynamicToStaticShapeConcat}, + {ngraph::opset3::Convert::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Clamp::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Floor::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Log::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Relu::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::ScatterUpdate::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Sigmoid::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Softmax::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Exp::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::Sqrt::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::LogicalNot::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset5::ScatterElementsUpdate::type_info, dynamicToStaticUnaryElementwise}, + {ngraph::opset3::StridedSlice::type_info, dynamicToStaticShapeStridedSlice}, + {ngraph::opset3::Squeeze::type_info, dynamicToStaticShapeSqueeze}, + {ngraph::opset3::Gather::type_info, dynamicToStaticShapeGather}, + {ngraph::opset3::Unsqueeze::type_info, dynamicToStaticShapeUnsqueeze}, + {ngraph::opset3::ROIAlign::type_info, dynamicToStaticShapeROIAlign}, + {ngraph::opset3::Reshape::type_info, dynamicToStaticShapeReshape}, + {ngraph::opset3::Broadcast::type_info, dynamicToStaticShapeBroadcast}, + {ngraph::opset3::MatMul::type_info, dynamicToStaticShapeMatMul}, + {ngraph::opset5::Split::type_info, dynamicToStaticShapeSplit}, + {ngraph::opset5::GatherND::type_info, dynamicToStaticShapeGatherND}, // reduction {ngraph::opset3::ReduceLogicalAnd::type_info, dynamicToStaticShapeReduce}, diff --git a/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_scatter.cpp b/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_scatter.cpp index c83dec9418d..721309630ca 100644 --- a/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_scatter.cpp +++ b/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_scatter.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -19,71 +20,113 @@ using DataType = ngraph::element::Type_t; struct ScatterTestCase { - ngraph::NodeTypeInfo scatter_type_info; - ngraph::Shape data_shape, indices_shape, updates_shape; + ngraph::NodeTypeInfo scatterTypeInfo; + ngraph::Shape dataShape, indicesShape, updatesShape; int64_t axis; }; +enum class ShapeType { + DYNAMIC, + STATIC +}; + +using ScatterParameters = std::tuple< + DataType, + DataType, + ScatterTestCase, + ShapeType>; + class DynamicToStaticShapeScatter : public CommonTestUtils::TestsCommon, - public testing::WithParamInterface> { + public testing::WithParamInterface { public: void SetUp() override { const auto& parameters = GetParam(); - const auto& numeric_type = std::get<0>(parameters); - const auto& integer_type = std::get<1>(parameters); - const auto& scatter_setup = std::get<2>(parameters); + const auto& numericType = std::get<0>(parameters); + const auto& integerType = std::get<1>(parameters); + const auto& scatterSetup = std::get<2>(parameters); + const auto& indicesUpdatesShapeType = std::get<3>(parameters); - ngraph::helpers::CompareFunctions(*transform(numeric_type, integer_type, scatter_setup), - *reference(numeric_type, integer_type, scatter_setup)); + ngraph::helpers::CompareFunctions( + *transform(numericType, integerType, scatterSetup, indicesUpdatesShapeType), + *reference(numericType, integerType, scatterSetup, indicesUpdatesShapeType)); } protected: std::shared_ptr transform( - const ngraph::element::Type_t& numeric_type, - const ngraph::element::Type_t& integer_type, - const ScatterTestCase& scatter_setup) const { - const auto data = std::make_shared(numeric_type, scatter_setup.data_shape); - const auto indices = std::make_shared(integer_type, scatter_setup.indices_shape); - const auto updates = std::make_shared(numeric_type, scatter_setup.updates_shape); - const auto axis = std::make_shared(integer_type, ngraph::Shape{1}, std::vector{scatter_setup.axis}); + const ngraph::element::Type_t& numericType, + const ngraph::element::Type_t& integerType, + const ScatterTestCase& scatterSetup, + ShapeType indicesUpdatesShapeType) const { + const auto data = std::make_shared(numericType, scatterSetup.dataShape); + const auto indices = std::make_shared(integerType, scatterSetup.indicesShape); + const auto updates = std::make_shared(numericType, scatterSetup.updatesShape); + const auto axis = std::make_shared(integerType, ngraph::Shape{1}, std::vector{scatterSetup.axis}); + const auto dataDims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()}); + const auto dataDSR = std::make_shared(data, dataDims); - const auto dims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()}); - const auto dsr = std::make_shared(data, dims); + ngraph::ParameterVector params{data, indices, updates, dataDims}; - const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis}); + std::shared_ptr scatterIndices = indices; + std::shared_ptr scatterUpdates = updates; + if (indicesUpdatesShapeType == ShapeType::DYNAMIC) { + const auto indicesDims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()}); + scatterIndices = std::make_shared(indices, indicesDims); + params.push_back(indicesDims); + const auto updatesDims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()}); + scatterUpdates = std::make_shared(updates, updatesDims); + params.push_back(updatesDims); + } + + const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis}); auto outputShape = node->get_output_partial_shape(0); const auto function = std::make_shared( ngraph::NodeVector{node}, - ngraph::ParameterVector{data, indices, updates, dims}, + params, "Actual"); - node->set_output_type(0, dsr->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank())); + node->set_output_type(0, dataDSR->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank())); - const auto transformations = vpu::Transformations{{scatter_setup.scatter_type_info, vpu::dynamicToStaticUnaryElementwise}}; + const auto transformations = vpu::Transformations{{scatterSetup.scatterTypeInfo, vpu::dynamicToStaticUnaryElementwise}}; vpu::DynamicToStaticShape(transformations).run_on_function(function); return function; } std::shared_ptr reference( - const ngraph::element::Type_t& numeric_type, - const ngraph::element::Type_t& integer_type, - const ScatterTestCase& scatter_setup) const { - const auto data = std::make_shared(numeric_type, scatter_setup.data_shape); - const auto indices = std::make_shared(integer_type, scatter_setup.indices_shape); - const auto updates = std::make_shared(numeric_type, scatter_setup.updates_shape); - const auto axis = std::make_shared(integer_type, ngraph::Shape{1}, std::vector{scatter_setup.axis}); + const ngraph::element::Type_t& numericType, + const ngraph::element::Type_t& integerType, + const ScatterTestCase& scatterSetup, + ShapeType indicesUpdatesShapeType) const { + const auto data = std::make_shared(numericType, scatterSetup.dataShape); + const auto indices = std::make_shared(integerType, scatterSetup.indicesShape); + const auto updates = std::make_shared(numericType, scatterSetup.updatesShape); + const auto axis = std::make_shared(integerType, ngraph::Shape{1}, std::vector{scatterSetup.axis}); + + const auto dataDims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()}); + const auto dataDSR = std::make_shared(data, dataDims); + + ngraph::ParameterVector params{data, indices, updates, dataDims}; + + std::shared_ptr scatterIndices = indices; + std::shared_ptr scatterUpdates = updates; + if (indicesUpdatesShapeType == ShapeType::DYNAMIC) { + const auto indicesDims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()}); + scatterIndices = std::make_shared(indices, indicesDims); + params.push_back(indicesDims); + const auto updatesDims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()}); + scatterUpdates = std::make_shared(updates, updatesDims); + params.push_back(updatesDims); + } - const auto dims = std::make_shared(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()}); - const auto dsr = std::make_shared(data, dims); + const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis}); - const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis}); + std::shared_ptr outNode = node; + const auto outDSR = std::make_shared(node, dataDims); - const auto dsr1 = std::make_shared(node, dims); return std::make_shared( - ngraph::NodeVector{dsr1}, - ngraph::ParameterVector{data, indices, updates, dims}, + ngraph::NodeVector{outDSR}, + params, "Expected"); } }; @@ -103,6 +146,11 @@ INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticShapeScatter, testing::Comb ngraph::element::i64, ngraph::element::u8), testing::Values( - ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1}))); + ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1}, + ScatterTestCase{ngraph::opset5::ScatterElementsUpdate::type_info, {300}, {300}, {300}, 0}), + testing::Values( + ShapeType::DYNAMIC, + ShapeType::STATIC) +)); } // namespace diff --git a/inference-engine/tests/functional/plugin/myriad/subgraph_tests/dsr_scatter.cpp b/inference-engine/tests/functional/plugin/myriad/subgraph_tests/dsr_scatter.cpp index 1b4f76956f9..729926659ba 100644 --- a/inference-engine/tests/functional/plugin/myriad/subgraph_tests/dsr_scatter.cpp +++ b/inference-engine/tests/functional/plugin/myriad/subgraph_tests/dsr_scatter.cpp @@ -56,6 +56,12 @@ INSTANTIATE_TEST_CASE_P(smoke_DynamicScatter, DSR_Scatter, {{84, 256, 7, 7}, {100, 256, 7, 7}}, {{84}, {100}}, {{84, 256, 7, 7}, {100, 256, 7, 7}}, + 0}, + ScatterTestCase{ + ngraph::opset5::ScatterElementsUpdate::type_info, + {{142}, {300}}, + {{80}, {300}}, + {{80}, {300}}, 0}), ::testing::Values(CommonTestUtils::DEVICE_MYRIAD)));