[PT FE]: support aten::instance_norm (#15213)
This commit is contained in:
parent
b4cb4fe8c9
commit
b2ce43a172
@ -50,7 +50,7 @@ OutputVector translate_conv_transposend(NodeContext& context) {
|
||||
auto bias = context.get_input(2);
|
||||
auto bias_rank = bias.get_partial_shape().rank();
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_conv_bias(context, bias, conv);
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
conv = context.mark_node(std::make_shared<ov::op::v1::Add>(conv, bias));
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ OutputVector translate_convnd(NodeContext& context) {
|
||||
auto bias = context.get_input(2);
|
||||
auto bias_rank = bias.get_partial_shape().rank();
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_conv_bias(context, bias, conv);
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ OutputVector translate_convolution(NodeContext& context) {
|
||||
auto bias = context.get_input(2);
|
||||
auto bias_rank = bias.get_partial_shape().rank();
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_conv_bias(context, bias, conv);
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
|
||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||
|
@ -46,7 +46,7 @@ OutputVector translate_convolution_mode(NodeContext& context) {
|
||||
auto bias = context.get_input(2);
|
||||
auto bias_rank = bias.get_partial_shape().rank();
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_conv_bias(context, bias, conv);
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
|
||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||
|
109
src/frontends/pytorch/src/op/instance_norm.cpp
Normal file
109
src/frontends/pytorch/src/op/instance_norm.cpp
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/batch_norm.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/mvn.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/tile.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
namespace {
|
||||
OutputVector translate_instance_norm_inference(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& reduction_axes,
|
||||
float eps) {
|
||||
auto norm = context.mark_node(
|
||||
std::make_shared<ov::op::v6::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT));
|
||||
if (!context.input_is_none(1)) {
|
||||
auto weight = context.get_input(1);
|
||||
weight = reshape_channelwise(context, weight, norm);
|
||||
norm = context.mark_node(std::make_shared<ov::op::v1::Multiply>(norm, weight));
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
auto bias = context.get_input(2);
|
||||
bias = reshape_channelwise(context, bias, norm);
|
||||
norm = context.mark_node(std::make_shared<ov::op::v1::Add>(norm, bias));
|
||||
}
|
||||
return {norm};
|
||||
}
|
||||
|
||||
OutputVector translate_instance_norm_train(NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& reduction_axes,
|
||||
float eps) {
|
||||
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto batch_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, zero, zero));
|
||||
auto channel_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, one, zero));
|
||||
auto batch_dim_1d = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(batch_dim, zero));
|
||||
auto batch_norm_channels_1d = context.mark_node(std::make_shared<ov::op::v1::Multiply>(batch_dim_1d, channel_dim));
|
||||
auto one_1d = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto tail_shape = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, reduction_axes, zero));
|
||||
auto reshape_shape = context.mark_node(
|
||||
std::make_shared<ov::op::v0::Concat>(OutputVector{one_1d, batch_norm_channels_1d, tail_shape}, 0));
|
||||
auto reshaped_input = context.mark_node(std::make_shared<ov::op::v1::Reshape>(input, reshape_shape, false));
|
||||
Output<Node> weight;
|
||||
Output<Node> bias;
|
||||
if (context.input_is_none(1)) {
|
||||
weight = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(one, batch_norm_channels_1d));
|
||||
weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(weight, input));
|
||||
} else {
|
||||
weight = context.get_input(1);
|
||||
weight = context.mark_node(std::make_shared<ov::op::v0::Tile>(weight, batch_dim_1d));
|
||||
}
|
||||
if (context.input_is_none(2)) {
|
||||
bias = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(zero, batch_norm_channels_1d));
|
||||
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, input));
|
||||
} else {
|
||||
bias = context.get_input(2);
|
||||
bias = context.mark_node(std::make_shared<ov::op::v0::Tile>(bias, batch_dim_1d));
|
||||
}
|
||||
auto running_mean = context.get_input(3);
|
||||
running_mean = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_mean, batch_dim_1d));
|
||||
auto running_var = context.get_input(4);
|
||||
running_var = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_var, batch_dim_1d));
|
||||
auto batch_norm = context.mark_node(
|
||||
std::make_shared<ov::op::v5::BatchNormInference>(reshaped_input, weight, bias, running_mean, running_var, eps));
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Reshape>(batch_norm, input_shape, true))};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_instance_norm(NodeContext& context) {
|
||||
num_inputs_check(context, 8, 9);
|
||||
auto input = context.get_input(0);
|
||||
auto eps = context.const_input<float>(7);
|
||||
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto rank_1d = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input_shape));
|
||||
auto rank = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(rank_1d));
|
||||
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto two = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {2}));
|
||||
auto reduction_axes = context.mark_node(std::make_shared<ov::op::v4::Range>(two, rank, one, element::i64));
|
||||
if (context.input_is_none(3) && context.input_is_none(4)) {
|
||||
return translate_instance_norm_inference(context, input, reduction_axes, eps);
|
||||
}
|
||||
return translate_instance_norm_train(context, input, reduction_axes, eps);
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -49,6 +49,7 @@ OP_CONVERTER(translate_group_norm);
|
||||
OP_CONVERTER(translate_hardtanh);
|
||||
OP_CONVERTER(translate_if);
|
||||
OP_CONVERTER(translate_im2col);
|
||||
OP_CONVERTER(translate_instance_norm);
|
||||
OP_CONVERTER(translate_int);
|
||||
OP_CONVERTER(translate_layer_norm);
|
||||
OP_CONVERTER(translate_len);
|
||||
@ -196,6 +197,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::Int", op::translate_int},
|
||||
{"aten::IntImplicit", op::translate_int},
|
||||
{"aten::im2col", op::translate_im2col},
|
||||
{"aten::instance_norm", op::translate_instance_norm},
|
||||
{"aten::is_grad_enabled", op::return_false_scalar},
|
||||
{"aten::layer_norm", op::translate_layer_norm},
|
||||
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
|
||||
|
@ -14,6 +14,14 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
|
||||
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
|
||||
auto inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() > min_inputs, "Got less inputs than expected");
|
||||
for (auto i = max_inputs; i < inputs.size(); i++) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
|
||||
}
|
||||
}
|
||||
|
||||
Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||
const NodeContext& context,
|
||||
size_t bias_input_idx,
|
||||
@ -34,18 +42,18 @@ Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||
}
|
||||
}
|
||||
|
||||
Output<ov::Node> reshape_conv_bias(const NodeContext& context, Output<ov::Node> bias, Output<ov::Node> conv) {
|
||||
auto conv_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(conv));
|
||||
auto conv_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(conv_shape));
|
||||
Output<ov::Node> reshape_channelwise(const NodeContext& context, Output<ov::Node> data, Output<ov::Node> shape_source) {
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(shape_source));
|
||||
auto input_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(input_shape));
|
||||
auto one_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto two_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2}));
|
||||
auto tail_shape_rank = context.mark_node(std::make_shared<opset10::Subtract>(conv_rank, two_const));
|
||||
auto tail_shape_rank = context.mark_node(std::make_shared<opset10::Subtract>(input_rank, two_const));
|
||||
auto tail_shape = context.mark_node(std::make_shared<opset10::Broadcast>(one_const, tail_shape_rank));
|
||||
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(bias));
|
||||
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(data));
|
||||
auto new_shape =
|
||||
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{one_const, channels_dim, tail_shape}, 0));
|
||||
|
||||
return context.mark_node(std::make_shared<opset10::Reshape>(bias, new_shape, false));
|
||||
return context.mark_node(std::make_shared<opset10::Reshape>(data, new_shape, false));
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> get_rank_node(const Output<Node>& node) {
|
||||
|
@ -18,12 +18,16 @@ class FrameworkNode;
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
|
||||
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs);
|
||||
|
||||
Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||
const NodeContext& context,
|
||||
size_t bias_input_idx,
|
||||
const std::vector<int>& unsqueeze_dims = {});
|
||||
|
||||
Output<ov::Node> reshape_conv_bias(const NodeContext& context, Output<ov::Node> bias, Output<ngraph::Node> conv);
|
||||
Output<ov::Node> reshape_channelwise(const NodeContext& context,
|
||||
Output<ov::Node> data,
|
||||
Output<ngraph::Node> shape_source);
|
||||
|
||||
std::shared_ptr<ov::Node> get_rank_node(const Output<Node>& node);
|
||||
|
||||
|
64
tests/layer_tests/pytorch_tests/test_instance_norm.py
Normal file
64
tests/layer_tests/pytorch_tests/test_instance_norm.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestInstanceNorm(PytorchLayerTest):
|
||||
def _prepare_input(self, ndim=4):
|
||||
import numpy as np
|
||||
shape5d = [3, 6, 10, 5, 2]
|
||||
shape = shape5d[:ndim]
|
||||
return (np.random.randn(*shape).astype(np.float32),)
|
||||
|
||||
def create_model(self, weights=False, bias=False, mean_var=False, eps=1e-05):
|
||||
import torch
|
||||
|
||||
class aten_instance_norm(torch.nn.Module):
|
||||
def __init__(self, weights=False, bias=False, mean_var=False, eps=1e-05):
|
||||
super(aten_instance_norm, self).__init__()
|
||||
weights_shape = (6, )
|
||||
self.weight = torch.randn(weights_shape) if weights else None
|
||||
self.bias = None
|
||||
self.use_input_stats = not mean_var
|
||||
if bias:
|
||||
self.bias = torch.randn(weights_shape)
|
||||
self.mean = None
|
||||
self.var = None
|
||||
if mean_var:
|
||||
self.mean = torch.randn(weights_shape)
|
||||
self.var = torch.randn(weights_shape)
|
||||
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return torch.instance_norm(x, self.weight, self.bias, self.mean, self.var, self.use_input_stats, 0.1, self.eps, False)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_instance_norm(weights, bias, mean_var, eps), ref_net, "aten::instance_norm"
|
||||
|
||||
@pytest.mark.parametrize("params",
|
||||
[
|
||||
{"eps": 0.0001},
|
||||
{'weights': True, 'eps': -0.05},
|
||||
{'weights': True},
|
||||
{'weights': True, 'bias': True},
|
||||
{"weights": True, 'bias': False, "mean_var": True},
|
||||
{"weights": True, 'bias': True, "mean_var": True},
|
||||
{"weights": False, 'bias': True, "mean_var": True},
|
||||
{"weights": False, 'bias': False, "mean_var": True},
|
||||
{"weights": False, 'bias': False, "mean_var": True, "eps": 1.5}
|
||||
])
|
||||
@pytest.mark.parametrize("kwargs_to_prepare_input", [
|
||||
{"ndim": 3},
|
||||
{'ndim': 4},
|
||||
{"ndim": 5}
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_group_norm(self, params, ie_device, precision, ir_version, kwargs_to_prepare_input):
|
||||
self._test(*self.create_model(**params),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=not params.get("mean_var", False))
|
Loading…
Reference in New Issue
Block a user