[IE][VPU][Tests]: Support DTS for ScatterElementsUpdate (#3559)
* Enable DTS for ScatterElementsUpdate * Update DTS tests * Update inference tests
This commit is contained in:
parent
a0952798ba
commit
d90c05aab4
@ -89,43 +89,44 @@ void validateDynamicFunction(const ngraph::Function& function) {
|
||||
|
||||
const Transformations& getDefaultTransformations() {
|
||||
static const Transformations transformations = {
|
||||
{ngraph::opset3::Add::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Multiply::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Subtract::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::VariadicSplit::type_info, dynamicToStaticShapeVariadicSplit},
|
||||
{ngraph::opset3::Divide::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Equal::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Greater::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Power::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Maximum::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Minimum::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Less::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset5::NonMaxSuppression::type_info, dynamicToStaticNonMaxSuppression},
|
||||
{ngraph::opset3::NonZero::type_info, dynamicToStaticShapeNonZero},
|
||||
{ngraph::opset3::TopK::type_info, dynamicToStaticShapeTopK},
|
||||
{ngraph::opset3::Transpose::type_info, dynamicToStaticShapeTranspose},
|
||||
{ngraph::opset3::Concat::type_info, dynamicToStaticShapeConcat},
|
||||
{ngraph::opset3::Convert::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Clamp::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Floor::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Log::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Relu::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::ScatterUpdate::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Sigmoid::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Softmax::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Exp::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Sqrt::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::LogicalNot::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::StridedSlice::type_info, dynamicToStaticShapeStridedSlice},
|
||||
{ngraph::opset3::Squeeze::type_info, dynamicToStaticShapeSqueeze},
|
||||
{ngraph::opset3::Gather::type_info, dynamicToStaticShapeGather},
|
||||
{ngraph::opset3::Unsqueeze::type_info, dynamicToStaticShapeUnsqueeze},
|
||||
{ngraph::opset3::ROIAlign::type_info, dynamicToStaticShapeROIAlign},
|
||||
{ngraph::opset3::Reshape::type_info, dynamicToStaticShapeReshape},
|
||||
{ngraph::opset3::Broadcast::type_info, dynamicToStaticShapeBroadcast},
|
||||
{ngraph::opset3::MatMul::type_info, dynamicToStaticShapeMatMul},
|
||||
{ngraph::opset5::Split::type_info, dynamicToStaticShapeSplit},
|
||||
{ngraph::opset5::GatherND::type_info, dynamicToStaticShapeGatherND},
|
||||
{ngraph::opset3::Add::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Multiply::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Subtract::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::VariadicSplit::type_info, dynamicToStaticShapeVariadicSplit},
|
||||
{ngraph::opset3::Divide::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Equal::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Greater::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Power::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Maximum::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Minimum::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Less::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset5::NonMaxSuppression::type_info, dynamicToStaticNonMaxSuppression},
|
||||
{ngraph::opset3::NonZero::type_info, dynamicToStaticShapeNonZero},
|
||||
{ngraph::opset3::TopK::type_info, dynamicToStaticShapeTopK},
|
||||
{ngraph::opset3::Transpose::type_info, dynamicToStaticShapeTranspose},
|
||||
{ngraph::opset3::Concat::type_info, dynamicToStaticShapeConcat},
|
||||
{ngraph::opset3::Convert::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Clamp::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Floor::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Log::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Relu::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::ScatterUpdate::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Sigmoid::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Softmax::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Exp::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::Sqrt::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::LogicalNot::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset5::ScatterElementsUpdate::type_info, dynamicToStaticUnaryElementwise},
|
||||
{ngraph::opset3::StridedSlice::type_info, dynamicToStaticShapeStridedSlice},
|
||||
{ngraph::opset3::Squeeze::type_info, dynamicToStaticShapeSqueeze},
|
||||
{ngraph::opset3::Gather::type_info, dynamicToStaticShapeGather},
|
||||
{ngraph::opset3::Unsqueeze::type_info, dynamicToStaticShapeUnsqueeze},
|
||||
{ngraph::opset3::ROIAlign::type_info, dynamicToStaticShapeROIAlign},
|
||||
{ngraph::opset3::Reshape::type_info, dynamicToStaticShapeReshape},
|
||||
{ngraph::opset3::Broadcast::type_info, dynamicToStaticShapeBroadcast},
|
||||
{ngraph::opset3::MatMul::type_info, dynamicToStaticShapeMatMul},
|
||||
{ngraph::opset5::Split::type_info, dynamicToStaticShapeSplit},
|
||||
{ngraph::opset5::GatherND::type_info, dynamicToStaticShapeGatherND},
|
||||
|
||||
// reduction
|
||||
{ngraph::opset3::ReduceLogicalAnd::type_info, dynamicToStaticShapeReduce},
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <common_test_utils/test_common.hpp>
|
||||
#include <ngraph_functions/utils/ngraph_helpers.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
|
||||
#include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
|
||||
#include <vpu/ngraph/transformations/dynamic_to_static_shape_unary_elementwise.hpp>
|
||||
@ -19,71 +20,113 @@ using DataType = ngraph::element::Type_t;
|
||||
|
||||
|
||||
struct ScatterTestCase {
|
||||
ngraph::NodeTypeInfo scatter_type_info;
|
||||
ngraph::Shape data_shape, indices_shape, updates_shape;
|
||||
ngraph::NodeTypeInfo scatterTypeInfo;
|
||||
ngraph::Shape dataShape, indicesShape, updatesShape;
|
||||
int64_t axis;
|
||||
};
|
||||
|
||||
enum class ShapeType {
|
||||
DYNAMIC,
|
||||
STATIC
|
||||
};
|
||||
|
||||
using ScatterParameters = std::tuple<
|
||||
DataType,
|
||||
DataType,
|
||||
ScatterTestCase,
|
||||
ShapeType>;
|
||||
|
||||
class DynamicToStaticShapeScatter : public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<DataType, DataType, ScatterTestCase>> {
|
||||
public testing::WithParamInterface<ScatterParameters> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const auto& parameters = GetParam();
|
||||
const auto& numeric_type = std::get<0>(parameters);
|
||||
const auto& integer_type = std::get<1>(parameters);
|
||||
const auto& scatter_setup = std::get<2>(parameters);
|
||||
const auto& numericType = std::get<0>(parameters);
|
||||
const auto& integerType = std::get<1>(parameters);
|
||||
const auto& scatterSetup = std::get<2>(parameters);
|
||||
const auto& indicesUpdatesShapeType = std::get<3>(parameters);
|
||||
|
||||
ngraph::helpers::CompareFunctions(*transform(numeric_type, integer_type, scatter_setup),
|
||||
*reference(numeric_type, integer_type, scatter_setup));
|
||||
ngraph::helpers::CompareFunctions(
|
||||
*transform(numericType, integerType, scatterSetup, indicesUpdatesShapeType),
|
||||
*reference(numericType, integerType, scatterSetup, indicesUpdatesShapeType));
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<const ngraph::Function> transform(
|
||||
const ngraph::element::Type_t& numeric_type,
|
||||
const ngraph::element::Type_t& integer_type,
|
||||
const ScatterTestCase& scatter_setup) const {
|
||||
const auto data = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.data_shape);
|
||||
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integer_type, scatter_setup.indices_shape);
|
||||
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.updates_shape);
|
||||
const auto axis = std::make_shared<ngraph::opset3::Constant>(integer_type, ngraph::Shape{1}, std::vector<int64_t>{scatter_setup.axis});
|
||||
const ngraph::element::Type_t& numericType,
|
||||
const ngraph::element::Type_t& integerType,
|
||||
const ScatterTestCase& scatterSetup,
|
||||
ShapeType indicesUpdatesShapeType) const {
|
||||
const auto data = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.dataShape);
|
||||
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integerType, scatterSetup.indicesShape);
|
||||
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.updatesShape);
|
||||
const auto axis = std::make_shared<ngraph::opset3::Constant>(integerType, ngraph::Shape{1}, std::vector<int64_t>{scatterSetup.axis});
|
||||
|
||||
const auto dataDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()});
|
||||
const auto dataDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dataDims);
|
||||
|
||||
const auto dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()});
|
||||
const auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dims);
|
||||
ngraph::ParameterVector params{data, indices, updates, dataDims};
|
||||
|
||||
const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis});
|
||||
std::shared_ptr<ngraph::Node> scatterIndices = indices;
|
||||
std::shared_ptr<ngraph::Node> scatterUpdates = updates;
|
||||
if (indicesUpdatesShapeType == ShapeType::DYNAMIC) {
|
||||
const auto indicesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()});
|
||||
scatterIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
|
||||
params.push_back(indicesDims);
|
||||
const auto updatesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()});
|
||||
scatterUpdates = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(updates, updatesDims);
|
||||
params.push_back(updatesDims);
|
||||
}
|
||||
|
||||
const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis});
|
||||
|
||||
auto outputShape = node->get_output_partial_shape(0);
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{node},
|
||||
ngraph::ParameterVector{data, indices, updates, dims},
|
||||
params,
|
||||
"Actual");
|
||||
node->set_output_type(0, dsr->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank()));
|
||||
node->set_output_type(0, dataDSR->get_input_element_type(0), ngraph::PartialShape::dynamic(outputShape.rank()));
|
||||
|
||||
const auto transformations = vpu::Transformations{{scatter_setup.scatter_type_info, vpu::dynamicToStaticUnaryElementwise}};
|
||||
const auto transformations = vpu::Transformations{{scatterSetup.scatterTypeInfo, vpu::dynamicToStaticUnaryElementwise}};
|
||||
vpu::DynamicToStaticShape(transformations).run_on_function(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<const ngraph::Function> reference(
|
||||
const ngraph::element::Type_t& numeric_type,
|
||||
const ngraph::element::Type_t& integer_type,
|
||||
const ScatterTestCase& scatter_setup) const {
|
||||
const auto data = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.data_shape);
|
||||
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integer_type, scatter_setup.indices_shape);
|
||||
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numeric_type, scatter_setup.updates_shape);
|
||||
const auto axis = std::make_shared<ngraph::opset3::Constant>(integer_type, ngraph::Shape{1}, std::vector<int64_t>{scatter_setup.axis});
|
||||
const ngraph::element::Type_t& numericType,
|
||||
const ngraph::element::Type_t& integerType,
|
||||
const ScatterTestCase& scatterSetup,
|
||||
ShapeType indicesUpdatesShapeType) const {
|
||||
const auto data = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.dataShape);
|
||||
const auto indices = std::make_shared<ngraph::opset3::Parameter>(integerType, scatterSetup.indicesShape);
|
||||
const auto updates = std::make_shared<ngraph::opset3::Parameter>(numericType, scatterSetup.updatesShape);
|
||||
const auto axis = std::make_shared<ngraph::opset3::Constant>(integerType, ngraph::Shape{1}, std::vector<int64_t>{scatterSetup.axis});
|
||||
|
||||
const auto dataDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.dataShape.size()});
|
||||
const auto dataDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dataDims);
|
||||
|
||||
ngraph::ParameterVector params{data, indices, updates, dataDims};
|
||||
|
||||
std::shared_ptr<ngraph::Node> scatterIndices = indices;
|
||||
std::shared_ptr<ngraph::Node> scatterUpdates = updates;
|
||||
if (indicesUpdatesShapeType == ShapeType::DYNAMIC) {
|
||||
const auto indicesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.indicesShape.size()});
|
||||
scatterIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(indices, indicesDims);
|
||||
params.push_back(indicesDims);
|
||||
const auto updatesDims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatterSetup.updatesShape.size()});
|
||||
scatterUpdates = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(updates, updatesDims);
|
||||
params.push_back(updatesDims);
|
||||
}
|
||||
|
||||
|
||||
const auto dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{scatter_setup.data_shape.size()});
|
||||
const auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(data, dims);
|
||||
const auto node = ngraph::helpers::getNodeSharedPtr(scatterSetup.scatterTypeInfo, {dataDSR, scatterIndices, scatterUpdates, axis});
|
||||
|
||||
const auto node = ngraph::helpers::getNodeSharedPtr(scatter_setup.scatter_type_info, {dsr, indices, updates, axis});
|
||||
std::shared_ptr<ngraph::Node> outNode = node;
|
||||
const auto outDSR = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, dataDims);
|
||||
|
||||
const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(node, dims);
|
||||
return std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{dsr1},
|
||||
ngraph::ParameterVector{data, indices, updates, dims},
|
||||
ngraph::NodeVector{outDSR},
|
||||
params,
|
||||
"Expected");
|
||||
}
|
||||
};
|
||||
@ -103,6 +146,11 @@ INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticShapeScatter, testing::Comb
|
||||
ngraph::element::i64,
|
||||
ngraph::element::u8),
|
||||
testing::Values(
|
||||
ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1})));
|
||||
ScatterTestCase{ngraph::opset3::ScatterUpdate::type_info, {1000, 256, 10, 15}, {125, 20}, {1000, 125, 20, 10, 15}, 1},
|
||||
ScatterTestCase{ngraph::opset5::ScatterElementsUpdate::type_info, {300}, {300}, {300}, 0}),
|
||||
testing::Values(
|
||||
ShapeType::DYNAMIC,
|
||||
ShapeType::STATIC)
|
||||
));
|
||||
|
||||
} // namespace
|
||||
|
@ -56,6 +56,12 @@ INSTANTIATE_TEST_CASE_P(smoke_DynamicScatter, DSR_Scatter,
|
||||
{{84, 256, 7, 7}, {100, 256, 7, 7}},
|
||||
{{84}, {100}},
|
||||
{{84, 256, 7, 7}, {100, 256, 7, 7}},
|
||||
0},
|
||||
ScatterTestCase{
|
||||
ngraph::opset5::ScatterElementsUpdate::type_info,
|
||||
{{142}, {300}},
|
||||
{{80}, {300}},
|
||||
{{80}, {300}},
|
||||
0}),
|
||||
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user