From bf760b663ef16e9734cca8a0840fe4631a1228e9 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 1 Dec 2023 17:59:03 +0400 Subject: [PATCH] PT FE - aten::alias, aten::alias_copy, aten::cross, aten::linalg_cross (#21265) * [PT FE]: support aten::alias, aten::alias_copy, aten::cross, aten::linalg_cross * add type alignment * fix code style --- src/frontends/pytorch/src/op/copy.cpp | 11 ++ src/frontends/pytorch/src/op/cross.cpp | 101 +++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 7 + tests/layer_tests/pytorch_tests/test_copy.py | 33 +++++ tests/layer_tests/pytorch_tests/test_cross.py | 122 ++++++++++++++++++ 5 files changed, 274 insertions(+) create mode 100644 src/frontends/pytorch/src/op/cross.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_cross.py diff --git a/src/frontends/pytorch/src/op/copy.cpp b/src/frontends/pytorch/src/op/copy.cpp index 271d06ffd92..f4cb61f83cb 100644 --- a/src/frontends/pytorch/src/op/copy.cpp +++ b/src/frontends/pytorch/src/op/copy.cpp @@ -29,6 +29,17 @@ OutputVector translate_copy_(const NodeContext& context) { return {res}; }; +OutputVector translate_alias_copy(const NodeContext& context) { + // aten::alias_copy(Tensor self) -> Tensor + // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + num_inputs_check(context, 1, 2); + auto self = context.get_input(0); + if (!context.input_is_none(1)) { + context.mutate_input(1, self); + } + return {self}; +} + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/cross.cpp b/src/frontends/pytorch/src/op/cross.cpp new file mode 100644 index 00000000000..06392a14e0e --- /dev/null +++ b/src/frontends/pytorch/src/op/cross.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/roll.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +namespace { +Output translate_cross_base(const NodeContext& context, Output self, Output other, Output dim) { + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); + auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2})); + auto x_roll_1 = context.mark_node(std::make_shared(self, const_2, dim)); + auto x_roll_2 = context.mark_node(std::make_shared(self, const_1, dim)); + auto y_roll_1 = context.mark_node(std::make_shared(other, const_1, dim)); + auto y_roll_2 = context.mark_node(std::make_shared(other, const_2, dim)); + auto mul_1 = context.mark_node(std::make_shared(x_roll_1, y_roll_1)); + auto mul_2 = context.mark_node(std::make_shared(x_roll_2, y_roll_2)); + return context.mark_node(std::make_shared(mul_1, mul_2)); +} + +} // namespace +OutputVector translate_linalg_cross(const NodeContext& context) { + // aten::linalg_cross(Tensor self, Tensor other, int? dim=-1) -> Tensor + // aten::linalg_cross.out(Tensor self, Tensor other, int? dim=-1, *, Tensor(a!) out) -> Tensor(a!) + num_inputs_check(context, 3, 4); + auto self = context.get_input(0); + auto other = context.get_input(1); + align_eltwise_input_types(context, self, other, true); + auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + Output dim; + if (context.input_is_none(2)) { + dim = const_minus_1; + } else { + dim = context.get_input(2); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + dim = context.mark_node(std::make_shared(dim, const_0)); + } + auto res = translate_cross_base(context, self, other, dim); + if (!context.input_is_none(3)) { + context.mutate_input(3, res); + } + + return {res}; +}; + +OutputVector translate_cross(const NodeContext& context) { + // aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor + // aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + num_inputs_check(context, 3, 4); + auto self = context.get_input(0); + auto other = context.get_input(1); + align_eltwise_input_types(context, self, other, true); + Output dim; + if (context.input_is_none(2)) { + // If dim is not given, it defaults to the first dimension found with the size 3 + auto pshape = self.get_partial_shape(); + if (pshape.rank().is_dynamic()) { + FRONT_END_GENERAL_CHECK(false, "Rank should be known for aten::cross without explicit dim"); + } + size_t dim_id = static_cast(pshape.rank().get_length()); + size_t rank = static_cast(pshape.rank().get_length()); + for (size_t i = 0; i < rank; i++) { + if (pshape[i].is_static() && pshape[i] == ov::Dimension(3)) { + dim_id = i; + break; + } + } + if (dim_id == rank) { + FRONT_END_GENERAL_CHECK(false, "Suitable dim for aten::cross not found"); + } + dim = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {dim_id})); + + } else { + dim = context.get_input(2); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + dim = context.mark_node(std::make_shared(dim, const_0)); + } + auto res = translate_cross_base(context, self, other, dim); + if (!context.input_is_none(3)) { + context.mutate_input(3, res); + } + + return {res}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 45c61f84424..230b2b4d06c 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -26,6 +26,7 @@ OP_CONVERTER(translate_add); OP_CONVERTER(translate_add_); OP_CONVERTER(translate_addcmul); OP_CONVERTER(translate_addmm); +OP_CONVERTER(translate_alias_copy); OP_CONVERTER(translate_all); OP_CONVERTER(translate_amax); OP_CONVERTER(translate_amin); @@ -54,6 +55,7 @@ OP_CONVERTER(translate_convnd); OP_CONVERTER(translate_convolution); OP_CONVERTER(translate_convolution_mode); OP_CONVERTER(translate_copy_); +OP_CONVERTER(translate_cross); OP_CONVERTER(translate_cumsum); OP_CONVERTER(translate_deform_conv); OP_CONVERTER(translate_derive_index); @@ -97,6 +99,7 @@ OP_CONVERTER(translate_int); OP_CONVERTER(translate_is_nonzero); OP_CONVERTER(translate_layer_norm); OP_CONVERTER(translate_len); +OP_CONVERTER(translate_linalg_cross); OP_CONVERTER(translate_linalg_norm); OP_CONVERTER(translate_linalg_matrix_norm); OP_CONVERTER(translate_linalg_vector_norm); @@ -266,6 +269,8 @@ const std::map get_supported_ops_ts() { {"aten::add_", op::translate_add_}, {"aten::addcmul", op::translate_addcmul}, {"aten::addmm", op::translate_addmm}, + {"aten::alias", op::skip_node}, + {"aten::alias_copy", op::translate_alias_copy}, {"aten::all", op::translate_all}, {"aten::amax", op::translate_amax}, {"aten::amin", op::translate_amin}, @@ -323,6 +328,7 @@ const std::map get_supported_ops_ts() { {"aten::cos_", op::inplace_op>}, {"aten::cosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::cosh_", op::inplace_op>}, + {"aten::cross", op::translate_cross}, {"aten::cumsum", op::translate_cumsum}, {"aten::detach", op::skip_node}, {"aten::dequantize", op::skip_node}, // we convert model to fp32 using FQ, so dequantization is not needed @@ -390,6 +396,7 @@ const std::map get_supported_ops_ts() { {"aten::lift", op::skip_node}, {"aten::lift_fresh", op::skip_node}, {"aten::lift_fresh_copy", op::skip_node}, + {"aten::linalg_cross", op::translate_linalg_cross}, {"aten::linalg_norm", op::translate_linalg_norm}, {"aten::linalg_matrix_norm", op::translate_linalg_matrix_norm}, {"aten::linalg_vector_norm", op::translate_linalg_vector_norm}, diff --git a/tests/layer_tests/pytorch_tests/test_copy.py b/tests/layer_tests/pytorch_tests/test_copy.py index b78af602712..6b4969c0277 100644 --- a/tests/layer_tests/pytorch_tests/test_copy.py +++ b/tests/layer_tests/pytorch_tests/test_copy.py @@ -31,3 +31,36 @@ class TestCopy(PytorchLayerTest): @pytest.mark.parametrize("value", [1, [2.5], range(224)]) def test_copy_(self, value, ie_device, precision, ir_version): self._test(*self.create_model(value), ie_device, precision, ir_version) + + +class TestAliasCopy(PytorchLayerTest): + def _prepare_input(self, out): + import numpy as np + if not out: + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + return (np.random.randn(1, 3, 224, 224).astype(np.float32), np.zeros((1, 3, 224, 224), dtype=np.float32)) + + def create_model(self, out): + import torch + + class aten_copy(torch.nn.Module): + def __init__(self, out): + super(aten_copy, self).__init__() + if out: + self.forward = self.forward_out + + def forward(self, x): + return torch.alias_copy(x) + + def forward_out(self, x, y): + return torch.alias_copy(x, out=y), y + + ref_net = None + + return aten_copy(out), ref_net, "aten::alias_copy" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("out", [True, False]) + def test_copy_(self, out, ie_device, precision, ir_version): + self._test(*self.create_model(out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out}) \ No newline at end of file diff --git a/tests/layer_tests/pytorch_tests/test_cross.py b/tests/layer_tests/pytorch_tests/test_cross.py new file mode 100644 index 00000000000..ff953881138 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_cross.py @@ -0,0 +1,122 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestLinalgCross(PytorchLayerTest): + def _prepare_input(self, x_shape, y_shape, out, dtype): + import numpy as np + x = np.random.randn(*x_shape).astype(dtype) + y = np.random.randn(*y_shape).astype(dtype) + if not out: + return (x, y) + return (x, y, np.zeros(np.maximum(np.array(x_shape), np.array(y_shape)).tolist(), dtype=dtype)) + + def create_model(self, dim, out): + import torch + + class aten_linalg_cross(torch.nn.Module): + def __init__(self, dim, out): + super(aten_linalg_cross, self).__init__() + if dim is None: + self.forward = self.forward_no_dim_no_out if not out else self.forward_no_dim_out + elif out: + self.forward = self.forward_out + self.dim = dim + + def forward(self, x, y): + return torch.linalg.cross(x, y, dim=self.dim) + + def forward_out(self, x, y, out): + return torch.linalg.cross(x, y, dim=self.dim, out=out), out + + def forward_no_dim_out(self, x, y, out): + return torch.linalg.cross(x, y, out=out), out + + def forward_no_dim_no_out(self, x, y): + return torch.linalg.cross(x, y) + + ref_net = None + + return aten_linalg_cross(dim, out), ref_net, "aten::linalg_cross" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("x_shape,y_shape,dim", [ + ((4, 3), (4, 3), None), + ((1, 3), (4, 3), -1), + ((4, 3), (1, 3), 1), + ((3, 5), (3, 5), 0), + ((2, 3, 4), (2, 3, 4), 1) + ]) + @pytest.mark.parametrize('dtype', ['float32', 'float64']) + @pytest.mark.parametrize("out", [True, False]) + def test_linalg_cross(self, x_shape, y_shape, dim, out, dtype, ie_device, precision, ir_version): + self._test( + *self.create_model(dim, out), ie_device, precision, ir_version, use_convert_model=True, + kwargs_to_prepare_input={"x_shape":x_shape, "y_shape": y_shape, "out": out, 'dtype': dtype}) + + +class TestCross(PytorchLayerTest): + def _prepare_input(self, x_shape, y_shape, out, dtype): + import numpy as np + x = np.random.randn(*x_shape).astype(dtype) + y = np.random.randn(*y_shape).astype(dtype) + if not out: + return (x, y) + return (x, y, np.zeros(np.maximum(np.array(x_shape), np.array(y_shape)).tolist(), dtype=dtype)) + + def create_model(self, dim, out, shape): + import torch + + class aten_cross(torch.nn.Module): + def __init__(self, dim, out, shape): + super(aten_cross, self).__init__() + if dim is None: + self.forward = self.forward_no_dim_no_out if not out else self.forward_no_dim_out + elif out: + self.forward = self.forward_out + self.dim = dim + self.shape = shape + + def forward(self, x, y): + return torch.cross(x, y, dim=self.dim) + + def forward_out(self, x, y, out): + return torch.cross(x, y, dim=self.dim, out=out), out + + def forward_no_dim_out(self, x, y, out): + x = torch.reshape(x, self.shape) + return torch.cross(x, y, out=out), out + + def forward_no_dim_no_out(self, x, y): + x = torch.reshape(x, self.shape) + return torch.cross(x, y) + + ref_net = None + + return aten_cross(dim, out, shape), ref_net, "aten::cross" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("x_shape,y_shape,dim", [ + ((1, 3), (4, 3), -1), + ((4, 3), (1, 3), 1), + ((3, 5), (3, 5), 0), + ((2, 3, 4), (2, 3, 4), 1), + ((3, 1), (3, 4), None), + ((4, 3), (4, 3), None), + ((2, 3, 4), (2, 3, 4), None), + ]) + @pytest.mark.parametrize("out", [True, False]) + @pytest.mark.parametrize('dtype', ['float32', 'float64']) + def test_linalg_cross(self, x_shape, y_shape, dim, out, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(dim, out, x_shape), ie_device, precision, ir_version, + use_convert_model=True, + kwargs_to_prepare_input={"x_shape":x_shape, "y_shape": y_shape, "out": out, "dtype": dtype}) \ No newline at end of file