From 234fe9293173014e0d422657f716074a12aeea6f Mon Sep 17 00:00:00 2001 From: Roman Lyamin Date: Tue, 11 Apr 2023 09:42:51 +0400 Subject: [PATCH] [GPU] MVN 1d dynamic batch case fix (#16826) --- src/plugins/intel_gpu/src/plugin/ops/parameter.cpp | 11 ++++++++++- .../plugin/gpu/single_layer_tests/dynamic/mvn.cpp | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/ops/parameter.cpp b/src/plugins/intel_gpu/src/plugin/ops/parameter.cpp index 10d6f47f99a..f31f2aa70d7 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/parameter.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/parameter.cpp @@ -28,10 +28,19 @@ static void CreateParameterOp(Program& p, const std::shared_ptrget_friendly_name()); // first create and add the input layout const auto inputDesc = inputInfo->getTensorDesc(); - auto input_pshape = op->get_partial_shape(); InferenceEngine::Layout l = inputDesc.getLayout(); InferenceEngine::Precision ip = inputDesc.getPrecision(); + auto input_pshape = op->get_partial_shape(); + if (!p.use_new_shape_infer()) { + if (input_pshape.size() < 4) { + input_pshape.insert(input_pshape.end(), 4 - input_pshape.size(), ov::Dimension(1)); + } + if (p.m_max_batch > 1) { + input_pshape[0] = ov::Dimension(p.m_curBatch); + } + } + cldnn::format inputFormat = cldnn::format::get_default_format(input_pshape.size()); std::vector default_order(input_pshape.size()); std::iota(default_order.begin(), default_order.end(), 0); diff --git a/src/tests/functional/plugin/gpu/single_layer_tests/dynamic/mvn.cpp b/src/tests/functional/plugin/gpu/single_layer_tests/dynamic/mvn.cpp index 88f592fada5..2333344e509 100644 --- a/src/tests/functional/plugin/gpu/single_layer_tests/dynamic/mvn.cpp +++ b/src/tests/functional/plugin/gpu/single_layer_tests/dynamic/mvn.cpp @@ -241,8 +241,8 @@ const std::vector epsilon = { const std::vector emptyReductionAxes = {{}}; -std::vector inpPrc = {ElementType::i8, ElementType::bf16, ElementType::f32}; -std::vector outPrc = {ElementType::bf16, ElementType::f32}; +std::vector inpPrc = {ElementType::i8, ElementType::f16, ElementType::f32}; +std::vector outPrc = {ElementType::f16, ElementType::f32}; const auto Mvn3D = ::testing::Combine( ::testing::Combine(