[PT FE]: support frobenius norm and fix aten::norm (#17701)

* [PT FE]: support frobenius norm and fix aten::norm

* fix code style
This commit is contained in:
Ekaterina Aidova 2023-05-29 11:14:16 +04:00 committed by GitHub
parent 00f94426f1
commit 3300543eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 11 deletions

View File

@ -111,9 +111,10 @@ Output<Node> frobenius_norm(const NodeContext& context, Output<Node> x, Output<N
}; // namespace
OutputVector translate_norm(const NodeContext& context) {
num_inputs_check(context, 4, 6);
num_inputs_check(context, 2, 6);
auto input_tensor = context.get_input(0);
auto p_node_type = context.get_input_type(1);
bool keep_dim = false;
Output<Node> dim;
if (context.input_is_none(2)) {
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
@ -125,7 +126,9 @@ OutputVector translate_norm(const NodeContext& context) {
} else {
dim = context.get_input(2);
}
auto keep_dim = context.const_input<bool>(3);
if (!context.input_is_none(3)) {
keep_dim = context.const_input<bool>(3);
}
if (!context.input_is_none(4)) {
input_tensor = apply_dtype(context, 4, input_tensor);
}
@ -295,6 +298,26 @@ OutputVector translate_linalg_norm(const NodeContext& context) {
return {result};
};
OutputVector translate_frobenius_norm(const NodeContext& context) {
// aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
// aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto x = context.get_input(0);
bool keep_dim = context.const_input<bool>(2);
Output<Node> dim;
if (context.input_is_none(1)) {
dim = get_axes_range(context, 0);
} else {
dim = context.get_input(1);
}
auto result = frobenius_norm(context, x, dim, keep_dim);
if (!context.input_is_none(3)) {
context.mutate_input(3, result);
}
return {result};
}
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -51,6 +51,7 @@ OP_CONVERTER(translate_fill_);
OP_CONVERTER(translate_flatten);
OP_CONVERTER(translate_floor_divide);
OP_CONVERTER(translate_floordiv);
OP_CONVERTER(translate_frobenius_norm);
OP_CONVERTER(translate_full);
OP_CONVERTER(translate_full_like);
OP_CONVERTER(translate_gather);
@ -229,6 +230,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
{"aten::floor_divide", op::translate_floor_divide},
{"aten::floordiv", op::translate_floordiv},
{"aten::frobenius_norm", op::translate_frobenius_norm},
{"aten::full", op::translate_full},
{"aten::full_like", op::translate_full_like},
{"aten::gather", op::translate_gather},

View File

@ -8,9 +8,6 @@ import torch
from pytorch_layer_test_class import PytorchLayerTest
@pytest.mark.parametrize('p', [-2, -1, 0, 1, 2, 2.5, float('inf'), float('-inf')])
@pytest.mark.parametrize('dim', [[0], [0, 1], [0, 1, 2]])
@pytest.mark.parametrize('keepdim', [True, False])
class TestNorm(PytorchLayerTest):
def _prepare_input(self):
@ -32,12 +29,87 @@ class TestNorm(PytorchLayerTest):
return aten_norm(p, dim, keepdim), ref_net, "aten::norm"
def create_model_tensor_norm(self, p, dim, keepdim):
class aten_norm(torch.nn.Module):
def __init__(self, p, dim, keepdim) -> None:
super().__init__()
self.p = p
self.dim = dim
self.keepdim = keepdim
if self.keepdim is None or self.dim is None:
self.forward = self.forward2
else:
self.forward = self.forward4
def forward4(self, input_data):
return input_data.norm(self.p, self.dim, self.keepdim)
def forward2(self, input_data):
return input_data.norm(self.p)
ref_net = None
return aten_norm(p, dim, keepdim), ref_net, "aten::norm"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize('p', [-2, -1, 0, 1, 2, 2.5, float('inf'), float('-inf')])
@pytest.mark.parametrize('dim', [[0], [0, 1], [0, 1, 2]])
@pytest.mark.parametrize('keepdim', [True, False])
def test_norm(self, ie_device, precision, ir_version, p, dim, keepdim):
self._test(*self.create_model(p, dim, keepdim),
ie_device, precision, ir_version)
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize('p', [-2, -1, 0, 1, 2, 2.5, float('inf'), float('-inf')])
@pytest.mark.parametrize('dim', [None, [0], [0, 1], [0, 1, 2]])
@pytest.mark.parametrize('keepdim', [None, True, False])
def test_norm_tensor(self, ie_device, precision, ir_version, p, dim, keepdim):
self._test(*self.create_model_tensor_norm(p, dim, keepdim),
ie_device, precision, ir_version)
class TestFrobeniusNorm(PytorchLayerTest):
def _prepare_input(self, out=False, dtype="float32"):
x = np.random.randn(10, 12, 14).astype(dtype)
if not out:
return (x,)
y = np.zeros_like(x)
return (x, y)
def create_model(self, dim, keepdim, out):
class aten_frobenius_norm(torch.nn.Module):
def __init__(self, dim, keepdim, out) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim
if out:
self.forward = self.forward_out
def forward(self, input_data):
return torch._VF.frobenius_norm(input_data, self.dim, self.keepdim)
def forward_out(self, input_data, out):
return torch._VF.frobenius_norm(input_data, self.dim, self.keepdim, out=out), out
ref_net = None
return aten_frobenius_norm(dim, keepdim, out), ref_net, "aten::frobenius_norm"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize('dim', [(1, ), (0, ), (-1, ), (0, 1), (1, 0)])
@pytest.mark.parametrize('keepdim', [True, False])
@pytest.mark.parametrize("out", [False, True])
@pytest.mark.parametrize("dtype", ["float32", "float64"])
def test_frobenius_norm(self, ie_device, precision, ir_version, dim, keepdim, out, dtype):
self._test(*self.create_model(dim, keepdim, out), ie_device, precision, ir_version,
kwargs_to_prepare_input={"out": out, "dtype": dtype}
)
class TestLinalgVectorNorm(PytorchLayerTest):
@ -105,7 +177,7 @@ class TestLinalgVectorNorm(PytorchLayerTest):
@pytest.mark.parametrize("prim_dtype", [True, False])
def test_linalg_vector_norm(self, p, dim, keepdim, dtype, out, prim_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(p, dim, keepdim, dtype, out, prim_dtype),
ie_device, precision, ir_version,
ie_device, precision, ir_version,
kwargs_to_prepare_input={"out": out or prim_dtype, "out_dtype": dtype if prim_dtype else None})
@ -175,7 +247,7 @@ class TestLinalgMatrixNorm(PytorchLayerTest):
@pytest.mark.parametrize("prim_dtype", [True, False])
def test_linalg_matrix_norm(self, p, dim, keepdim, dtype, out, prim_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(p, dim, keepdim, dtype, out, prim_dtype),
ie_device, precision, ir_version,
ie_device, precision, ir_version,
kwargs_to_prepare_input={"out": out or prim_dtype, "out_dtype": dtype if prim_dtype else None})
@ -238,9 +310,9 @@ class TestLinalgNorm(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize('p,dim', [
(-1, [0, 1]), (1, [-1, -2]), (float('inf'), [1, 0]),
(float('-inf'), [-2, -1]), (0, 1), (1, -1),
(None, None), (2.5, 0), (-1, 1), (2, 0),
(-1, [0, 1]), (1, [-1, -2]), (float('inf'), [1, 0]),
(float('-inf'), [-2, -1]), (0, 1), (1, -1),
(None, None), (2.5, 0), (-1, 1), (2, 0),
(float('inf'), 1), (float('-inf'), 1), ("fro", (0, 1))])
@pytest.mark.parametrize('keepdim', [True, False])
@pytest.mark.parametrize("dtype", ["float32", "float64", None])
@ -248,5 +320,5 @@ class TestLinalgNorm(PytorchLayerTest):
@pytest.mark.parametrize("prim_dtype", [True, False])
def test_linalg_norm(self, p, dim, keepdim, dtype, out, prim_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(p, dim, keepdim, dtype, out, prim_dtype),
ie_device, precision, ir_version,
ie_device, precision, ir_version,
kwargs_to_prepare_input={"out": out or prim_dtype, "out_dtype": dtype if prim_dtype else None})