[PT FE] Add aten::__xor__ (#20662)
* Add __xor__ * Add xor tests * add more xfail tests * Update src/frontends/pytorch/src/op_table.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * Update src/frontends/pytorch/src/op_table.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * fix code style --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
7720135f58
commit
26632d1cd9
@ -6,6 +6,7 @@
|
||||
#include "openvino/op/logical_and.hpp"
|
||||
#include "openvino/op/logical_not.hpp"
|
||||
#include "openvino/op/logical_or.hpp"
|
||||
#include "openvino/op/logical_xor.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -45,6 +46,16 @@ OutputVector translate_bitwise_or(const NodeContext& context) {
|
||||
return {or_x};
|
||||
};
|
||||
|
||||
OutputVector translate_bitwise_xor(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),
|
||||
"aten::bitwise_xor supported only for boolean input");
|
||||
auto xor_x = context.mark_node(std::make_shared<ov::op::v1::LogicalXor>(x, y));
|
||||
return {xor_x};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -131,6 +131,7 @@ OP_CONVERTER(translate_one_hot);
|
||||
OP_CONVERTER(translate_ones);
|
||||
OP_CONVERTER(translate_ones_like);
|
||||
OP_CONVERTER(translate_or);
|
||||
OP_CONVERTER(translate_bitwise_xor);
|
||||
OP_CONVERTER(translate_outer);
|
||||
OP_CONVERTER(translate_pad);
|
||||
OP_CONVERTER(translate_pairwise_distance);
|
||||
@ -233,6 +234,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::__getitem__", op::translate_getitem},
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::__or__", op::translate_or},
|
||||
{"aten::__xor__", op::translate_bitwise_xor},
|
||||
{"aten::__range_length", op::translate_range_length},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
|
87
tests/layer_tests/pytorch_tests/test_xor.py
Normal file
87
tests/layer_tests/pytorch_tests/test_xor.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestXor(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return self.input_data
|
||||
|
||||
def create_model_tensor_input(self):
|
||||
class aten_xor_tensor(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, tensor_a, tensor_b):
|
||||
return tensor_a ^ tensor_b
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_xor_tensor(), ref_net, "aten::__xor__"
|
||||
|
||||
def create_model_bool_input(self):
|
||||
class aten_xor_bool(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, bool_a: bool, bool_b: bool):
|
||||
return bool_a ^ bool_b
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_xor_bool(), ref_net, "aten::__xor__"
|
||||
|
||||
def create_model_int_input(self):
|
||||
class aten_xor_int(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, int_a: int, int_b: int):
|
||||
return int_a ^ int_b
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_xor_int(), ref_net, "aten::__xor__"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array([True, False, False], dtype=np.bool_), np.array(
|
||||
[True, True, False], dtype=np.bool_))
|
||||
self._test(*self.create_model_tensor_input(),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_bool(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(True, dtype=np.bool_),
|
||||
np.array(True, dtype=np.bool_))
|
||||
self._test(*self.create_model_bool_input(),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.xfail(reason="bitwise_xor is not implemented")
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_int(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array(3, dtype=np.int),
|
||||
np.array(4, dtype=np.int))
|
||||
self._test(*self.create_model_int_input(),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.xfail(reason="bitwise_xor is not implemented")
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_xor_tensor(self, ie_device, precision, ir_version):
|
||||
self.input_data = (np.array([3, 5, 8], dtype=np.int), np.array(
|
||||
[7, 11, 2], dtype=np.int))
|
||||
self._test(*self.create_model_tensor_input(),
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user