[GPU] add shape infer in scatter elements update (#20250)
* add shape infer in scatter elements update * output shape is same with input shape in dynamic case
This commit is contained in:
parent
a71283ea94
commit
9d56c31581
@ -9,6 +9,8 @@
|
||||
#include "json_object.h"
|
||||
#include <string>
|
||||
|
||||
#include "scatter_elements_update_shape_inference.hpp"
|
||||
|
||||
namespace cldnn {
|
||||
GPU_DEFINE_PRIMITIVE_TYPE_ID(scatter_elements_update)
|
||||
|
||||
@ -16,11 +18,11 @@ layout scatter_elements_update_inst::calc_output_layout(scatter_elements_update_
|
||||
auto desc = impl_param.typed_desc<scatter_elements_update>();
|
||||
|
||||
const int32_t axis = desc->axis;
|
||||
const size_t input_number_of_dims = impl_param.get_input_layout().get_tensor().sizes().size();
|
||||
const size_t input_number_of_dims = impl_param.get_input_layout().get_partial_shape().size();
|
||||
|
||||
auto input_layout = impl_param.get_input_layout();
|
||||
|
||||
auto output_shape = input_layout.get_tensor();
|
||||
auto output_shape = input_layout.get_partial_shape();
|
||||
auto input_format = input_layout.format;
|
||||
auto output_type = input_layout.data_type;
|
||||
|
||||
@ -31,7 +33,7 @@ layout scatter_elements_update_inst::calc_output_layout(scatter_elements_update_
|
||||
if (static_cast<size_t>(axis) < 0 || static_cast<size_t>(axis) >= input_number_of_dims)
|
||||
CLDNN_ERROR_MESSAGE(desc->id, "Incorrect axis value for ScatterElementsUpdate: Axis must be positive and less than the input tensor dimension.");
|
||||
|
||||
return layout{output_type, input_format, output_shape};
|
||||
return layout{output_shape, output_type, input_format};
|
||||
}
|
||||
|
||||
std::string scatter_elements_update_inst::to_string(scatter_elements_update_node const& node) {
|
||||
|
@ -173,7 +173,7 @@ TEST_P(ScatterUpdateLayerGPUTest, CompareWithRefs) {
|
||||
|
||||
namespace ScatterNDUpdate {
|
||||
|
||||
const std::vector<ScatterUpdateLayerParams> scatterParams = {
|
||||
const std::vector<ScatterUpdateLayerParams> scatterNDParams = {
|
||||
ScatterUpdateLayerParams{
|
||||
ScatterUpdateShapes{
|
||||
{{-1, -1, -1, -1, -1}, {{10, 9, 10, 9, 10}, {10, 1, 11, 2, 5}, {10, 15, 8, 1, 7}}},
|
||||
@ -212,6 +212,39 @@ const std::vector<ScatterUpdateLayerParams> scatterParams = {
|
||||
},
|
||||
};
|
||||
|
||||
const std::vector<ScatterUpdateLayerParams> scatterElementsParams = {
|
||||
ScatterUpdateLayerParams{
|
||||
ScatterUpdateShapes{
|
||||
{{-1, -1, -1, -1, -1}, {{10, 9, 10, 9, 10}, {10, 5, 11, 4, 5}, {10, 15, 8, 1, 7}}},
|
||||
{{-1, -1, -1, -1, -1 }, {{3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}}},
|
||||
{{-1, -1, -1, -1, -1 }, {{3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}, {3, 2, 1, 2, 1}}},
|
||||
{{1}, {{1}}}
|
||||
},
|
||||
IndicesValues{ 5, 6, 2, 8, 5, 6, 2, 8, 5, 6, 2, 8 },
|
||||
Scatterupdate_type::Elements
|
||||
},
|
||||
ScatterUpdateLayerParams{
|
||||
ScatterUpdateShapes{
|
||||
{{-1, -1, -1, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
|
||||
{{-1, -1, -1, -1}, {{3, 1, 2, 3}, {3, 1, 2, 3}, {3, 1, 2, 3}}},
|
||||
{{-1, -1, -1, -1}, {{3, 1, 2, 3}, {3, 1, 2, 3}, {3, 1, 2, 3}}},
|
||||
{{1}, {{1}}}
|
||||
},
|
||||
IndicesValues{ 0, 1, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2 },
|
||||
Scatterupdate_type::Elements
|
||||
},
|
||||
ScatterUpdateLayerParams{
|
||||
ScatterUpdateShapes{
|
||||
{{{3, 10}, -1, {3, 9}, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
|
||||
{{2, -1, 3, -1}, {{2, 1, 3, 1}, {2, 1, 3, 1}, {2, 1, 3, 1}}},
|
||||
{{2, -1, 3, -1}, {{2, 1, 3, 1}, {2, 1, 3, 1}, {2, 1, 3, 1}}},
|
||||
{{1}, {{1}}}
|
||||
},
|
||||
IndicesValues{ 0, 1, 1, 2, 2, 2 },
|
||||
Scatterupdate_type::Elements
|
||||
},
|
||||
};
|
||||
|
||||
const std::vector<ElementType> inputPrecisions = {
|
||||
ElementType::f32,
|
||||
};
|
||||
@ -260,7 +293,14 @@ const std::vector<ScatterUpdateLayerParams> scatterElementsUpdate_EmptyInput1_2P
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ScatterNDUpdate_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(scatterParams),
|
||||
::testing::ValuesIn(scatterNDParams),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(constantPrecisions)),
|
||||
ScatterUpdateLayerGPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ScatterElementsUpdate_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(scatterElementsParams),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(constantPrecisions)),
|
||||
ScatterUpdateLayerGPUTest::getTestCaseName);
|
||||
@ -280,7 +320,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_ScatterNDUpdate_EmptyInput1_2_CompareWithRefs_dyn
|
||||
ScatterUpdateLayerGPUTest::getTestCaseName);
|
||||
|
||||
// ScatterELementsUpdate doesn't support dynamic shape yet. Need to enable when it supports.
|
||||
INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_ScatterElementsUpdate_EmptyInput1_2_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ScatterElementsUpdate_EmptyInput1_2_CompareWithRefs_dynamic, ScatterUpdateLayerGPUTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(scatterElementsUpdate_EmptyInput1_2Params),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
|
Loading…
Reference in New Issue
Block a user