Use ShapeOf to create a new shape for gamma, mean, bias in BatchNormDecomposition (#5157)

This commit is contained in:
Mateusz Tabaka 2021-04-13 15:00:12 +02:00 committed by GitHub
parent 0bbe9c73e6
commit 984a55ec88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,12 +58,13 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
Shape input_aligned_shape = m_gamma.get_shape();
for (int64_t i = 0; i < dims_to_add; ++i)
input_aligned_shape.push_back(1);
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
const auto one = op::Constant::create(element::i64, Shape{1}, {1});
const auto tail_shape_rank = op::Constant::create(element::i64, Shape{1}, {dims_to_add});
const auto tail_shape = std::make_shared<opset5::Broadcast>(one, tail_shape_rank);
const auto C_dim = std::make_shared<opset5::ShapeOf>(m_gamma);
// create new shape [1, C, 1, 1, ...]
const auto new_shape = std::make_shared<opset5::Concat>(
OutputVector{one, C_dim, tail_shape}, 0);
auto gamma_div_scale_aligned = make_shared<opset5::Reshape>(gamma_div_scale, new_shape, true);
auto beta_aligned = make_shared<opset5::Reshape>(m_beta, new_shape, true);