[PT FE] Fix constant type in aten::square (#21580)
* Fix constant type in aten::square * Add test * Simplify test
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/power.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@@ -15,9 +16,11 @@ namespace op {
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_square(const NodeContext& context) {
|
||||
// aten::square(Tensor self) -> Tensor
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto input_0 = context.get_input(0);
|
||||
auto const_2 = context.mark_node(v0::Constant::create(input_0.get_element_type(), Shape{1}, {2}));
|
||||
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
|
||||
const_2 = context.mark_node(std::make_shared<v1::ConvertLike>(const_2, input_0));
|
||||
return {context.mark_node(std::make_shared<v1::Power>(input_0, const_2))};
|
||||
};
|
||||
|
||||
|
||||
36
tests/layer_tests/pytorch_tests/test_square.py
Normal file
36
tests/layer_tests/pytorch_tests/test_square.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestSquareTypes(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (torch.randn(self.shape).to(self.type).numpy(),)
|
||||
|
||||
def create_model(self, type):
|
||||
|
||||
class aten_square(torch.nn.Module):
|
||||
def __init__(self, type):
|
||||
super().__init__()
|
||||
self.type = type
|
||||
|
||||
def forward(self, lhs):
|
||||
return torch.square(lhs.to(self.type))
|
||||
|
||||
return aten_square(type), None, "aten::square"
|
||||
|
||||
@pytest.mark.parametrize(("type"), [torch.int32, torch.int64, torch.float32])
|
||||
@pytest.mark.parametrize(("shape"), [[2, 3], [],])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_square_types(self, ie_device, precision, ir_version, type, shape):
|
||||
self.type = type
|
||||
self.shape = shape
|
||||
self._test(*self.create_model(type),
|
||||
ie_device, precision, ir_version)
|
||||
Reference in New Issue
Block a user