[PT FE] Fix constant type in aten::square (#21580)

* Fix constant type in aten::square

* Add test

* Simplify test
This commit is contained in:
Maxim Vafin
2023-12-11 20:29:02 +01:00
committed by GitHub
parent 0f6b9abee8
commit e1b9d8c167
2 changed files with 40 additions and 1 deletions

View File

@@ -4,6 +4,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/power.hpp"
#include "utils.hpp"
@@ -15,9 +16,11 @@ namespace op {
using namespace ov::op;
OutputVector translate_square(const NodeContext& context) {
// aten::square(Tensor self) -> Tensor
num_inputs_check(context, 1, 1);
auto input_0 = context.get_input(0);
auto const_2 = context.mark_node(v0::Constant::create(input_0.get_element_type(), Shape{1}, {2}));
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
const_2 = context.mark_node(std::make_shared<v1::ConvertLike>(const_2, input_0));
return {context.mark_node(std::make_shared<v1::Power>(input_0, const_2))};
};

View File

@@ -0,0 +1,36 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestSquareTypes(PytorchLayerTest):
def _prepare_input(self):
return (torch.randn(self.shape).to(self.type).numpy(),)
def create_model(self, type):
class aten_square(torch.nn.Module):
def __init__(self, type):
super().__init__()
self.type = type
def forward(self, lhs):
return torch.square(lhs.to(self.type))
return aten_square(type), None, "aten::square"
@pytest.mark.parametrize(("type"), [torch.int32, torch.int64, torch.float32])
@pytest.mark.parametrize(("shape"), [[2, 3], [],])
@pytest.mark.nightly
@pytest.mark.precommit
def test_square_types(self, ie_device, precision, ir_version, type, shape):
self.type = type
self.shape = shape
self._test(*self.create_model(type),
ie_device, precision, ir_version)