[GPU] add shape infer in scatter elements update (#20250)

* add shape infer in scatter elements update

* output shape is same with input shape in dynamic case
This commit is contained in:
Wilson Seok 2023-10-25 16:01:52 +09:00 committed by GitHub
parent a71283ea94
commit 9d56c31581
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 6 deletions

View File

@ -9,6 +9,8 @@
#include "json_object.h"
#include <string>
#include "scatter_elements_update_shape_inference.hpp"
namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(scatter_elements_update)
@ -16,11 +18,11 @@ layout scatter_elements_update_inst::calc_output_layout(scatter_elements_update_
auto desc = impl_param.typed_desc<scatter_elements_update>();
const int32_t axis = desc->axis;
const size_t input_number_of_dims = impl_param.get_input_layout().get_tensor().sizes().size();
const size_t input_number_of_dims = impl_param.get_input_layout().get_partial_shape().size();
auto input_layout = impl_param.get_input_layout();
auto output_shape = input_layout.get_tensor();
auto output_shape = input_layout.get_partial_shape();
auto input_format = input_layout.format;
auto output_type = input_layout.data_type;
@ -31,7 +33,7 @@ layout scatter_elements_update_inst::calc_output_layout(scatter_elements_update_
if (static_cast<size_t>(axis) < 0 || static_cast<size_t>(axis) >= input_number_of_dims)
CLDNN_ERROR_MESSAGE(desc->id, "Incorrect axis value for ScatterElementsUpdate: Axis must be positive and less than the input tensor dimension.");
return layout{output_type, input_format, output_shape};
return layout{output_shape, output_type, input_format};
}
std::string scatter_elements_update_inst::to_string(scatter_elements_update_node const& node) {

View File

@ -173,7 +173,7 @@ TEST_P(ScatterUpdateLayerGPUTest, CompareWithRefs) {
namespace ScatterNDUpdate {
const std::vector<ScatterUpdateLayerParams> scatterParams = {
const std::vector<ScatterUpdateLayerParams> scatterNDParams = {
ScatterUpdateLayerParams{
ScatterUpdateShapes{
{{-1, -1, -1, -1, -1}, {{10, 9, 10, 9, 10}, {10, 1, 11, 2, 5}, {10, 15, 8, 1, 7}}},
@ -212,6 +212,39 @@ const std::vector<ScatterUpdateLayerParams> scatterParams = {
},
};
const std::vector<ScatterUpdateLayerParams> scatterElementsParams = {
ScatterUpdateLayerParams{
ScatterUpdateShapes{
{{-1, -1, -1, -1, -1}, {{10, 9, 10, 9, 10}, {10, 5, 11, 4, 5}, {10, 15, 8, 1, 7}}},
{{-1, -1, -1, -1, -1 }, {{3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}}},
{{-1, -1, -1, -1, -1 }, {{3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}}},
{{1}, {{1}}}
},
IndicesValues{ 5, 6, 2, 8, 5, 6, 2, 8, 5, 6, 2, 8 },
Scatterupdate_type::Elements
},
ScatterUpdateLayerParams{
ScatterUpdateShapes{
{{-1, -1, -1, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{-1, -1, -1, -1}, {{3, 1, 2, 3}, {3, 1, 2, 3}, {3, 1, 2, 3}}},
{{-1, -1, -1, -1}, {{3, 1, 2, 3}, {3, 1, 2, 3}, {3, 1, 2, 3}}},
{{1}, {{1}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2 },
Scatterupdate_type::Elements
},
ScatterUpdateLayerParams{
ScatterUpdateShapes{
{{{3, 10}, -1, {3, 9}, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{2, -1, 3, -1}, {{2, 1, 3, 1}, {2, 1, 3, 1}, {2, 1, 3, 1}}},
{{2, -1, 3, -1}, {{2, 1, 3, 1}, {2, 1, 3, 1}, {2, 1, 3, 1}}},
{{1}, {{1}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2 },
Scatterupdate_type::Elements
},
};
const std::vector<ElementType> inputPrecisions = {
ElementType::f32,
};
@ -260,7 +293,14 @@ const std::vector<ScatterUpdateLayerParams> scatterElementsUpdate_EmptyInput1_2P
INSTANTIATE_TEST_SUITE_P(smoke_ScatterNDUpdate_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(scatterParams),
::testing::ValuesIn(scatterNDParams),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(constantPrecisions)),
ScatterUpdateLayerGPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_ScatterElementsUpdate_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(scatterElementsParams),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(constantPrecisions)),
ScatterUpdateLayerGPUTest::getTestCaseName);
@ -280,7 +320,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_ScatterNDUpdate_EmptyInput1_2_CompareWithRefs_dyn
ScatterUpdateLayerGPUTest::getTestCaseName);
// ScatterELementsUpdate doesn't support dynamic shape yet. Need to enable when it supports.
INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_ScatterElementsUpdate_EmptyInput1_2_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
INSTANTIATE_TEST_SUITE_P(smoke_ScatterElementsUpdate_EmptyInput1_2_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(scatterElementsUpdate_EmptyInput1_2Params),
::testing::ValuesIn(inputPrecisions),