[PT FE]: extend batch norm to support training mode (#17040)
This commit is contained in:
parent
f736c71feb
commit
39ed9a624f
@ -6,9 +6,15 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reduce_mean.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@ -39,6 +45,10 @@ OutputVector translate_batch_norm(const NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
Output<Node> weight;
|
||||
Output<Node> bias;
|
||||
Output<Node> running_mean;
|
||||
Output<Node> running_var;
|
||||
Output<Node> current_mean;
|
||||
Output<Node> current_var;
|
||||
if (!context.input_is_none(1)) {
|
||||
weight = context.get_input(1);
|
||||
} else {
|
||||
@ -53,10 +63,34 @@ OutputVector translate_batch_norm(const NodeContext& context) {
|
||||
}
|
||||
// index 3 running_mean and index 4 running_var can be none for training case only, check that not training before
|
||||
auto training = context.const_input<bool>(5);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!training, "Translation for aten::batch_norm do not support training mode.");
|
||||
auto running_mean = context.get_input(3);
|
||||
auto running_var = context.get_input(4);
|
||||
// Input with index 6 is momentum, it is used only in training mode
|
||||
// if training for batch norm activated, but model in eval mode, it uses current statistics instead of running
|
||||
if (training) {
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto rank_unsq = context.mark_node(std::make_shared<v3::ShapeOf>(input_shape, element::i32));
|
||||
auto rank = context.mark_node(std::make_shared<v0::Squeeze>(rank_unsq, zero));
|
||||
auto after_channel_dims = context.mark_node(std::make_shared<v0::Range>(two, rank, one));
|
||||
auto axes = context.mark_node(std::make_shared<v0::Concat>(OutputVector{zero_1d, after_channel_dims}, 0));
|
||||
current_mean = context.mark_node(std::make_shared<v1::ReduceMean>(input, axes, false));
|
||||
auto mean = context.mark_node(std::make_shared<v1::ReduceMean>(input, axes, true));
|
||||
auto sub_v = context.mark_node(std::make_shared<v1::Subtract>(input, mean));
|
||||
auto sqr_sub = context.mark_node(std::make_shared<v1::Multiply>(sub_v, sub_v));
|
||||
current_var = context.mark_node(std::make_shared<v1::ReduceMean>(sqr_sub, axes, false));
|
||||
}
|
||||
if (!training) {
|
||||
running_mean = context.get_input(3);
|
||||
} else {
|
||||
running_mean = current_mean;
|
||||
}
|
||||
if (!training) {
|
||||
running_var = context.get_input(4);
|
||||
} else {
|
||||
running_var = current_var;
|
||||
}
|
||||
// Input with index 6 is momentum, it is used only for updating running_mean accumulation during training
|
||||
auto epsilon = context.const_input<float>(7);
|
||||
// Input with index 8 is flag "cudnn_enabled" we can ignore it
|
||||
return {context.mark_node(
|
||||
|
@ -12,7 +12,7 @@ class TestBatchNorm(PytorchLayerTest):
|
||||
shape = shape5d[:ndim]
|
||||
return (np.random.randn(*shape).astype(np.float32),)
|
||||
|
||||
def create_model(self, weights, bias, eps):
|
||||
def create_model(self, weights, bias, eps, train, running_stats):
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -29,20 +29,34 @@ class TestBatchNorm(PytorchLayerTest):
|
||||
def forward(self, x):
|
||||
return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=False)
|
||||
|
||||
class aten_batch_norm_train(torch.nn.Module):
|
||||
def __init__(self, weights=True, bias=True, eps=1e-05, running_stats=False):
|
||||
super(aten_batch_norm_train, self).__init__()
|
||||
self.weight = torch.randn(6) if weights else None
|
||||
self.bias = torch.randn(6) if bias else None
|
||||
self.running_mean = torch.randn(6) if running_stats else None
|
||||
self.running_var = torch.randn(6) if running_stats else None
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=True)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_batch_norm_inference(weights, bias, eps), ref_net, "aten::batch_norm"
|
||||
return aten_batch_norm_inference(weights, bias, eps) if not train else aten_batch_norm_train(weights, bias, eps, running_stats), ref_net, "aten::batch_norm"
|
||||
|
||||
@pytest.mark.parametrize("weights", [True, False])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("eps", [1.0, 0.00005, 0.5, 0.042])
|
||||
@pytest.mark.parametrize("kwargs_to_prepare_input", [
|
||||
@pytest.mark.parametrize(("train", "running_stats"), [(True, False), (True, True), (False, False)])
|
||||
@pytest.mark.parametrize("kwargs_to_prepare_input",
|
||||
[
|
||||
{"ndim": 3},
|
||||
{'ndim': 4},
|
||||
{"ndim": 5}
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_batch_norm(self, weights, bias, eps, ie_device, precision, ir_version, kwargs_to_prepare_input):
|
||||
self._test(*self.create_model(weights, bias, eps),
|
||||
def test_batch_norm(self, weights, bias, eps, train, running_stats, ie_device, precision, ir_version, kwargs_to_prepare_input):
|
||||
self._test(*self.create_model(weights, bias, eps, train, running_stats),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=False, use_mo_convert=False)
|
Loading…
Reference in New Issue
Block a user