[IE][VPU][Tests]: Support DTS for ScatterElementsUpdate (#3559)

* Enable DTS for ScatterElementsUpdate
* Update DTS tests
* Update inference tests
This commit is contained in:
Andrew Bakalin 2020-12-11 12:46:26 +03:00 committed by GitHub
parent a0952798ba
commit d90c05aab4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 72 deletions

View File

@ -89,43 +89,44 @@ void validateDynamicFunction(const ngraph::Function& function) {
const Transformations& getDefaultTransformations() { const Transformations& getDefaultTransformations() {
static const Transformations transformations = { static const Transformations transformations = {
{ngraph::opset3::Add::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Add::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Multiply::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Multiply::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Subtract::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Subtract::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::VariadicSplit::type_info, dynamicToStaticShapeVariadicSplit}, {ngraph::opset3::VariadicSplit::type_info, dynamicToStaticShapeVariadicSplit},
{ngraph::opset3::Divide::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Divide::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Equal::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Equal::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Greater::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Greater::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Power::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Power::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Maximum::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Maximum::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Minimum::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Minimum::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset3::Less::type_info, dynamicToStaticShapeBinaryEltwise}, {ngraph::opset3::Less::type_info, dynamicToStaticShapeBinaryEltwise},
{ngraph::opset5::NonMaxSuppression::type_info, dynamicToStaticNonMaxSuppression}, {ngraph::opset5::NonMaxSuppression::type_info, dynamicToStaticNonMaxSuppression},
{ngraph::opset3::NonZero::type_info, dynamicToStaticShapeNonZero}, {ngraph::opset3::NonZero::type_info, dynamicToStaticShapeNonZero},
{ngraph::opset3::TopK::type_info, dynamicToStaticShapeTopK}, {ngraph::opset3::TopK::type_info, dynamicToStaticShapeTopK},
{ngraph::opset3::Transpose::type_info, dynamicToStaticShapeTranspose}, {ngraph::opset3::Transpose::type_info, dynamicToStaticShapeTranspose},
{ngraph::opset3::Concat::type_info, dynamicToStaticShapeConcat}, {ngraph::opset3::Concat::type_info, dynamicToStaticShapeConcat},
{ngraph::opset3::Convert::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Convert::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Clamp::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Clamp::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Floor::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Floor::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Log::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Log::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Relu::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Relu::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::ScatterUpdate::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::ScatterUpdate::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Sigmoid::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Sigmoid::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Softmax::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Softmax::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Exp::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Exp::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Sqrt::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::Sqrt::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::LogicalNot::type_info, dynamicToStaticUnaryElementwise}, {ngraph::opset3::LogicalNot::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::StridedSlice::type_info, dynamicToStaticShapeStridedSlice}, {ngraph::opset5::ScatterElementsUpdate::type_info, dynamicToStaticUnaryElementwise},
{ngraph::opset3::Squeeze::type_info, dynamicToStaticShapeSqueeze}, {ngraph::opset3::StridedSlice::type_info, dynamicToStaticShapeStridedSlice},
{ngraph::opset3::Gather::type_info, dynamicToStaticShapeGather}, {ngraph::opset3::Squeeze::type_info, dynamicToStaticShapeSqueeze},
{ngraph::opset3::Unsqueeze::type_info, dynamicToStaticShapeUnsqueeze}, {ngraph::opset3::Gather::type_info, dynamicToStaticShapeGather},
{ngraph::opset3::ROIAlign::type_info, dynamicToStaticShapeROIAlign}, {ngraph::opset3::Unsqueeze::type_info, dynamicToStaticShapeUnsqueeze},
{ngraph::opset3::Reshape::type_info, dynamicToStaticShapeReshape}, {ngraph::opset3::ROIAlign::type_info, dynamicToStaticShapeROIAlign},
{ngraph::opset3::Broadcast::type_info, dynamicToStaticShapeBroadcast}, {ngraph::opset3::Reshape::type_info, dynamicToStaticShapeReshape},
{ngraph::opset3::MatMul::type_info, dynamicToStaticShapeMatMul}, {ngraph::opset3::Broadcast::type_info, dynamicToStaticShapeBroadcast},
{ngraph::opset5::Split::type_info, dynamicToStaticShapeSplit}, {ngraph::opset3::MatMul::type_info, dynamicToStaticShapeMatMul},
{ngraph::opset5::GatherND::type_info, dynamicToStaticShapeGatherND}, {ngraph::opset5::Split::type_info, dynamicToStaticShapeSplit},
{ngraph::opset5::GatherND::type_info, dynamicToStaticShapeGatherND},
// reduction // reduction
{ngraph::opset3::ReduceLogicalAnd::type_info, dynamicToStaticShapeReduce}, {ngraph::opset3::ReduceLogicalAnd::type_info, dynamicToStaticShapeReduce},

View File

@ -5,6 +5,7 @@
#include <common_test_utils/test_common.hpp> #include <common_test_utils/test_common.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp> #include <ngraph_functions/utils/ngraph_helpers.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp> #include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
#include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp> #include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
#include <vpu/ngraph/transformations/dynamic_to_static_shape_unary_elementwise.hpp> #include <vpu/ngraph/transformations/dynamic_to_static_shape_unary_elementwise.hpp>
@ -19,71 +20,113 @@ using DataType = ngraph::element::Type_t;
struct ScatterTestCase { struct ScatterTestCase {
ngraph::NodeTypeInfo scatter_type_info; ngraph::NodeTypeInfo scatterTypeInfo;
ngraph::Shape data_shape, indices_shape, updates_shape; ngraph::Shape dataShape, indicesShape, updatesShape;
int64_t axis; int64_t axis;
}; };
enum class ShapeType {
DYNAMIC,
STATIC
};
using ScatterParameters = std::tuple<
DataType,
DataType,
ScatterTestCase,
ShapeType>;
class DynamicToStaticShapeScatter : public CommonTestUtils::TestsCommon, class DynamicToStaticShapeScatter : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<DataType, DataType, ScatterTestCase>> { public testing::WithParamInterface<ScatterParameters> {
public: public:
void SetUp() override { void SetUp() override {
const auto& parameters = GetParam(); const auto& parameters = GetParam();
const auto& numeric_type = std::get<0>(parameters); const auto& numericType = std::get<0>(parameters);
const auto& integer_type = std::get<1>(parameters); const auto& integerType = std::get<1>(parameters);
const auto& scatter_setup = std::get<2>(parameters); const auto& scatterSetup = std::get<2>(parameters);
const auto& indicesUpdatesShapeType = std::get<3>(parameters);
ngraph::helpers::CompareFunctions(*transform(numeric_type, integer_type, scatter_setup), ngraph::helpers::CompareFunctions(
*reference(numeric_type, integer_type, scatter_setup)); *transform(numericType, integerType, scatterSetup, indicesUpdatesShapeType),
*reference(numericType, integerType, scatterSetup, indicesUpdatesShapeType));
} }
protected: protected:
std::shared_ptr<const ngraph::Function> transform( std::shared_ptr<const ngraph::Function> transform(
const ngraph::element::Type_t& numeric_type, const ngraph::element::Type_t& numericType,
const ngraph::element::Type_t& integer_type, const ngraph::element::Type_t& integerType,
const ScatterTestCase& scatter_setup) const { const ScatterTestCase& scatterSetup,
const auto data = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.data_shape); ShapeType indicesUpdatesShapeType) const {
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integer_type, scatter_setup.indices_shape); const auto data = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.dataShape);
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.updates_shape); const auto indices = std::make_shared<ngraph::opset3::Parameter>(integerType, scatterSetup.indicesShape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integer_type, ngraph::Shape{1}, std::vector<int64_t>{scatter_setup.axis}); const auto updates = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.updatesShape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integerType, ngraph::Shape{1}, std::vector<int64_t>{scatterSetup.axis});
const auto dataDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()});
const auto dataDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dataDims);
const auto dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()}); ngraph::ParameterVector params{data, indices, updates, dataDims};
const auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dims);
const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis}); std::shared_ptr<ngraph::Node> scatterIndices = indices;
std::shared_ptr<ngraph::Node> scatterUpdates = updates;
if (indicesUpdatesShapeType == ShapeType::DYNAMIC) {
const auto indicesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()});
scatterIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
params.push_back(indicesDims);
const auto updatesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()});
scatterUpdates = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(updates, updatesDims);
params.push_back(updatesDims);
}
const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis});
auto outputShape = node->get_output_partial_shape(0); auto outputShape = node->get_output_partial_shape(0);
const auto function = std::make_shared<ngraph::Function>( const auto function = std::make_shared<ngraph::Function>(
ngraph::NodeVector{node}, ngraph::NodeVector{node},
ngraph::ParameterVector{data, indices, updates, dims}, params,
"Actual"); "Actual");
node->set_output_type(0, dsr->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank())); node->set_output_type(0, dataDSR->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank()));
const auto transformations = vpu::Transformations{{scatter_setup.scatter_type_info, vpu::dynamicToStaticUnaryElementwise}}; const auto transformations = vpu::Transformations{{scatterSetup.scatterTypeInfo, vpu::dynamicToStaticUnaryElementwise}};
vpu::DynamicToStaticShape(transformations).run_on_function(function); vpu::DynamicToStaticShape(transformations).run_on_function(function);
return function; return function;
} }
std::shared_ptr<const ngraph::Function> reference( std::shared_ptr<const ngraph::Function> reference(
const ngraph::element::Type_t& numeric_type, const ngraph::element::Type_t& numericType,
const ngraph::element::Type_t& integer_type, const ngraph::element::Type_t& integerType,
const ScatterTestCase& scatter_setup) const { const ScatterTestCase& scatterSetup,
const auto data = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.data_shape); ShapeType indicesUpdatesShapeType) const {
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integer_type, scatter_setup.indices_shape); const auto data = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.dataShape);
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.updates_shape); const auto indices = std::make_shared<ngraph::opset3::Parameter>(integerType, scatterSetup.indicesShape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integer_type, ngraph::Shape{1}, std::vector<int64_t>{scatter_setup.axis}); const auto updates = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.updatesShape);
const auto axis = std::make_shared<ngraph::opset3::Constant>(integerType, ngraph::Shape{1}, std::vector<int64_t>{scatterSetup.axis});
const auto dataDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()});
const auto dataDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dataDims);
ngraph::ParameterVector params{data, indices, updates, dataDims};
std::shared_ptr<ngraph::Node> scatterIndices = indices;
std::shared_ptr<ngraph::Node> scatterUpdates = updates;
if (indicesUpdatesShapeType == ShapeType::DYNAMIC) {
const auto indicesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()});
scatterIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
params.push_back(indicesDims);
const auto updatesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()});
scatterUpdates = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(updates, updatesDims);
params.push_back(updatesDims);
}
const auto dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()}); const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis});
const auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dims);
const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis}); std::shared_ptr<ngraph::Node> outNode = node;
const auto outDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, dataDims);
const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, dims);
return std::make_shared<ngraph::Function>( return std::make_shared<ngraph::Function>(
ngraph::NodeVector{dsr1}, ngraph::NodeVector{outDSR},
ngraph::ParameterVector{data, indices, updates, dims}, params,
"Expected"); "Expected");
} }
}; };
@ -103,6 +146,11 @@ INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticShapeScatter, testing::Comb
ngraph::element::i64, ngraph::element::i64,
ngraph::element::u8), ngraph::element::u8),
testing::Values( testing::Values(
ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1}))); ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1},
ScatterTestCase{ngraph::opset5::ScatterElementsUpdate::type_info, {300}, {300}, {300}, 0}),
testing::Values(
ShapeType::DYNAMIC,
ShapeType::STATIC)
));
} // namespace } // namespace

View File

@ -56,6 +56,12 @@ INSTANTIATE_TEST_CASE_P(smoke_DynamicScatter, DSR_Scatter,
{{84, 256, 7, 7}, {100, 256, 7, 7}}, {{84, 256, 7, 7}, {100, 256, 7, 7}},
{{84}, {100}}, {{84}, {100}},
{{84, 256, 7, 7}, {100, 256, 7, 7}}, {{84, 256, 7, 7}, {100, 256, 7, 7}},
0},
ScatterTestCase{
ngraph::opset5::ScatterElementsUpdate::type_info,
{{142}, {300}},
{{80}, {300}},
{{80}, {300}},
0}), 0}),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD))); ::testing::Values(CommonTestUtils::DEVICE_MYRIAD)));