Use ShapeOf to create a new shape for gamma, mean, bias in BatchNormDecomposition (#5157)
This commit is contained in:
parent
0bbe9c73e6
commit
984a55ec88
@ -58,12 +58,13 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
||||
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
|
||||
|
||||
int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
|
||||
|
||||
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
|
||||
Shape input_aligned_shape = m_gamma.get_shape();
|
||||
for (int64_t i = 0; i < dims_to_add; ++i)
|
||||
input_aligned_shape.push_back(1);
|
||||
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
|
||||
const auto one = op::Constant::create(element::i64, Shape{1}, {1});
|
||||
const auto tail_shape_rank = op::Constant::create(element::i64, Shape{1}, {dims_to_add});
|
||||
const auto tail_shape = std::make_shared<opset5::Broadcast>(one, tail_shape_rank);
|
||||
const auto C_dim = std::make_shared<opset5::ShapeOf>(m_gamma);
|
||||
// create new shape [1, C, 1, 1, ...]
|
||||
const auto new_shape = std::make_shared<opset5::Concat>(
|
||||
OutputVector{one, C_dim, tail_shape}, 0);
|
||||
|
||||
auto gamma_div_scale_aligned = make_shared<opset5::Reshape>(gamma_div_scale, new_shape, true);
|
||||
auto beta_aligned = make_shared<opset5::Reshape>(m_beta, new_shape, true);
|
||||
|
Loading…
Reference in New Issue
Block a user