[PT FE]: multiple fixes for models from optimum testing scope (#18501)

* [PT FE]: multiple fixes for models from optimum testing scope

* Update src/frontends/pytorch/src/op_table.cpp
This commit is contained in:
Ekaterina Aidova
2023-07-14 10:37:53 +04:00
committed by GitHub
parent 1d9be8c76e
commit 4c49040ce6
3 changed files with 134 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/round.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_round(const NodeContext& context) {
// aten::round(Tensor self) -> Tensor
// aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
// aten::round.int(int a) -> float
// aten::round.float(float a) -> float
// aten::round.Scalar(Scalar a) -> Scalar
num_inputs_check(context, 1, 2);
auto data = context.get_input(0);
auto data_rank = data.get_partial_shape().rank();
auto is_scalar = data_rank.is_static() && data_rank.get_length() == 0;
auto is_integer = !data.get_element_type().is_dynamic() && data.get_element_type().is_integral();
if (is_scalar && is_integer) {
data = context.mark_node(std::make_shared<v0::Convert>(data, element::f32));
}
auto res = context.mark_node(std::make_shared<v5::Round>(data, v5::Round::RoundMode::HALF_TO_EVEN));
if (!context.input_is_none(1)) {
context.mutate_input(1, res);
}
return {res};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -116,6 +116,7 @@ OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_reshape_as);
OP_CONVERTER(translate_roi_align);
OP_CONVERTER(translate_roll);
OP_CONVERTER(translate_round);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_rsub);
OP_CONVERTER(translate_scaled_dot_product_attention);
@@ -280,6 +281,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
{"aten::leaky_relu_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::PRelu>>},
{"aten::len", op::translate_len},
// lift op is torchscript specific op responsible for tensors coping with guarantee of new memory allocation
{"aten::lift", op::skip_node},
{"aten::lift_fresh", op::skip_node},
{"aten::lift_fresh_copy", op::skip_node},
{"aten::linalg_norm", op::translate_linalg_norm},
{"aten::linalg_matrix_norm", op::translate_linalg_matrix_norm},
{"aten::linalg_vector_norm", op::translate_linalg_vector_norm},
@@ -303,6 +308,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
{"aten::multiply", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::multiply_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
{"aten::narrow", op::translate_narrow},
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
{"aten::neg", op::translate_neg},
@@ -329,6 +336,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::reshape", op::translate_reshape},
{"aten::reshape_as", op::translate_reshape_as},
{"aten::roll", op::translate_roll},
{"aten::round", op::translate_round},
{"aten::rsqrt", op::translate_rsqrt},
{"aten::rsub", op::translate_rsub},
{"aten::ScalarImplicit", op::skip_node},

View File

@@ -0,0 +1,84 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestRound(PytorchLayerTest):
def _prepare_input(self, out=False, dtype="float32"):
import numpy as np
input = np.random.randn(1, 3, 224, 224).astype(dtype)
if not out:
return (input, )
return (input, np.zeros_like(input))
def create_model(self, out=False):
import torch
class aten_round(torch.nn.Module):
def __init__(self, out):
super(aten_round, self).__init__()
if out:
self.forward = self.forward_out
def forward(self, x):
return torch.round(x)
def forward_out(self, x, y):
return torch.round(x, out=y), y
ref_net = None
return aten_round(out), ref_net, "aten::round"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("out", [True, False])
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"])
def test_round(self, out, dtype, ie_device, precision, ir_version):
self._test(*self.create_model(out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "dtype": dtype})
class TestRoundScalar(PytorchLayerTest):
def _prepare_input_int(self):
import numpy as np
return (np.array(np.random.randint(low=-5, high=5)), )
def _prepare_input_float(self):
import numpy as np
return (np.array(np.random.uniform(low=-5, high=5)), )
def create_model(self, input_type="float"):
import torch
class aten_round(torch.nn.Module):
def __init__(self, input_type):
super(aten_round, self).__init__()
if input_type == "int":
self.forward = self.forward_int
else:
self.forward = self.forward_float
def forward_int(self, x:int):
return torch.round(x)
def forward_float(self, x:float):
return torch.round(x)
ref_net = None
return aten_round(input_type), ref_net, "aten::round"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("input_type", ["int", "float"])
def test_round(self, input_type, ie_device, precision, ir_version):
if input_type == "int":
self._prepare_input = self._prepare_input_int
else:
self._prepare_input = self._prepare_input_float
self._test(*self.create_model(input_type), ie_device, precision, ir_version, trace_model=True)