[PT FE]: support aten::log1p, fixes for where and linalg_norm (#20167)

* [PT FE]: support aten::log1p, fixes for where and linalg_norm

* clarify norm behaviour
This commit is contained in:
Ekaterina Aidova 2023-10-06 12:26:12 +04:00 committed by Alexander Nesterov
parent aba9956e3b
commit 5ba7d9b72d
7 changed files with 75 additions and 20 deletions

View File

@ -5,6 +5,7 @@
#include "openvino/op/log.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
@ -55,6 +56,17 @@ OutputVector translate_logsumexp(const NodeContext& context) {
return {log};
};
OutputVector translate_log1p(const NodeContext& context) {
// torch.log1p returns a tensor with the natural logarithm of the elements of input + 1.
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
auto x_plus_one = context.mark_node(std::make_shared<v1::Add>(x, one));
auto log = context.mark_node(std::make_shared<v0::Log>(x_plus_one));
return {log};
};
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -259,7 +259,7 @@ OutputVector translate_linalg_norm(const NodeContext& context) {
auto input_rank = x.get_partial_shape().rank();
if (input_rank.is_static() && input_rank.get_length() == 2) {
result = frobenius_norm(context, x, dim, keep_dim);
} else if (input_rank.is_static() && input_rank.get_length() == 1) {
} else if (input_rank.is_dynamic() || input_rank.get_length() == 1) {
result = norm_vector(context, x, dim, 2, keep_dim);
} else {
FRONT_END_OP_CONVERSION_CHECK(false,

View File

@ -21,6 +21,7 @@ OutputVector translate_where(const NodeContext& context) {
auto bool_cond = context.mark_node(std::make_shared<v0::Convert>(cond, element::boolean));
auto x = context.get_input(1);
auto y = context.get_input(2);
align_eltwise_input_types(context, x, y, true);
return {context.mark_node(std::make_shared<v1::Select>(bool_cond, x, y))};
};

View File

@ -89,6 +89,7 @@ OP_CONVERTER(translate_linspace);
OP_CONVERTER(translate_list_construct);
OP_CONVERTER(translate_list_unpack);
OP_CONVERTER(translate_log);
OP_CONVERTER(translate_log1p);
OP_CONVERTER(translate_log_softmax);
OP_CONVERTER(translate_log2);
OP_CONVERTER(translate_logsumexp);
@ -353,6 +354,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::logical_not", op::translate_not},
{"aten::logical_xor", op::translate_xor},
{"aten::log_softmax", op::translate_log_softmax},
{"aten::log1p", op::translate_log1p},
{"aten::log1p_", op::inplace_op<op::translate_log1p>},
{"aten::log2", op::translate_log2},
{"aten::log2_", op::inplace_op<op::translate_log2>},
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},

View File

@ -17,7 +17,9 @@ class TestLog(PytorchLayerTest):
"log": torch.log,
"log_": torch.log_,
"log2": torch.log2,
"log2_": torch.log2_
"log2_": torch.log2_,
"log1p": torch.log1p,
"log1p_": torch.log1p_
}
op_fn = ops[op]
@ -42,7 +44,10 @@ class TestLog(PytorchLayerTest):
["log_", "float32"],
["log2", "float32"],
["log2", "int32"],
["log2_", "float32"]])
["log2_", "float32"],
["log1p", "float32"],
["log1p", "int32"],
["log1p_", "float32"]])
def test_log(self, op, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(op), ie_device, precision,
ir_version, kwargs_to_prepare_input={"dtype": input_dtype})

View File

@ -253,11 +253,11 @@ class TestLinalgMatrixNorm(PytorchLayerTest):
class TestLinalgNorm(PytorchLayerTest):
def _prepare_input(self, out=False, out_dtype=None):
def _prepare_input(self, out=False, out_dtype=None, input_shape=(3, 3)):
if not out:
return (np.random.randn(3, 3).astype(np.float32),)
x = np.random.randn(3, 3).astype(np.float32)
y = np.random.randn(3, 3).astype(
return (np.random.randn(*input_shape).astype(np.float32),)
x = np.random.randn(*input_shape).astype(np.float32)
y = np.random.randn(*input_shape).astype(
out_dtype if out_dtype is not None else np.float32)
return (x, y)
@ -318,7 +318,12 @@ class TestLinalgNorm(PytorchLayerTest):
@pytest.mark.parametrize("dtype", ["float32", "float64", None])
@pytest.mark.parametrize("out", [True, False])
@pytest.mark.parametrize("prim_dtype", [True, False])
def test_linalg_norm(self, p, dim, keepdim, dtype, out, prim_dtype, ie_device, precision, ir_version):
@pytest.mark.parametrize("input_shape", [[1, 3], [3, 3], [1, 3, 3]])
def test_linalg_norm(self, p, dim, keepdim, dtype, out, prim_dtype, input_shape, ie_device, precision, ir_version):
self._test(*self.create_model(p, dim, keepdim, dtype, out, prim_dtype),
ie_device, precision, ir_version,
kwargs_to_prepare_input={"out": out or prim_dtype, "out_dtype": dtype if prim_dtype else None})
kwargs_to_prepare_input={
"out": out or prim_dtype,
"out_dtype": dtype if prim_dtype else None,
"input_shape": input_shape
})

