[PT FE]: extend batch norm to support training mode (#17040)

This commit is contained in:
Ekaterina Aidova 2023-04-25 11:27:00 +02:00 committed by GitHub
parent f736c71feb
commit 39ed9a624f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 9 deletions

View File

@ -6,9 +6,15 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_mean.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
@ -39,6 +45,10 @@ OutputVector translate_batch_norm(const NodeContext& context) {
auto input = context.get_input(0);
Output<Node> weight;
Output<Node> bias;
Output<Node> running_mean;
Output<Node> running_var;
Output<Node> current_mean;
Output<Node> current_var;
if (!context.input_is_none(1)) {
weight = context.get_input(1);
} else {
@ -53,10 +63,34 @@ OutputVector translate_batch_norm(const NodeContext& context) {
}
// index 3 running_mean and index 4 running_var can be none for training case only, check that not training before
auto training = context.const_input<bool>(5);
FRONT_END_OP_CONVERSION_CHECK(!training, "Translation for aten::batch_norm do not support training mode.");
auto running_mean = context.get_input(3);
auto running_var = context.get_input(4);
// Input with index 6 is momentum, it is used only in training mode
// if training for batch norm activated, but model in eval mode, it uses current statistics instead of running
if (training) {
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
auto rank_unsq = context.mark_node(std::make_shared<v3::ShapeOf>(input_shape, element::i32));
auto rank = context.mark_node(std::make_shared<v0::Squeeze>(rank_unsq, zero));
auto after_channel_dims = context.mark_node(std::make_shared<v0::Range>(two, rank, one));
auto axes = context.mark_node(std::make_shared<v0::Concat>(OutputVector{zero_1d, after_channel_dims}, 0));
current_mean = context.mark_node(std::make_shared<v1::ReduceMean>(input, axes, false));
auto mean = context.mark_node(std::make_shared<v1::ReduceMean>(input, axes, true));
auto sub_v = context.mark_node(std::make_shared<v1::Subtract>(input, mean));
auto sqr_sub = context.mark_node(std::make_shared<v1::Multiply>(sub_v, sub_v));
current_var = context.mark_node(std::make_shared<v1::ReduceMean>(sqr_sub, axes, false));
}
if (!training) {
running_mean = context.get_input(3);
} else {
running_mean = current_mean;
}
if (!training) {
running_var = context.get_input(4);
} else {
running_var = current_var;
}
// Input with index 6 is momentum, it is used only for updating running_mean accumulation during training
auto epsilon = context.const_input<float>(7);
// Input with index 8 is flag "cudnn_enabled" we can ignore it
return {context.mark_node(

View File

@ -12,7 +12,7 @@ class TestBatchNorm(PytorchLayerTest):
shape = shape5d[:ndim]
return (np.random.randn(*shape).astype(np.float32),)
def create_model(self, weights, bias, eps):
def create_model(self, weights, bias, eps, train, running_stats):
import torch
import torch.nn.functional as F
@ -29,20 +29,34 @@ class TestBatchNorm(PytorchLayerTest):
def forward(self, x):
return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=False)
class aten_batch_norm_train(torch.nn.Module):
def __init__(self, weights=True, bias=True, eps=1e-05, running_stats=False):
super(aten_batch_norm_train, self).__init__()
self.weight = torch.randn(6) if weights else None
self.bias = torch.randn(6) if bias else None
self.running_mean = torch.randn(6) if running_stats else None
self.running_var = torch.randn(6) if running_stats else None
self.eps = eps
def forward(self, x):
return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=True)
ref_net = None
return aten_batch_norm_inference(weights, bias, eps), ref_net, "aten::batch_norm"
return aten_batch_norm_inference(weights, bias, eps) if not train else aten_batch_norm_train(weights, bias, eps, running_stats), ref_net, "aten::batch_norm"
@pytest.mark.parametrize("weights", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.00005, 0.5, 0.042])
@pytest.mark.parametrize("kwargs_to_prepare_input", [
@pytest.mark.parametrize(("train", "running_stats"), [(True, False), (True, True), (False, False)])
@pytest.mark.parametrize("kwargs_to_prepare_input",
[
{"ndim": 3},
{'ndim': 4},
{"ndim": 5}
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_batch_norm(self, weights, bias, eps, ie_device, precision, ir_version, kwargs_to_prepare_input):
self._test(*self.create_model(weights, bias, eps),
def test_batch_norm(self, weights, bias, eps, train, running_stats, ie_device, precision, ir_version, kwargs_to_prepare_input):
self._test(*self.create_model(weights, bias, eps, train, running_stats),
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=False, use_mo_convert=False)