[PT FE]: support aten:Bool, add tests for aten::add_ (#15590)

This commit is contained in:
Ekaterina Aidova 2023-02-14 02:29:43 +04:00 committed by GitHub
parent efe3b27f5b
commit 609dee0abc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 7 deletions

View File

@ -0,0 +1,22 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_bool(NodeContext& context) {
num_inputs_check(context, 1, 1);
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::boolean))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -22,6 +22,7 @@ OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_arange);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
@ -149,6 +150,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::avg_pool3d", op::translate_avg_poolnd},
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::Bool", op::translate_bool},
// {"aten::cat", done as transformation},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},

View File

@ -18,25 +18,30 @@ class TestAdd(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 5, 3, 4).astype(np.float32), self.input_rhs)
def create_model(self, alpha):
def create_model(self, alpha, op_type):
class aten_add(torch.nn.Module):
def __init__(self, alpha) -> None:
def __init__(self, alpha, op) -> None:
super().__init__()
self.alpha = alpha
self.forward = self.forward1 if op == "add" else self.forward2
def forward(self, lhs, rhs):
def forward1(self, lhs, rhs):
return torch.add(lhs, rhs, alpha=self.alpha)
def forward2(self, lhs, rhs):
return lhs.add_(rhs, alpha=self.alpha)
ref_net = None
return aten_add(alpha), ref_net, "aten::add"
return aten_add(alpha, op_type), ref_net, f"aten::{op_type}"
@pytest.mark.nightly
@pytest.mark.precommit
def test_add(self, ie_device, precision, ir_version, alpha, input_rhs):
@pytest.mark.parametrize("op_type", ["add", "add_"])
def test_add(self, ie_device, precision, ir_version, alpha, input_rhs, op_type):
self.input_rhs = input_rhs
self._test(*self.create_model(alpha), ie_device, precision, ir_version)
self._test(*self.create_model(alpha, op_type), ie_device, precision, ir_version)
class TestAddTypes(PytorchLayerTest):

View File

@ -0,0 +1,36 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestBool(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1).astype(np.int32),)
def create_model(self, input_type):
import torch
class prim_bool(torch.nn.Module):
def __init__(self, input_type):
super(prim_bool, self).__init__()
self.forward = self.forward_tensor if input_type != "scalar" else self.forward_scalar
def forward_tensor(self, x):
return bool(x)
def forward_scalar(self, x:int):
return bool(x)
ref_net = None
return prim_bool(input_type), ref_net, "aten::Bool"
@pytest.mark.parametrize("input_type", ["tensor", "scalar"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_ceil(self, ie_device, precision, ir_version, input_type):
self._test(*self.create_model(input_type), ie_device, precision, ir_version)