[PT FE]: support aten::instance_norm (#15213)

This commit is contained in:
Ekaterina Aidova 2023-01-31 12:51:02 +04:00 committed by GitHub
parent b4cb4fe8c9
commit b2ce43a172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 198 additions and 11 deletions

View File

@ -50,7 +50,7 @@ OutputVector translate_conv_transposend(NodeContext& context) {
auto bias = context.get_input(2);
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_conv_bias(context, bias, conv);
bias = reshape_channelwise(context, bias, conv);
}
conv = context.mark_node(std::make_shared<ov::op::v1::Add>(conv, bias));
}

View File

@ -49,7 +49,7 @@ OutputVector translate_convnd(NodeContext& context) {
auto bias = context.get_input(2);
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_conv_bias(context, bias, conv);
bias = reshape_channelwise(context, bias, conv);
}
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
}

View File

@ -69,7 +69,7 @@ OutputVector translate_convolution(NodeContext& context) {
auto bias = context.get_input(2);
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_conv_bias(context, bias, conv);
bias = reshape_channelwise(context, bias, conv);
}
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));

View File

@ -46,7 +46,7 @@ OutputVector translate_convolution_mode(NodeContext& context) {
auto bias = context.get_input(2);
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_conv_bias(context, bias, conv);
bias = reshape_channelwise(context, bias, conv);
}
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));

View File

@ -0,0 +1,109 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/batch_norm.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/mvn.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
namespace {
OutputVector translate_instance_norm_inference(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& reduction_axes,
float eps) {
auto norm = context.mark_node(
std::make_shared<ov::op::v6::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT));
if (!context.input_is_none(1)) {
auto weight = context.get_input(1);
weight = reshape_channelwise(context, weight, norm);
norm = context.mark_node(std::make_shared<ov::op::v1::Multiply>(norm, weight));
}
if (!context.input_is_none(2)) {
auto bias = context.get_input(2);
bias = reshape_channelwise(context, bias, norm);
norm = context.mark_node(std::make_shared<ov::op::v1::Add>(norm, bias));
}
return {norm};
}
OutputVector translate_instance_norm_train(NodeContext& context,
const Output<Node>& input,
const Output<Node>& reduction_axes,
float eps) {
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
auto batch_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, zero, zero));
auto channel_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, one, zero));
auto batch_dim_1d = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(batch_dim, zero));
auto batch_norm_channels_1d = context.mark_node(std::make_shared<ov::op::v1::Multiply>(batch_dim_1d, channel_dim));
auto one_1d = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{1}, {1}));
auto tail_shape = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, reduction_axes, zero));
auto reshape_shape = context.mark_node(
std::make_shared<ov::op::v0::Concat>(OutputVector{one_1d, batch_norm_channels_1d, tail_shape}, 0));
auto reshaped_input = context.mark_node(std::make_shared<ov::op::v1::Reshape>(input, reshape_shape, false));
Output<Node> weight;
Output<Node> bias;
if (context.input_is_none(1)) {
weight = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(one, batch_norm_channels_1d));
weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(weight, input));
} else {
weight = context.get_input(1);
weight = context.mark_node(std::make_shared<ov::op::v0::Tile>(weight, batch_dim_1d));
}
if (context.input_is_none(2)) {
bias = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(zero, batch_norm_channels_1d));
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, input));
} else {
bias = context.get_input(2);
bias = context.mark_node(std::make_shared<ov::op::v0::Tile>(bias, batch_dim_1d));
}
auto running_mean = context.get_input(3);
running_mean = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_mean, batch_dim_1d));
auto running_var = context.get_input(4);
running_var = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_var, batch_dim_1d));
auto batch_norm = context.mark_node(
std::make_shared<ov::op::v5::BatchNormInference>(reshaped_input, weight, bias, running_mean, running_var, eps));
return {context.mark_node(std::make_shared<ov::op::v1::Reshape>(batch_norm, input_shape, true))};
}
} // namespace
OutputVector translate_instance_norm(NodeContext& context) {
num_inputs_check(context, 8, 9);
auto input = context.get_input(0);
auto eps = context.const_input<float>(7);
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
auto rank_1d = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input_shape));
auto rank = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(rank_1d));
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
auto two = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {2}));
auto reduction_axes = context.mark_node(std::make_shared<ov::op::v4::Range>(two, rank, one, element::i64));
if (context.input_is_none(3) && context.input_is_none(4)) {
return translate_instance_norm_inference(context, input, reduction_axes, eps);
}
return translate_instance_norm_train(context, input, reduction_axes, eps);
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -49,6 +49,7 @@ OP_CONVERTER(translate_group_norm);
OP_CONVERTER(translate_hardtanh);
OP_CONVERTER(translate_if);
OP_CONVERTER(translate_im2col);
OP_CONVERTER(translate_instance_norm);
OP_CONVERTER(translate_int);
OP_CONVERTER(translate_layer_norm);
OP_CONVERTER(translate_len);
@ -196,6 +197,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::Int", op::translate_int},
{"aten::IntImplicit", op::translate_int},
{"aten::im2col", op::translate_im2col},
{"aten::instance_norm", op::translate_instance_norm},
{"aten::is_grad_enabled", op::return_false_scalar},
{"aten::layer_norm", op::translate_layer_norm},
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset10::PRelu>},

