[PT FE]: support aten:Bool, add tests for aten::add_ (#15590)
This commit is contained in:
parent
efe3b27f5b
commit
609dee0abc
22
src/frontends/pytorch/src/op/bool.cpp
Normal file
22
src/frontends/pytorch/src/op/bool.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_bool(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::boolean))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -22,6 +22,7 @@ OP_CONVERTER(translate_addmm);
|
||||
OP_CONVERTER(translate_arange);
|
||||
OP_CONVERTER(translate_as_tensor);
|
||||
OP_CONVERTER(translate_avg_poolnd);
|
||||
OP_CONVERTER(translate_bool);
|
||||
OP_CONVERTER(translate_batch_norm);
|
||||
OP_CONVERTER(translate_clamp);
|
||||
OP_CONVERTER(translate_constant);
|
||||
@ -149,6 +150,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::avg_pool3d", op::translate_avg_poolnd},
|
||||
{"aten::batch_norm", op::translate_batch_norm},
|
||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::Bool", op::translate_bool},
|
||||
// {"aten::cat", done as transformation},
|
||||
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
|
||||
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
|
||||
|
@ -18,25 +18,30 @@ class TestAdd(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(2, 5, 3, 4).astype(np.float32), self.input_rhs)
|
||||
|
||||
def create_model(self, alpha):
|
||||
def create_model(self, alpha, op_type):
|
||||
class aten_add(torch.nn.Module):
|
||||
|
||||
def __init__(self, alpha) -> None:
|
||||
def __init__(self, alpha, op) -> None:
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.forward = self.forward1 if op == "add" else self.forward2
|
||||
|
||||
def forward(self, lhs, rhs):
|
||||
def forward1(self, lhs, rhs):
|
||||
return torch.add(lhs, rhs, alpha=self.alpha)
|
||||
|
||||
def forward2(self, lhs, rhs):
|
||||
return lhs.add_(rhs, alpha=self.alpha)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_add(alpha), ref_net, "aten::add"
|
||||
return aten_add(alpha, op_type), ref_net, f"aten::{op_type}"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_add(self, ie_device, precision, ir_version, alpha, input_rhs):
|
||||
@pytest.mark.parametrize("op_type", ["add", "add_"])
|
||||
def test_add(self, ie_device, precision, ir_version, alpha, input_rhs, op_type):
|
||||
self.input_rhs = input_rhs
|
||||
self._test(*self.create_model(alpha), ie_device, precision, ir_version)
|
||||
self._test(*self.create_model(alpha, op_type), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestAddTypes(PytorchLayerTest):
|
||||
|
36
tests/layer_tests/pytorch_tests/test_bool.py
Normal file
36
tests/layer_tests/pytorch_tests/test_bool.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestBool(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1).astype(np.int32),)
|
||||
|
||||
def create_model(self, input_type):
|
||||
import torch
|
||||
|
||||
class prim_bool(torch.nn.Module):
|
||||
def __init__(self, input_type):
|
||||
super(prim_bool, self).__init__()
|
||||
self.forward = self.forward_tensor if input_type != "scalar" else self.forward_scalar
|
||||
|
||||
def forward_tensor(self, x):
|
||||
return bool(x)
|
||||
|
||||
def forward_scalar(self, x:int):
|
||||
return bool(x)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return prim_bool(input_type), ref_net, "aten::Bool"
|
||||
|
||||
@pytest.mark.parametrize("input_type", ["tensor", "scalar"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_ceil(self, ie_device, precision, ir_version, input_type):
|
||||
self._test(*self.create_model(input_type), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user