View File

@ -8,7 +8,7 @@ from pytorch_layer_test_class import PytorchLayerTest
class Testwhere(PytorchLayerTest):
def _prepare_input(self, mask_fill='ones', mask_dtype=bool, return_x_y=False):
def _prepare_input(self, mask_fill='ones', mask_dtype=bool, return_x_y=False, x_dtype="float32", y_dtype=None):
input_shape = [2, 10]
mask = np.zeros(input_shape).astype(mask_dtype)
if mask_fill == 'ones':
@ -16,16 +16,31 @@ class Testwhere(PytorchLayerTest):
if mask_fill == 'random':
idx = np.random.choice(10, 5)
mask[:, idx] = 1
x = np.random.randn(*input_shape)
y = np.random.randn(*input_shape)
x = np.random.randn(*input_shape).astype(x_dtype)
y = np.random.randn(*input_shape).astype(y_dtype or x_dtype)
return (mask,) if not return_x_y else (mask, x, y)
def create_model(self, as_non_zero):
def create_model(self, as_non_zero, dtypes=None):
import torch
dtype_map = {
"float32": torch.float32,
"int32": torch.int32
}
torch_dtypes = None
if dtypes:
torch_dtypes = (dtype_map[dtypes[0]], dtype_map[dtypes[1]])
class aten_where(torch.nn.Module):
def __init__(self, dtypes) -> None:
super().__init__()
self.x_dtype = dtypes[0]
self.y_dtype = dtypes[1]
def forward(self, cond, x, y):
return torch.where(cond, x, y)
return torch.where(cond, x.to(self.x_dtype), y.to(self.y_dtype))
class aten_where_as_nonzero(torch.nn.Module):
def forward(self, cond):
@ -35,25 +50,39 @@ class Testwhere(PytorchLayerTest):
if as_non_zero:
return aten_where_as_nonzero(), ref_net, "aten::where"
return aten_where(), ref_net, "aten::where"
return aten_where(torch_dtypes), ref_net, "aten::where"
@pytest.mark.parametrize(
"mask_fill", ['zeros', 'ones', 'random'])
@pytest.mark.parametrize("mask_dtype", [np.uint8, bool]) # np.float32 incorrectly casted to bool
@pytest.mark.parametrize("x_dtype", ["float32", "int32"])
@pytest.mark.parametrize("y_dtype", ["float32", "int32"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_where(self, mask_fill, mask_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(False),
def test_where(self, mask_fill, mask_dtype, x_dtype, y_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(False, dtypes=(x_dtype, y_dtype)),
ie_device, precision, ir_version,
kwargs_to_prepare_input={'mask_fill': mask_fill, 'mask_dtype': mask_dtype, 'return_x_y': True})
kwargs_to_prepare_input={
'mask_fill': mask_fill,
'mask_dtype': mask_dtype,
'return_x_y': True,
"x_dtype": x_dtype,
"y_dtype": y_dtype
})
@pytest.mark.parametrize(
"mask_fill", ['zeros', 'ones', 'random'])
@pytest.mark.parametrize("mask_dtype", [np.uint8, bool]) # np.float32 incorrectly casted to bool
@pytest.mark.parametrize("x_dtype", ["float32", "int32"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_where_as_nonzero(self, mask_fill, mask_dtype, ie_device, precision, ir_version):
def test_where_as_nonzero(self, mask_fill, mask_dtype, x_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(True),
ie_device, precision, ir_version,
kwargs_to_prepare_input={'mask_fill': mask_fill, 'mask_dtype': mask_dtype, 'return_x_y': False},
kwargs_to_prepare_input={
'mask_fill': mask_fill,
'mask_dtype': mask_dtype,
'return_x_y': False,
"x_dtype": x_dtype,
},
trace_model=True)