View File

@ -14,6 +14,14 @@ namespace ov {
namespace frontend {
namespace pytorch {
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
auto inputs = context.inputs();
FRONT_END_OP_CONVERSION_CHECK(inputs.size() > min_inputs, "Got less inputs than expected");
for (auto i = max_inputs; i < inputs.size(); i++) {
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
}
}
Output<Node> make_optional_bias(const Output<Node>& base_op,
const NodeContext& context,
size_t bias_input_idx,
@ -34,18 +42,18 @@ Output<Node> make_optional_bias(const Output<Node>& base_op,
}
}
Output<ov::Node> reshape_conv_bias(const NodeContext& context, Output<ov::Node> bias, Output<ov::Node> conv) {
auto conv_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(conv));
auto conv_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(conv_shape));
Output<ov::Node> reshape_channelwise(const NodeContext& context, Output<ov::Node> data, Output<ov::Node> shape_source) {
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(shape_source));
auto input_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(input_shape));
auto one_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
auto two_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2}));
auto tail_shape_rank = context.mark_node(std::make_shared<opset10::Subtract>(conv_rank, two_const));
auto tail_shape_rank = context.mark_node(std::make_shared<opset10::Subtract>(input_rank, two_const));
auto tail_shape = context.mark_node(std::make_shared<opset10::Broadcast>(one_const, tail_shape_rank));
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(bias));
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(data));
auto new_shape =
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{one_const, channels_dim, tail_shape}, 0));
return context.mark_node(std::make_shared<opset10::Reshape>(bias, new_shape, false));
return context.mark_node(std::make_shared<opset10::Reshape>(data, new_shape, false));
}
std::shared_ptr<Node> get_rank_node(const Output<Node>& node) {

View File

@ -18,12 +18,16 @@ class FrameworkNode;
namespace frontend {
namespace pytorch {
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs);
Output<Node> make_optional_bias(const Output<Node>& base_op,
const NodeContext& context,
size_t bias_input_idx,
const std::vector<int>& unsqueeze_dims = {});
Output<ov::Node> reshape_conv_bias(const NodeContext& context, Output<ov::Node> bias, Output<ngraph::Node> conv);
Output<ov::Node> reshape_channelwise(const NodeContext& context,
Output<ov::Node> data,
Output<ngraph::Node> shape_source);
std::shared_ptr<ov::Node> get_rank_node(const Output<Node>& node);

View File

@ -0,0 +1,64 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestInstanceNorm(PytorchLayerTest):
def _prepare_input(self, ndim=4):
import numpy as np
shape5d = [3, 6, 10, 5, 2]
shape = shape5d[:ndim]
return (np.random.randn(*shape).astype(np.float32),)
def create_model(self, weights=False, bias=False, mean_var=False, eps=1e-05):
import torch
class aten_instance_norm(torch.nn.Module):
def __init__(self, weights=False, bias=False, mean_var=False, eps=1e-05):
super(aten_instance_norm, self).__init__()
weights_shape = (6, )
self.weight = torch.randn(weights_shape) if weights else None
self.bias = None
self.use_input_stats = not mean_var
if bias:
self.bias = torch.randn(weights_shape)
self.mean = None
self.var = None
if mean_var:
self.mean = torch.randn(weights_shape)
self.var = torch.randn(weights_shape)
self.eps = eps
def forward(self, x):
return torch.instance_norm(x, self.weight, self.bias, self.mean, self.var, self.use_input_stats, 0.1, self.eps, False)
ref_net = None
return aten_instance_norm(weights, bias, mean_var, eps), ref_net, "aten::instance_norm"
@pytest.mark.parametrize("params",
[
{"eps": 0.0001},
{'weights': True, 'eps': -0.05},
{'weights': True},
{'weights': True, 'bias': True},
{"weights": True, 'bias': False, "mean_var": True},
{"weights": True, 'bias': True, "mean_var": True},
{"weights": False, 'bias': True, "mean_var": True},
{"weights": False, 'bias': False, "mean_var": True},
{"weights": False, 'bias': False, "mean_var": True, "eps": 1.5}
])
@pytest.mark.parametrize("kwargs_to_prepare_input", [
{"ndim": 3},
{'ndim': 4},
{"ndim": 5}
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_group_norm(self, params, ie_device, precision, ir_version, kwargs_to_prepare_input):
self._test(*self.create_model(**params),
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=not params.get("mean_var", False))