diff --git a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp index 78272257309..36f8e52a720 100644 --- a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp @@ -420,7 +420,8 @@ std::shared_ptr make_shape_inference(const std::shared_ptr(op) || ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || - ov::is_type(op) || ov::is_type(op)) { + ov::is_type(op) || ov::is_type(op) || + ov::is_type(op)) { return std::make_shared(op); } else if (ov::is_type(op) || ov::is_type(op) || diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/scatter_update_shape_inference_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/scatter_update_shape_inference_test.cpp new file mode 100644 index 00000000000..3c36ea97ac9 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/scatter_update_shape_inference_test.cpp @@ -0,0 +1,134 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +using namespace ov; +using namespace ov::intel_cpu; + +TEST(StaticShapeInferenceTest, ScatterUpdate_3D_axis_1) { + auto data_param = std::make_shared(element::i32, PartialShape{-1, -1, -1}); + auto indices_param = std::make_shared(element::i32, PartialShape{-1, -1}); + auto updates_param = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1}); + auto axis_param = std::make_shared(element::i32, PartialShape{-1}); + + auto scatter_update = std::make_shared(data_param, indices_param, updates_param, axis_param); + + int32_t axis_val[] = {1}; + std::map> constant_data; + constant_data[3] = std::make_shared(element::Type_t::i32, Shape{1}, axis_val); + std::vector input_shapes = {StaticShape{2, 3, 4}, + StaticShape{2, 1}, + StaticShape{2, 2, 1, 4}, + StaticShape{1}}, + output_shapes = {StaticShape{}}; + shape_inference(scatter_update.get(), input_shapes, output_shapes, constant_data); + EXPECT_EQ(output_shapes[0], StaticShape({2, 3, 4})); +} + +TEST(StaticShapeInferenceTest, ScatterUpdate_4D_axis_2) { + auto data_param = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1}); + auto indices_param = std::make_shared(element::i32, PartialShape{-1, -1}); + auto updates_param = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1, -1}); + auto axis_param = std::make_shared(element::i32, PartialShape{-1}); + + auto scatter_update = std::make_shared(data_param, indices_param, updates_param, axis_param); + + int32_t axis_val[] = {2}; + std::map> constant_data; + constant_data[3] = std::make_shared(element::Type_t::i32, Shape{1}, axis_val); + std::vector input_shapes = {StaticShape{1000, 256, 10, 15}, + StaticShape{125, 20}, + StaticShape{1000, 125, 20, 10, 15}, + StaticShape{1}}, + output_shapes = {StaticShape{}}; + shape_inference(scatter_update.get(), input_shapes, output_shapes, constant_data); + EXPECT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 15})); +} + +TEST(StaticShapeInferenceTest, ScatterUpdate_4D_incompatible_axis) { + auto data_param = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1}); + auto indices_param = std::make_shared(element::i32, PartialShape{-1, -1}); + auto updates_param = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1, -1}); + auto axis_param = std::make_shared(element::i32, PartialShape{-1}); + + auto scatter_update = std::make_shared(data_param, indices_param, updates_param, axis_param); + + int32_t axis_val[] = {1}; + std::map> constant_data; + constant_data[3] = std::make_shared(element::Type_t::i32, Shape{1}, axis_val); + std::vector input_shapes = {StaticShape{1000, 256, 10, 15}, + StaticShape{125, 20}, + StaticShape{1000, 125, 20, 10, 15}, + StaticShape{1}}, + output_shapes = {StaticShape{}}; + shape_inference(scatter_update.get(), input_shapes, output_shapes, constant_data); + EXPECT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 15})); +} + +TEST(StaticShapeInferenceTest, ScatterUpdate_axis_as_const) { + auto data_param = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1}); + auto indices_param = std::make_shared(element::i32, PartialShape{-1, -1}); + auto updates_param = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1, -1}); + auto axis_const = std::make_shared(element::i32, Shape{1}, std::vector{1}); + + auto scatter_update = std::make_shared(data_param, indices_param, updates_param, axis_const); + + std::vector input_shapes = {StaticShape{1000, 256, 10, 15}, + StaticShape{125, 20}, + StaticShape{1000, 125, 20, 10, 15}, + StaticShape{1}}, + output_shapes = {StaticShape{}}; + shape_inference(scatter_update.get(), input_shapes, output_shapes); + EXPECT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 15})); +} + +TEST(StaticShapeInferenceTest, ScatterUpdate_dynamic_rank) { + auto data_param = std::make_shared(element::i32, PartialShape::dynamic()); + auto indices_param = std::make_shared(element::i32, PartialShape::dynamic()); + auto updates_param = std::make_shared(element::i32, PartialShape::dynamic()); + auto axis_param = std::make_shared(element::i32, PartialShape::dynamic()); + + auto scatter_update = std::make_shared(data_param, indices_param, updates_param, axis_param); + + int32_t axis_val[] = {1}; + std::map> constant_data; + constant_data[3] = std::make_shared(element::Type_t::i32, Shape{1}, axis_val); + std::vector input_shapes = {StaticShape{1000, 256, 10, 15}, + StaticShape{125, 20}, + StaticShape{1000, 125, 20, 10, 15}, + StaticShape{1}}, + output_shapes = {StaticShape{}}; + shape_inference(scatter_update.get(), input_shapes, output_shapes, constant_data); + EXPECT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 15})); +} + +TEST(StaticShapeInferenceTest, ScatterUpdate_params_dynamic_rank_incorrect_updates_shape) { + auto data_param = std::make_shared(element::i32, PartialShape::dynamic()); + auto indices_param = std::make_shared(element::i32, PartialShape::dynamic()); + auto updates_param = std::make_shared(element::i32, PartialShape::dynamic()); + auto axis_param = std::make_shared(element::i32, PartialShape::dynamic()); + + auto scatter_update = std::make_shared(data_param, indices_param, updates_param, axis_param); + + int32_t axis_val[] = {1}; + std::map> constant_data; + constant_data[3] = std::make_shared(element::Type_t::i32, Shape{1}, axis_val); + + // Incorrect rank of the third input shape + std::vector input_shapes = {StaticShape{1000, 256, 10, 15}, + StaticShape{125, 20, 1, 1, 1}, + StaticShape{1000, 125, 20, 10}, + StaticShape{1}}, + output_shapes = {StaticShape{}}; + + // ScatterUpdate shape_inference is implemented by usage of entryFirstPassthrough, no additional checks + shape_inference(scatter_update.get(), input_shapes, output_shapes, constant_data); + EXPECT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 15})); +}