[PT FE]: support aten::instance_norm (#15213)
This commit is contained in:
parent
b4cb4fe8c9
commit
b2ce43a172
@ -50,7 +50,7 @@ OutputVector translate_conv_transposend(NodeContext& context) {
|
|||||||
auto bias = context.get_input(2);
|
auto bias = context.get_input(2);
|
||||||
auto bias_rank = bias.get_partial_shape().rank();
|
auto bias_rank = bias.get_partial_shape().rank();
|
||||||
if (bias_rank == 1) {
|
if (bias_rank == 1) {
|
||||||
bias = reshape_conv_bias(context, bias, conv);
|
bias = reshape_channelwise(context, bias, conv);
|
||||||
}
|
}
|
||||||
conv = context.mark_node(std::make_shared<ov::op::v1::Add>(conv, bias));
|
conv = context.mark_node(std::make_shared<ov::op::v1::Add>(conv, bias));
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ OutputVector translate_convnd(NodeContext& context) {
|
|||||||
auto bias = context.get_input(2);
|
auto bias = context.get_input(2);
|
||||||
auto bias_rank = bias.get_partial_shape().rank();
|
auto bias_rank = bias.get_partial_shape().rank();
|
||||||
if (bias_rank == 1) {
|
if (bias_rank == 1) {
|
||||||
bias = reshape_conv_bias(context, bias, conv);
|
bias = reshape_channelwise(context, bias, conv);
|
||||||
}
|
}
|
||||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,7 @@ OutputVector translate_convolution(NodeContext& context) {
|
|||||||
auto bias = context.get_input(2);
|
auto bias = context.get_input(2);
|
||||||
auto bias_rank = bias.get_partial_shape().rank();
|
auto bias_rank = bias.get_partial_shape().rank();
|
||||||
if (bias_rank == 1) {
|
if (bias_rank == 1) {
|
||||||
bias = reshape_conv_bias(context, bias, conv);
|
bias = reshape_channelwise(context, bias, conv);
|
||||||
}
|
}
|
||||||
|
|
||||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||||
|
@ -46,7 +46,7 @@ OutputVector translate_convolution_mode(NodeContext& context) {
|
|||||||
auto bias = context.get_input(2);
|
auto bias = context.get_input(2);
|
||||||
auto bias_rank = bias.get_partial_shape().rank();
|
auto bias_rank = bias.get_partial_shape().rank();
|
||||||
if (bias_rank == 1) {
|
if (bias_rank == 1) {
|
||||||
bias = reshape_conv_bias(context, bias, conv);
|
bias = reshape_channelwise(context, bias, conv);
|
||||||
}
|
}
|
||||||
|
|
||||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||||
|
109
src/frontends/pytorch/src/op/instance_norm.cpp
Normal file
109
src/frontends/pytorch/src/op/instance_norm.cpp
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/add.hpp"
|
||||||
|
#include "openvino/op/batch_norm.hpp"
|
||||||
|
#include "openvino/op/broadcast.hpp"
|
||||||
|
#include "openvino/op/concat.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/convert_like.hpp"
|
||||||
|
#include "openvino/op/gather.hpp"
|
||||||
|
#include "openvino/op/multiply.hpp"
|
||||||
|
#include "openvino/op/mvn.hpp"
|
||||||
|
#include "openvino/op/range.hpp"
|
||||||
|
#include "openvino/op/reshape.hpp"
|
||||||
|
#include "openvino/op/shape_of.hpp"
|
||||||
|
#include "openvino/op/squeeze.hpp"
|
||||||
|
#include "openvino/op/tile.hpp"
|
||||||
|
#include "openvino/op/unsqueeze.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace pytorch {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
OutputVector translate_instance_norm_inference(const NodeContext& context,
|
||||||
|
const Output<Node>& input,
|
||||||
|
const Output<Node>& reduction_axes,
|
||||||
|
float eps) {
|
||||||
|
auto norm = context.mark_node(
|
||||||
|
std::make_shared<ov::op::v6::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT));
|
||||||
|
if (!context.input_is_none(1)) {
|
||||||
|
auto weight = context.get_input(1);
|
||||||
|
weight = reshape_channelwise(context, weight, norm);
|
||||||
|
norm = context.mark_node(std::make_shared<ov::op::v1::Multiply>(norm, weight));
|
||||||
|
}
|
||||||
|
if (!context.input_is_none(2)) {
|
||||||
|
auto bias = context.get_input(2);
|
||||||
|
bias = reshape_channelwise(context, bias, norm);
|
||||||
|
norm = context.mark_node(std::make_shared<ov::op::v1::Add>(norm, bias));
|
||||||
|
}
|
||||||
|
return {norm};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_instance_norm_train(NodeContext& context,
|
||||||
|
const Output<Node>& input,
|
||||||
|
const Output<Node>& reduction_axes,
|
||||||
|
float eps) {
|
||||||
|
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||||
|
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
|
||||||
|
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||||
|
auto batch_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, zero, zero));
|
||||||
|
auto channel_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, one, zero));
|
||||||
|
auto batch_dim_1d = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(batch_dim, zero));
|
||||||
|
auto batch_norm_channels_1d = context.mark_node(std::make_shared<ov::op::v1::Multiply>(batch_dim_1d, channel_dim));
|
||||||
|
auto one_1d = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||||
|
auto tail_shape = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, reduction_axes, zero));
|
||||||
|
auto reshape_shape = context.mark_node(
|
||||||
|
std::make_shared<ov::op::v0::Concat>(OutputVector{one_1d, batch_norm_channels_1d, tail_shape}, 0));
|
||||||
|
auto reshaped_input = context.mark_node(std::make_shared<ov::op::v1::Reshape>(input, reshape_shape, false));
|
||||||
|
Output<Node> weight;
|
||||||
|
Output<Node> bias;
|
||||||
|
if (context.input_is_none(1)) {
|
||||||
|
weight = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(one, batch_norm_channels_1d));
|
||||||
|
weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(weight, input));
|
||||||
|
} else {
|
||||||
|
weight = context.get_input(1);
|
||||||
|
weight = context.mark_node(std::make_shared<ov::op::v0::Tile>(weight, batch_dim_1d));
|
||||||
|
}
|
||||||
|
if (context.input_is_none(2)) {
|
||||||
|
bias = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(zero, batch_norm_channels_1d));
|
||||||
|
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, input));
|
||||||
|
} else {
|
||||||
|
bias = context.get_input(2);
|
||||||
|
bias = context.mark_node(std::make_shared<ov::op::v0::Tile>(bias, batch_dim_1d));
|
||||||
|
}
|
||||||
|
auto running_mean = context.get_input(3);
|
||||||
|
running_mean = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_mean, batch_dim_1d));
|
||||||
|
auto running_var = context.get_input(4);
|
||||||
|
running_var = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_var, batch_dim_1d));
|
||||||
|
auto batch_norm = context.mark_node(
|
||||||
|
std::make_shared<ov::op::v5::BatchNormInference>(reshaped_input, weight, bias, running_mean, running_var, eps));
|
||||||
|
return {context.mark_node(std::make_shared<ov::op::v1::Reshape>(batch_norm, input_shape, true))};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
OutputVector translate_instance_norm(NodeContext& context) {
|
||||||
|
num_inputs_check(context, 8, 9);
|
||||||
|
auto input = context.get_input(0);
|
||||||
|
auto eps = context.const_input<float>(7);
|
||||||
|
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||||
|
auto rank_1d = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input_shape));
|
||||||
|
auto rank = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(rank_1d));
|
||||||
|
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
|
||||||
|
auto two = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {2}));
|
||||||
|
auto reduction_axes = context.mark_node(std::make_shared<ov::op::v4::Range>(two, rank, one, element::i64));
|
||||||
|
if (context.input_is_none(3) && context.input_is_none(4)) {
|
||||||
|
return translate_instance_norm_inference(context, input, reduction_axes, eps);
|
||||||
|
}
|
||||||
|
return translate_instance_norm_train(context, input, reduction_axes, eps);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pytorch
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -49,6 +49,7 @@ OP_CONVERTER(translate_group_norm);
|
|||||||
OP_CONVERTER(translate_hardtanh);
|
OP_CONVERTER(translate_hardtanh);
|
||||||
OP_CONVERTER(translate_if);
|
OP_CONVERTER(translate_if);
|
||||||
OP_CONVERTER(translate_im2col);
|
OP_CONVERTER(translate_im2col);
|
||||||
|
OP_CONVERTER(translate_instance_norm);
|
||||||
OP_CONVERTER(translate_int);
|
OP_CONVERTER(translate_int);
|
||||||
OP_CONVERTER(translate_layer_norm);
|
OP_CONVERTER(translate_layer_norm);
|
||||||
OP_CONVERTER(translate_len);
|
OP_CONVERTER(translate_len);
|
||||||
@ -196,6 +197,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"aten::Int", op::translate_int},
|
{"aten::Int", op::translate_int},
|
||||||
{"aten::IntImplicit", op::translate_int},
|
{"aten::IntImplicit", op::translate_int},
|
||||||
{"aten::im2col", op::translate_im2col},
|
{"aten::im2col", op::translate_im2col},
|
||||||
|
{"aten::instance_norm", op::translate_instance_norm},
|
||||||
{"aten::is_grad_enabled", op::return_false_scalar},
|
{"aten::is_grad_enabled", op::return_false_scalar},
|
||||||
{"aten::layer_norm", op::translate_layer_norm},
|
{"aten::layer_norm", op::translate_layer_norm},
|
||||||
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
|
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
|
||||||
|
@ -14,6 +14,14 @@ namespace ov {
|
|||||||
namespace frontend {
|
namespace frontend {
|
||||||
namespace pytorch {
|
namespace pytorch {
|
||||||
|
|
||||||
|
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
|
||||||
|
auto inputs = context.inputs();
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(inputs.size() > min_inputs, "Got less inputs than expected");
|
||||||
|
for (auto i = max_inputs; i < inputs.size(); i++) {
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Output<Node> make_optional_bias(const Output<Node>& base_op,
|
Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||||
const NodeContext& context,
|
const NodeContext& context,
|
||||||
size_t bias_input_idx,
|
size_t bias_input_idx,
|
||||||
@ -34,18 +42,18 @@ Output<Node> make_optional_bias(const Output<Node>& base_op,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<ov::Node> reshape_conv_bias(const NodeContext& context, Output<ov::Node> bias, Output<ov::Node> conv) {
|
Output<ov::Node> reshape_channelwise(const NodeContext& context, Output<ov::Node> data, Output<ov::Node> shape_source) {
|
||||||
auto conv_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(conv));
|
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(shape_source));
|
||||||
auto conv_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(conv_shape));
|
auto input_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(input_shape));
|
||||||
auto one_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
auto one_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||||
auto two_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2}));
|
auto two_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2}));
|
||||||
auto tail_shape_rank = context.mark_node(std::make_shared<opset10::Subtract>(conv_rank, two_const));
|
auto tail_shape_rank = context.mark_node(std::make_shared<opset10::Subtract>(input_rank, two_const));
|
||||||
auto tail_shape = context.mark_node(std::make_shared<opset10::Broadcast>(one_const, tail_shape_rank));
|
auto tail_shape = context.mark_node(std::make_shared<opset10::Broadcast>(one_const, tail_shape_rank));
|
||||||
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(bias));
|
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(data));
|
||||||
auto new_shape =
|
auto new_shape =
|
||||||
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{one_const, channels_dim, tail_shape}, 0));
|
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{one_const, channels_dim, tail_shape}, 0));
|
||||||
|
|
||||||
return context.mark_node(std::make_shared<opset10::Reshape>(bias, new_shape, false));
|
return context.mark_node(std::make_shared<opset10::Reshape>(data, new_shape, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> get_rank_node(const Output<Node>& node) {
|
std::shared_ptr<Node> get_rank_node(const Output<Node>& node) {
|
||||||
|
@ -18,12 +18,16 @@ class FrameworkNode;
|
|||||||
namespace frontend {
|
namespace frontend {
|
||||||
namespace pytorch {
|
namespace pytorch {
|
||||||
|
|
||||||
|
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs);
|
||||||
|
|
||||||
Output<Node> make_optional_bias(const Output<Node>& base_op,
|
Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||||
const NodeContext& context,
|
const NodeContext& context,
|
||||||
size_t bias_input_idx,
|
size_t bias_input_idx,
|
||||||
const std::vector<int>& unsqueeze_dims = {});
|
const std::vector<int>& unsqueeze_dims = {});
|
||||||
|
|
||||||
Output<ov::Node> reshape_conv_bias(const NodeContext& context, Output<ov::Node> bias, Output<ngraph::Node> conv);
|
Output<ov::Node> reshape_channelwise(const NodeContext& context,
|
||||||
|
Output<ov::Node> data,
|
||||||
|
Output<ngraph::Node> shape_source);
|
||||||
|
|
||||||
std::shared_ptr<ov::Node> get_rank_node(const Output<Node>& node);
|
std::shared_ptr<ov::Node> get_rank_node(const Output<Node>& node);
|
||||||
|
|
||||||
|
64
tests/layer_tests/pytorch_tests/test_instance_norm.py
Normal file
64
tests/layer_tests/pytorch_tests/test_instance_norm.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestInstanceNorm(PytorchLayerTest):
|
||||||
|
def _prepare_input(self, ndim=4):
|
||||||
|
import numpy as np
|
||||||
|
shape5d = [3, 6, 10, 5, 2]
|
||||||
|
shape = shape5d[:ndim]
|
||||||
|
return (np.random.randn(*shape).astype(np.float32),)
|
||||||
|
|
||||||
|
def create_model(self, weights=False, bias=False, mean_var=False, eps=1e-05):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class aten_instance_norm(torch.nn.Module):
|
||||||
|
def __init__(self, weights=False, bias=False, mean_var=False, eps=1e-05):
|
||||||
|
super(aten_instance_norm, self).__init__()
|
||||||
|
weights_shape = (6, )
|
||||||
|
self.weight = torch.randn(weights_shape) if weights else None
|
||||||
|
self.bias = None
|
||||||
|
self.use_input_stats = not mean_var
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.randn(weights_shape)
|
||||||
|
self.mean = None
|
||||||
|
self.var = None
|
||||||
|
if mean_var:
|
||||||
|
self.mean = torch.randn(weights_shape)
|
||||||
|
self.var = torch.randn(weights_shape)
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.instance_norm(x, self.weight, self.bias, self.mean, self.var, self.use_input_stats, 0.1, self.eps, False)
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_instance_norm(weights, bias, mean_var, eps), ref_net, "aten::instance_norm"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("params",
|
||||||
|
[
|
||||||
|
{"eps": 0.0001},
|
||||||
|
{'weights': True, 'eps': -0.05},
|
||||||
|
{'weights': True},
|
||||||
|
{'weights': True, 'bias': True},
|
||||||
|
{"weights": True, 'bias': False, "mean_var": True},
|
||||||
|
{"weights": True, 'bias': True, "mean_var": True},
|
||||||
|
{"weights": False, 'bias': True, "mean_var": True},
|
||||||
|
{"weights": False, 'bias': False, "mean_var": True},
|
||||||
|
{"weights": False, 'bias': False, "mean_var": True, "eps": 1.5}
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("kwargs_to_prepare_input", [
|
||||||
|
{"ndim": 3},
|
||||||
|
{'ndim': 4},
|
||||||
|
{"ndim": 5}
|
||||||
|
])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_group_norm(self, params, ie_device, precision, ir_version, kwargs_to_prepare_input):
|
||||||
|
self._test(*self.create_model(**params),
|
||||||
|
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=not params.get("mean_var", False))
|
Loading…
Reference in New Issue
Block a user