From 39ed9a624f08917d222a372fddb16cf92eb18b54 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 25 Apr 2023 11:27:00 +0200 Subject: [PATCH] [PT FE]: extend batch norm to support training mode (#17040) --- src/frontends/pytorch/src/op/batch_norm.cpp | 42 +++++++++++++++++-- .../pytorch_tests/test_batch_norm.py | 24 ++++++++--- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/src/frontends/pytorch/src/op/batch_norm.cpp b/src/frontends/pytorch/src/op/batch_norm.cpp index a306dd21832..f99852341e6 100644 --- a/src/frontends/pytorch/src/op/batch_norm.cpp +++ b/src/frontends/pytorch/src/op/batch_norm.cpp @@ -6,9 +6,15 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/gather.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reduce_mean.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/unsqueeze.hpp" #include "utils.hpp" @@ -39,6 +45,10 @@ OutputVector translate_batch_norm(const NodeContext& context) { auto input = context.get_input(0); Output weight; Output bias; + Output running_mean; + Output running_var; + Output current_mean; + Output current_var; if (!context.input_is_none(1)) { weight = context.get_input(1); } else { @@ -53,10 +63,34 @@ OutputVector translate_batch_norm(const NodeContext& context) { } // index 3 running_mean and index 4 running_var can be none for training case only, check that not training before auto training = context.const_input(5); - FRONT_END_OP_CONVERSION_CHECK(!training, "Translation for aten::batch_norm do not support training mode."); - auto running_mean = context.get_input(3); - auto running_var = context.get_input(4); - // Input with index 6 is momentum, it is used only in training mode + // if training for batch norm activated, but model in eval mode, it uses current statistics instead of running + if (training) { + auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2})); + auto input_shape = context.mark_node(std::make_shared(input, element::i32)); + auto rank_unsq = context.mark_node(std::make_shared(input_shape, element::i32)); + auto rank = context.mark_node(std::make_shared(rank_unsq, zero)); + auto after_channel_dims = context.mark_node(std::make_shared(two, rank, one)); + auto axes = context.mark_node(std::make_shared(OutputVector{zero_1d, after_channel_dims}, 0)); + current_mean = context.mark_node(std::make_shared(input, axes, false)); + auto mean = context.mark_node(std::make_shared(input, axes, true)); + auto sub_v = context.mark_node(std::make_shared(input, mean)); + auto sqr_sub = context.mark_node(std::make_shared(sub_v, sub_v)); + current_var = context.mark_node(std::make_shared(sqr_sub, axes, false)); + } + if (!training) { + running_mean = context.get_input(3); + } else { + running_mean = current_mean; + } + if (!training) { + running_var = context.get_input(4); + } else { + running_var = current_var; + } + // Input with index 6 is momentum, it is used only for updating running_mean accumulation during training auto epsilon = context.const_input(7); // Input with index 8 is flag "cudnn_enabled" we can ignore it return {context.mark_node( diff --git a/tests/layer_tests/pytorch_tests/test_batch_norm.py b/tests/layer_tests/pytorch_tests/test_batch_norm.py index 07d9733ceb8..339ca2219bb 100644 --- a/tests/layer_tests/pytorch_tests/test_batch_norm.py +++ b/tests/layer_tests/pytorch_tests/test_batch_norm.py @@ -12,7 +12,7 @@ class TestBatchNorm(PytorchLayerTest): shape = shape5d[:ndim] return (np.random.randn(*shape).astype(np.float32),) - def create_model(self, weights, bias, eps): + def create_model(self, weights, bias, eps, train, running_stats): import torch import torch.nn.functional as F @@ -29,20 +29,34 @@ class TestBatchNorm(PytorchLayerTest): def forward(self, x): return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=False) + class aten_batch_norm_train(torch.nn.Module): + def __init__(self, weights=True, bias=True, eps=1e-05, running_stats=False): + super(aten_batch_norm_train, self).__init__() + self.weight = torch.randn(6) if weights else None + self.bias = torch.randn(6) if bias else None + self.running_mean = torch.randn(6) if running_stats else None + self.running_var = torch.randn(6) if running_stats else None + self.eps = eps + + def forward(self, x): + return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=True) + ref_net = None - return aten_batch_norm_inference(weights, bias, eps), ref_net, "aten::batch_norm" + return aten_batch_norm_inference(weights, bias, eps) if not train else aten_batch_norm_train(weights, bias, eps, running_stats), ref_net, "aten::batch_norm" @pytest.mark.parametrize("weights", [True, False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("eps", [1.0, 0.00005, 0.5, 0.042]) - @pytest.mark.parametrize("kwargs_to_prepare_input", [ + @pytest.mark.parametrize(("train", "running_stats"), [(True, False), (True, True), (False, False)]) + @pytest.mark.parametrize("kwargs_to_prepare_input", + [ {"ndim": 3}, {'ndim': 4}, {"ndim": 5} ]) @pytest.mark.nightly @pytest.mark.precommit - def test_batch_norm(self, weights, bias, eps, ie_device, precision, ir_version, kwargs_to_prepare_input): - self._test(*self.create_model(weights, bias, eps), + def test_batch_norm(self, weights, bias, eps, train, running_stats, ie_device, precision, ir_version, kwargs_to_prepare_input): + self._test(*self.create_model(weights, bias, eps, train, running_stats), ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=False, use_mo_convert=False) \ No newline at end of file