[GPU] MVN 1d dynamic batch case fix (#16826)

This commit is contained in:
Roman Lyamin 2023-04-11 09:42:51 +04:00 committed by GitHub
parent efc647a512
commit 234fe92931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 3 deletions

View File

@ -28,10 +28,19 @@ static void CreateParameterOp(Program& p, const std::shared_ptr<ngraph::op::v0::
auto inputInfo = networkInputs.at(op->get_friendly_name());
// first create and add the input layout
const auto inputDesc = inputInfo->getTensorDesc();
auto input_pshape = op->get_partial_shape();
InferenceEngine::Layout l = inputDesc.getLayout();
InferenceEngine::Precision ip = inputDesc.getPrecision();
auto input_pshape = op->get_partial_shape();
if (!p.use_new_shape_infer()) {
if (input_pshape.size() < 4) {
input_pshape.insert(input_pshape.end(), 4 - input_pshape.size(), ov::Dimension(1));
}
if (p.m_max_batch > 1) {
input_pshape[0] = ov::Dimension(p.m_curBatch);
}
}
cldnn::format inputFormat = cldnn::format::get_default_format(input_pshape.size());
std::vector<size_t> default_order(input_pshape.size());
std::iota(default_order.begin(), default_order.end(), 0);

View File

@ -241,8 +241,8 @@ const std::vector<double> epsilon = {
const std::vector<ngraph::AxisSet> emptyReductionAxes = {{}};
std::vector<ElementType> inpPrc = {ElementType::i8, ElementType::bf16, ElementType::f32};
std::vector<ElementType> outPrc = {ElementType::bf16, ElementType::f32};
std::vector<ElementType> inpPrc = {ElementType::i8, ElementType::f16, ElementType::f32};
std::vector<ElementType> outPrc = {ElementType::f16, ElementType::f32};
const auto Mvn3D = ::testing::Combine(
::testing::Combine(