PT FE - aten::alias, aten::alias_copy, aten::cross, aten::linalg_cross (#21265)

* [PT FE]: support aten::alias, aten::alias_copy, aten::cross, aten::linalg_cross

* add type alignment

* fix code style
This commit is contained in:
Ekaterina Aidova 2023-12-01 17:59:03 +04:00 committed by GitHub
parent 0e642e984b
commit bf760b663e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 274 additions and 0 deletions

View File

@ -29,6 +29,17 @@ OutputVector translate_copy_(const NodeContext& context) {
return {res};
};
OutputVector translate_alias_copy(const NodeContext& context) {
// aten::alias_copy(Tensor self) -> Tensor
// aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 1, 2);
auto self = context.get_input(0);
if (!context.input_is_none(1)) {
context.mutate_input(1, self);
}
return {self};
}
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -0,0 +1,101 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/roll.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
namespace {
Output<Node> translate_cross_base(const NodeContext& context, Output<Node> self, Output<Node> other, Output<Node> dim) {
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
auto x_roll_1 = context.mark_node(std::make_shared<v7::Roll>(self, const_2, dim));
auto x_roll_2 = context.mark_node(std::make_shared<v7::Roll>(self, const_1, dim));
auto y_roll_1 = context.mark_node(std::make_shared<v7::Roll>(other, const_1, dim));
auto y_roll_2 = context.mark_node(std::make_shared<v7::Roll>(other, const_2, dim));
auto mul_1 = context.mark_node(std::make_shared<v1::Multiply>(x_roll_1, y_roll_1));
auto mul_2 = context.mark_node(std::make_shared<v1::Multiply>(x_roll_2, y_roll_2));
return context.mark_node(std::make_shared<v1::Subtract>(mul_1, mul_2));
}
} // namespace
OutputVector translate_linalg_cross(const NodeContext& context) {
// aten::linalg_cross(Tensor self, Tensor other, int? dim=-1) -> Tensor
// aten::linalg_cross.out(Tensor self, Tensor other, int? dim=-1, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto self = context.get_input(0);
auto other = context.get_input(1);
align_eltwise_input_types(context, self, other, true);
auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
Output<Node> dim;
if (context.input_is_none(2)) {
dim = const_minus_1;
} else {
dim = context.get_input(2);
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
dim = context.mark_node(std::make_shared<v0::Unsqueeze>(dim, const_0));
}
auto res = translate_cross_base(context, self, other, dim);
if (!context.input_is_none(3)) {
context.mutate_input(3, res);
}
return {res};
};
OutputVector translate_cross(const NodeContext& context) {
// aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor
// aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto self = context.get_input(0);
auto other = context.get_input(1);
align_eltwise_input_types(context, self, other, true);
Output<Node> dim;
if (context.input_is_none(2)) {
// If dim is not given, it defaults to the first dimension found with the size 3
auto pshape = self.get_partial_shape();
if (pshape.rank().is_dynamic()) {
FRONT_END_GENERAL_CHECK(false, "Rank should be known for aten::cross without explicit dim");
}
size_t dim_id = static_cast<size_t>(pshape.rank().get_length());
size_t rank = static_cast<size_t>(pshape.rank().get_length());
for (size_t i = 0; i < rank; i++) {
if (pshape[i].is_static() && pshape[i] == ov::Dimension(3)) {
dim_id = i;
break;
}
}
if (dim_id == rank) {
FRONT_END_GENERAL_CHECK(false, "Suitable dim for aten::cross not found");
}
dim = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {dim_id}));
} else {
dim = context.get_input(2);
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
dim = context.mark_node(std::make_shared<v0::Unsqueeze>(dim, const_0));
}
auto res = translate_cross_base(context, self, other, dim);
if (!context.input_is_none(3)) {
context.mutate_input(3, res);
}
return {res};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -26,6 +26,7 @@ OP_CONVERTER(translate_add);
OP_CONVERTER(translate_add_);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_alias_copy);
OP_CONVERTER(translate_all);
OP_CONVERTER(translate_amax);
OP_CONVERTER(translate_amin);
@ -54,6 +55,7 @@ OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
OP_CONVERTER(translate_copy_);
OP_CONVERTER(translate_cross);
OP_CONVERTER(translate_cumsum);
OP_CONVERTER(translate_deform_conv);
OP_CONVERTER(translate_derive_index);
@ -97,6 +99,7 @@ OP_CONVERTER(translate_int);
OP_CONVERTER(translate_is_nonzero);
OP_CONVERTER(translate_layer_norm);
OP_CONVERTER(translate_len);
OP_CONVERTER(translate_linalg_cross);
OP_CONVERTER(translate_linalg_norm);
OP_CONVERTER(translate_linalg_matrix_norm);
OP_CONVERTER(translate_linalg_vector_norm);
@ -266,6 +269,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::add_", op::translate_add_},
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::alias", op::skip_node},
{"aten::alias_copy", op::translate_alias_copy},
{"aten::all", op::translate_all},
{"aten::amax", op::translate_amax},
{"aten::amin", op::translate_amin},
@ -323,6 +328,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cos>>},
{"aten::cosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cosh>},
{"aten::cosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cosh>>},
{"aten::cross", op::translate_cross},
{"aten::cumsum", op::translate_cumsum},
{"aten::detach", op::skip_node},
{"aten::dequantize", op::skip_node}, // we convert model to fp32 using FQ, so dequantization is not needed
@ -390,6 +396,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::lift", op::skip_node},
{"aten::lift_fresh", op::skip_node},
{"aten::lift_fresh_copy", op::skip_node},
{"aten::linalg_cross", op::translate_linalg_cross},
{"aten::linalg_norm", op::translate_linalg_norm},
{"aten::linalg_matrix_norm", op::translate_linalg_matrix_norm},
{"aten::linalg_vector_norm", op::translate_linalg_vector_norm},

View File

@ -31,3 +31,36 @@ class TestCopy(PytorchLayerTest):
@pytest.mark.parametrize("value", [1, [2.5], range(224)])
def test_copy_(self, value, ie_device, precision, ir_version):
self._test(*self.create_model(value), ie_device, precision, ir_version)
class TestAliasCopy(PytorchLayerTest):
def _prepare_input(self, out):
import numpy as np
if not out:
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
return (np.random.randn(1, 3, 224, 224).astype(np.float32), np.zeros((1, 3, 224, 224), dtype=np.float32))
def create_model(self, out):
import torch
class aten_copy(torch.nn.Module):
def __init__(self, out):
super(aten_copy, self).__init__()
if out:
self.forward = self.forward_out
def forward(self, x):
return torch.alias_copy(x)
def forward_out(self, x, y):
return torch.alias_copy(x, out=y), y
ref_net = None
return aten_copy(out), ref_net, "aten::alias_copy"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("out", [True, False])
def test_copy_(self, out, ie_device, precision, ir_version):
self._test(*self.create_model(out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out})

View File

@ -0,0 +1,122 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestLinalgCross(PytorchLayerTest):
def _prepare_input(self, x_shape, y_shape, out, dtype):
import numpy as np
x = np.random.randn(*x_shape).astype(dtype)
y = np.random.randn(*y_shape).astype(dtype)
if not out:
return (x, y)
return (x, y, np.zeros(np.maximum(np.array(x_shape), np.array(y_shape)).tolist(), dtype=dtype))
def create_model(self, dim, out):
import torch
class aten_linalg_cross(torch.nn.Module):
def __init__(self, dim, out):
super(aten_linalg_cross, self).__init__()
if dim is None:
self.forward = self.forward_no_dim_no_out if not out else self.forward_no_dim_out
elif out:
self.forward = self.forward_out
self.dim = dim
def forward(self, x, y):
return torch.linalg.cross(x, y, dim=self.dim)
def forward_out(self, x, y, out):
return torch.linalg.cross(x, y, dim=self.dim, out=out), out
def forward_no_dim_out(self, x, y, out):
return torch.linalg.cross(x, y, out=out), out
def forward_no_dim_no_out(self, x, y):
return torch.linalg.cross(x, y)
ref_net = None
return aten_linalg_cross(dim, out), ref_net, "aten::linalg_cross"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("x_shape,y_shape,dim", [
((4, 3), (4, 3), None),
((1, 3), (4, 3), -1),
((4, 3), (1, 3), 1),
((3, 5), (3, 5), 0),
((2, 3, 4), (2, 3, 4), 1)
])
@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize("out", [True, False])
def test_linalg_cross(self, x_shape, y_shape, dim, out, dtype, ie_device, precision, ir_version):
self._test(
*self.create_model(dim, out), ie_device, precision, ir_version, use_convert_model=True,
kwargs_to_prepare_input={"x_shape":x_shape, "y_shape": y_shape, "out": out, 'dtype': dtype})
class TestCross(PytorchLayerTest):
def _prepare_input(self, x_shape, y_shape, out, dtype):
import numpy as np
x = np.random.randn(*x_shape).astype(dtype)
y = np.random.randn(*y_shape).astype(dtype)
if not out:
return (x, y)
return (x, y, np.zeros(np.maximum(np.array(x_shape), np.array(y_shape)).tolist(), dtype=dtype))
def create_model(self, dim, out, shape):
import torch
class aten_cross(torch.nn.Module):
def __init__(self, dim, out, shape):
super(aten_cross, self).__init__()
if dim is None:
self.forward = self.forward_no_dim_no_out if not out else self.forward_no_dim_out
elif out:
self.forward = self.forward_out
self.dim = dim
self.shape = shape
def forward(self, x, y):
return torch.cross(x, y, dim=self.dim)
def forward_out(self, x, y, out):
return torch.cross(x, y, dim=self.dim, out=out), out
def forward_no_dim_out(self, x, y, out):
x = torch.reshape(x, self.shape)
return torch.cross(x, y, out=out), out
def forward_no_dim_no_out(self, x, y):
x = torch.reshape(x, self.shape)
return torch.cross(x, y)
ref_net = None
return aten_cross(dim, out, shape), ref_net, "aten::cross"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("x_shape,y_shape,dim", [
((1, 3), (4, 3), -1),
((4, 3), (1, 3), 1),
((3, 5), (3, 5), 0),
((2, 3, 4), (2, 3, 4), 1),
((3, 1), (3, 4), None),
((4, 3), (4, 3), None),
((2, 3, 4), (2, 3, 4), None),
])
@pytest.mark.parametrize("out", [True, False])
@pytest.mark.parametrize('dtype', ['float32', 'float64'])
def test_linalg_cross(self, x_shape, y_shape, dim, out, dtype, ie_device, precision, ir_version):
self._test(*self.create_model(dim, out, x_shape), ie_device, precision, ir_version,
use_convert_model=True,
kwargs_to_prepare_input={"x_shape":x_shape, "y_shape": y_shape, "out": out, "dtype": dtype})