diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp index 233397fe2d5..cff79f56120 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp @@ -58,12 +58,13 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() { auto gamma_div_scale = std::make_shared(m_gamma, scale); int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2; - - // TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf - Shape input_aligned_shape = m_gamma.get_shape(); - for (int64_t i = 0; i < dims_to_add; ++i) - input_aligned_shape.push_back(1); - auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape); + const auto one = op::Constant::create(element::i64, Shape{1}, {1}); + const auto tail_shape_rank = op::Constant::create(element::i64, Shape{1}, {dims_to_add}); + const auto tail_shape = std::make_shared(one, tail_shape_rank); + const auto C_dim = std::make_shared(m_gamma); + // create new shape [1, C, 1, 1, ...] + const auto new_shape = std::make_shared( + OutputVector{one, C_dim, tail_shape}, 0); auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true); auto beta_aligned = make_shared(m_beta, new_shape, true);