[PT FE] Support non boolean inputs for __or__ and __and__ operations (#19268)
* [PT FE] Support non boolean inputs for __or__ and __and__ operations * Add test for __or__
This commit is contained in:
42
src/frontends/pytorch/src/op/logical.cpp
Normal file
42
src/frontends/pytorch/src/op/logical.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/logical_and.hpp"
|
||||
#include "openvino/op/logical_or.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_or(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
|
||||
// TODO: use bitwise op here when will be supported by openvino
|
||||
auto or_node = context.mark_node(std::make_shared<v1::LogicalOr>(x, y));
|
||||
return {or_node};
|
||||
};
|
||||
|
||||
OutputVector translate_and(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
|
||||
// TODO: use bitwise op here when will be supported by openvino
|
||||
auto or_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
|
||||
return {or_node};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -22,6 +22,7 @@ OP_CONVERTER(translate_add);
|
||||
OP_CONVERTER(translate_addcmul);
|
||||
OP_CONVERTER(translate_addmm);
|
||||
OP_CONVERTER(translate_all);
|
||||
OP_CONVERTER(translate_and);
|
||||
OP_CONVERTER(translate_arange);
|
||||
OP_CONVERTER(translate_argmax);
|
||||
OP_CONVERTER(translate_argsort);
|
||||
@@ -111,6 +112,7 @@ OP_CONVERTER(translate_norm);
|
||||
OP_CONVERTER(translate_numel);
|
||||
OP_CONVERTER(translate_ones);
|
||||
OP_CONVERTER(translate_ones_like);
|
||||
OP_CONVERTER(translate_or);
|
||||
OP_CONVERTER(translate_outer);
|
||||
OP_CONVERTER(translate_pad);
|
||||
OP_CONVERTER(translate_pairwise_distance);
|
||||
@@ -202,11 +204,11 @@ OP_CONVERTER(translate_transpose_fx);
|
||||
// Supported ops for TorchScript
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
return {
|
||||
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
|
||||
{"aten::__and__", op::translate_and},
|
||||
{"aten::__derive_index", op::translate_derive_index},
|
||||
{"aten::__getitem__", op::translate_getitem},
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::__or__", op::translate_1to1_match_2_inputs<opset10::LogicalOr>},
|
||||
{"aten::__or__", op::translate_or},
|
||||
{"aten::__range_length", op::translate_range_length},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
|
||||
28
tests/layer_tests/pytorch_tests/test_or.py
Normal file
28
tests/layer_tests/pytorch_tests/test_or.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestLog(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randint(0, 255, (20, 30, 40, 50)),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_or(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
res = torch.ByteTensor(x.size()).zero_()
|
||||
res[:, :, :, 1:] = res[:, :, :, 1:] | (x[:, :, :, 1:] != x[:, :, :, :-1])
|
||||
res[:, :, :, :-1] = res[:, :, :, :-1] | (x[:, :, :, 1:] != x[:, :, :, :-1])
|
||||
return res.float()
|
||||
|
||||
return aten_or(), None, "aten::__or__"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_or(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, dynamic_shapes=False, trace_model=True)
|
||||
Reference in New Issue
Block a user