From 5ba7d9b72d93b821f3ccce07dbbc5168df159fcc Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 6 Oct 2023 12:26:12 +0400 Subject: [PATCH] [PT FE]: support aten::log1p, fixes for where and linalg_norm (#20167) * [PT FE]: support aten::log1p, fixes for where and linalg_norm * clarify norm behaviour --- src/frontends/pytorch/src/op/log.cpp | 12 +++++ src/frontends/pytorch/src/op/norm.cpp | 2 +- src/frontends/pytorch/src/op/where.cpp | 1 + src/frontends/pytorch/src/op_table.cpp | 3 ++ tests/layer_tests/pytorch_tests/test_log.py | 9 +++- tests/layer_tests/pytorch_tests/test_norm.py | 17 ++++--- tests/layer_tests/pytorch_tests/test_where.py | 51 +++++++++++++++---- 7 files changed, 75 insertions(+), 20 deletions(-) diff --git a/src/frontends/pytorch/src/op/log.cpp b/src/frontends/pytorch/src/op/log.cpp index c047f9e7853..20232e31dec 100644 --- a/src/frontends/pytorch/src/op/log.cpp +++ b/src/frontends/pytorch/src/op/log.cpp @@ -5,6 +5,7 @@ #include "openvino/op/log.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/divide.hpp" @@ -55,6 +56,17 @@ OutputVector translate_logsumexp(const NodeContext& context) { return {log}; }; +OutputVector translate_log1p(const NodeContext& context) { + // torch.log1p returns a tensor with the natural logarithm of the elements of input + 1. + num_inputs_check(context, 1, 1); + auto x = context.get_input(0); + x = context.mark_node(std::make_shared(x, element::f32)); + auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1})); + auto x_plus_one = context.mark_node(std::make_shared(x, one)); + auto log = context.mark_node(std::make_shared(x_plus_one)); + return {log}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/norm.cpp b/src/frontends/pytorch/src/op/norm.cpp index 6cf30a323e4..d3136b7e76a 100644 --- a/src/frontends/pytorch/src/op/norm.cpp +++ b/src/frontends/pytorch/src/op/norm.cpp @@ -259,7 +259,7 @@ OutputVector translate_linalg_norm(const NodeContext& context) { auto input_rank = x.get_partial_shape().rank(); if (input_rank.is_static() && input_rank.get_length() == 2) { result = frobenius_norm(context, x, dim, keep_dim); - } else if (input_rank.is_static() && input_rank.get_length() == 1) { + } else if (input_rank.is_dynamic() || input_rank.get_length() == 1) { result = norm_vector(context, x, dim, 2, keep_dim); } else { FRONT_END_OP_CONVERSION_CHECK(false, diff --git a/src/frontends/pytorch/src/op/where.cpp b/src/frontends/pytorch/src/op/where.cpp index 4a9de9f69ed..3d03706970b 100644 --- a/src/frontends/pytorch/src/op/where.cpp +++ b/src/frontends/pytorch/src/op/where.cpp @@ -21,6 +21,7 @@ OutputVector translate_where(const NodeContext& context) { auto bool_cond = context.mark_node(std::make_shared(cond, element::boolean)); auto x = context.get_input(1); auto y = context.get_input(2); + align_eltwise_input_types(context, x, y, true); return {context.mark_node(std::make_shared(bool_cond, x, y))}; }; diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index bbad312b74e..41a790d1ef2 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -89,6 +89,7 @@ OP_CONVERTER(translate_linspace); OP_CONVERTER(translate_list_construct); OP_CONVERTER(translate_list_unpack); OP_CONVERTER(translate_log); +OP_CONVERTER(translate_log1p); OP_CONVERTER(translate_log_softmax); OP_CONVERTER(translate_log2); OP_CONVERTER(translate_logsumexp); @@ -353,6 +354,8 @@ const std::map get_supported_ops_ts() { {"aten::logical_not", op::translate_not}, {"aten::logical_xor", op::translate_xor}, {"aten::log_softmax", op::translate_log_softmax}, + {"aten::log1p", op::translate_log1p}, + {"aten::log1p_", op::inplace_op}, {"aten::log2", op::translate_log2}, {"aten::log2_", op::inplace_op}, {"aten::lt", op::translate_1to1_match_2_inputs_align_types}, diff --git a/tests/layer_tests/pytorch_tests/test_log.py b/tests/layer_tests/pytorch_tests/test_log.py index 2d8a87fd22d..1e4de2dd4f1 100644 --- a/tests/layer_tests/pytorch_tests/test_log.py +++ b/tests/layer_tests/pytorch_tests/test_log.py @@ -17,7 +17,9 @@ class TestLog(PytorchLayerTest): "log": torch.log, "log_": torch.log_, "log2": torch.log2, - "log2_": torch.log2_ + "log2_": torch.log2_, + "log1p": torch.log1p, + "log1p_": torch.log1p_ } op_fn = ops[op] @@ -42,7 +44,10 @@ class TestLog(PytorchLayerTest): ["log_", "float32"], ["log2", "float32"], ["log2", "int32"], - ["log2_", "float32"]]) + ["log2_", "float32"], + ["log1p", "float32"], + ["log1p", "int32"], + ["log1p_", "float32"]]) def test_log(self, op, input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(op), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": input_dtype}) \ No newline at end of file diff --git a/tests/layer_tests/pytorch_tests/test_norm.py b/tests/layer_tests/pytorch_tests/test_norm.py index fa2da2f082a..aef0a074059 100644 --- a/tests/layer_tests/pytorch_tests/test_norm.py +++ b/tests/layer_tests/pytorch_tests/test_norm.py @@ -253,11 +253,11 @@ class TestLinalgMatrixNorm(PytorchLayerTest): class TestLinalgNorm(PytorchLayerTest): - def _prepare_input(self, out=False, out_dtype=None): + def _prepare_input(self, out=False, out_dtype=None, input_shape=(3, 3)): if not out: - return (np.random.randn(3, 3).astype(np.float32),) - x = np.random.randn(3, 3).astype(np.float32) - y = np.random.randn(3, 3).astype( + return (np.random.randn(*input_shape).astype(np.float32),) + x = np.random.randn(*input_shape).astype(np.float32) + y = np.random.randn(*input_shape).astype( out_dtype if out_dtype is not None else np.float32) return (x, y) @@ -318,7 +318,12 @@ class TestLinalgNorm(PytorchLayerTest): @pytest.mark.parametrize("dtype", ["float32", "float64", None]) @pytest.mark.parametrize("out", [True, False]) @pytest.mark.parametrize("prim_dtype", [True, False]) - def test_linalg_norm(self, p, dim, keepdim, dtype, out, prim_dtype, ie_device, precision, ir_version): + @pytest.mark.parametrize("input_shape", [[1, 3], [3, 3], [1, 3, 3]]) + def test_linalg_norm(self, p, dim, keepdim, dtype, out, prim_dtype, input_shape, ie_device, precision, ir_version): self._test(*self.create_model(p, dim, keepdim, dtype, out, prim_dtype), ie_device, precision, ir_version, - kwargs_to_prepare_input={"out": out or prim_dtype, "out_dtype": dtype if prim_dtype else None}) + kwargs_to_prepare_input={ + "out": out or prim_dtype, + "out_dtype": dtype if prim_dtype else None, + "input_shape": input_shape + }) diff --git a/tests/layer_tests/pytorch_tests/test_where.py b/tests/layer_tests/pytorch_tests/test_where.py index 20d9fa1d19b..b87f3794f76 100644 --- a/tests/layer_tests/pytorch_tests/test_where.py +++ b/tests/layer_tests/pytorch_tests/test_where.py @@ -8,7 +8,7 @@ from pytorch_layer_test_class import PytorchLayerTest class Testwhere(PytorchLayerTest): - def _prepare_input(self, mask_fill='ones', mask_dtype=bool, return_x_y=False): + def _prepare_input(self, mask_fill='ones', mask_dtype=bool, return_x_y=False, x_dtype="float32", y_dtype=None): input_shape = [2, 10] mask = np.zeros(input_shape).astype(mask_dtype) if mask_fill == 'ones': @@ -16,16 +16,31 @@ class Testwhere(PytorchLayerTest): if mask_fill == 'random': idx = np.random.choice(10, 5) mask[:, idx] = 1 - x = np.random.randn(*input_shape) - y = np.random.randn(*input_shape) + x = np.random.randn(*input_shape).astype(x_dtype) + y = np.random.randn(*input_shape).astype(y_dtype or x_dtype) return (mask,) if not return_x_y else (mask, x, y) - def create_model(self, as_non_zero): + def create_model(self, as_non_zero, dtypes=None): import torch + dtype_map = { + "float32": torch.float32, + "int32": torch.int32 + } + + torch_dtypes = None + if dtypes: + torch_dtypes = (dtype_map[dtypes[0]], dtype_map[dtypes[1]]) + class aten_where(torch.nn.Module): + def __init__(self, dtypes) -> None: + super().__init__() + self.x_dtype = dtypes[0] + self.y_dtype = dtypes[1] + + def forward(self, cond, x, y): - return torch.where(cond, x, y) + return torch.where(cond, x.to(self.x_dtype), y.to(self.y_dtype)) class aten_where_as_nonzero(torch.nn.Module): def forward(self, cond): @@ -35,25 +50,39 @@ class Testwhere(PytorchLayerTest): if as_non_zero: return aten_where_as_nonzero(), ref_net, "aten::where" - return aten_where(), ref_net, "aten::where" + return aten_where(torch_dtypes), ref_net, "aten::where" @pytest.mark.parametrize( "mask_fill", ['zeros', 'ones', 'random']) @pytest.mark.parametrize("mask_dtype", [np.uint8, bool]) # np.float32 incorrectly casted to bool + @pytest.mark.parametrize("x_dtype", ["float32", "int32"]) + @pytest.mark.parametrize("y_dtype", ["float32", "int32"]) @pytest.mark.nightly @pytest.mark.precommit - def test_where(self, mask_fill, mask_dtype, ie_device, precision, ir_version): - self._test(*self.create_model(False), + def test_where(self, mask_fill, mask_dtype, x_dtype, y_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(False, dtypes=(x_dtype, y_dtype)), ie_device, precision, ir_version, - kwargs_to_prepare_input={'mask_fill': mask_fill, 'mask_dtype': mask_dtype, 'return_x_y': True}) + kwargs_to_prepare_input={ + 'mask_fill': mask_fill, + 'mask_dtype': mask_dtype, + 'return_x_y': True, + "x_dtype": x_dtype, + "y_dtype": y_dtype + }) @pytest.mark.parametrize( "mask_fill", ['zeros', 'ones', 'random']) @pytest.mark.parametrize("mask_dtype", [np.uint8, bool]) # np.float32 incorrectly casted to bool + @pytest.mark.parametrize("x_dtype", ["float32", "int32"]) @pytest.mark.nightly @pytest.mark.precommit - def test_where_as_nonzero(self, mask_fill, mask_dtype, ie_device, precision, ir_version): + def test_where_as_nonzero(self, mask_fill, mask_dtype, x_dtype, ie_device, precision, ir_version): self._test(*self.create_model(True), ie_device, precision, ir_version, - kwargs_to_prepare_input={'mask_fill': mask_fill, 'mask_dtype': mask_dtype, 'return_x_y': False}, + kwargs_to_prepare_input={ + 'mask_fill': mask_fill, + 'mask_dtype': mask_dtype, + 'return_x_y': False, + "x_dtype": x_dtype, + }, trace_model=True)