diff --git a/src/frontends/pytorch/src/op/logical.cpp b/src/frontends/pytorch/src/op/logical.cpp new file mode 100644 index 00000000000..0c5a93e2c91 --- /dev/null +++ b/src/frontends/pytorch/src/op/logical.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/logical_and.hpp" +#include "openvino/op/logical_or.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_or(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto x = context.get_input(0); + auto y = context.get_input(1); + x = context.mark_node(std::make_shared(x, element::boolean)); + y = context.mark_node(std::make_shared(y, element::boolean)); + // TODO: use bitwise op here when will be supported by openvino + auto or_node = context.mark_node(std::make_shared(x, y)); + return {or_node}; +}; + +OutputVector translate_and(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto x = context.get_input(0); + auto y = context.get_input(1); + x = context.mark_node(std::make_shared(x, element::boolean)); + y = context.mark_node(std::make_shared(y, element::boolean)); + // TODO: use bitwise op here when will be supported by openvino + auto or_node = context.mark_node(std::make_shared(x, y)); + return {or_node}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index fbc13a99447..1fc92a4d18f 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -22,6 +22,7 @@ OP_CONVERTER(translate_add); OP_CONVERTER(translate_addcmul); OP_CONVERTER(translate_addmm); OP_CONVERTER(translate_all); +OP_CONVERTER(translate_and); OP_CONVERTER(translate_arange); OP_CONVERTER(translate_argmax); OP_CONVERTER(translate_argsort); @@ -111,6 +112,7 @@ OP_CONVERTER(translate_norm); OP_CONVERTER(translate_numel); OP_CONVERTER(translate_ones); OP_CONVERTER(translate_ones_like); +OP_CONVERTER(translate_or); OP_CONVERTER(translate_outer); OP_CONVERTER(translate_pad); OP_CONVERTER(translate_pairwise_distance); @@ -202,11 +204,11 @@ OP_CONVERTER(translate_transpose_fx); // Supported ops for TorchScript const std::map get_supported_ops_ts() { return { - {"aten::__and__", op::translate_1to1_match_2_inputs}, // TODO: cover numerical cases + {"aten::__and__", op::translate_and}, {"aten::__derive_index", op::translate_derive_index}, {"aten::__getitem__", op::translate_getitem}, {"aten::__not__", op::translate_1to1_match_1_inputs}, - {"aten::__or__", op::translate_1to1_match_2_inputs}, + {"aten::__or__", op::translate_or}, {"aten::__range_length", op::translate_range_length}, {"aten::_convolution", op::translate_convolution}, {"aten::_convolution_mode", op::translate_convolution_mode}, diff --git a/tests/layer_tests/pytorch_tests/test_or.py b/tests/layer_tests/pytorch_tests/test_or.py new file mode 100644 index 00000000000..c6592a11af0 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_or.py @@ -0,0 +1,28 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestLog(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randint(0, 255, (20, 30, 40, 50)),) + + def create_model(self): + import torch + + class aten_or(torch.nn.Module): + def forward(self, x): + res = torch.ByteTensor(x.size()).zero_() + res[:, :, :, 1:] = res[:, :, :, 1:] | (x[:, :, :, 1:] != x[:, :, :, :-1]) + res[:, :, :, :-1] = res[:, :, :, :-1] | (x[:, :, :, 1:] != x[:, :, :, :-1]) + return res.float() + + return aten_or(), None, "aten::__or__" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_or(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, dynamic_shapes=False, trace_model=True)