[PT FE] Add aten::__xor__ (#20662)

* Add __xor__

* Add xor tests

* add more xfail tests

* Update src/frontends/pytorch/src/op_table.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Update src/frontends/pytorch/src/op_table.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* fix code style

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Karan Jakhar 2023-10-26 19:58:47 +05:30 committed by GitHub
parent 7720135f58
commit 26632d1cd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 100 additions and 0 deletions

View File

@ -6,6 +6,7 @@
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_not.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/logical_xor.hpp"
#include "utils.hpp"
namespace ov {
@ -45,6 +46,16 @@ OutputVector translate_bitwise_or(const NodeContext& context) {
return {or_x};
};
OutputVector translate_bitwise_xor(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto y = context.get_input(1);
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),
"aten::bitwise_xor supported only for boolean input");
auto xor_x = context.mark_node(std::make_shared<ov::op::v1::LogicalXor>(x, y));
return {xor_x};
};
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -131,6 +131,7 @@ OP_CONVERTER(translate_one_hot);
OP_CONVERTER(translate_ones);
OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_or);
OP_CONVERTER(translate_bitwise_xor);
OP_CONVERTER(translate_outer);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pairwise_distance);
@ -233,6 +234,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::__getitem__", op::translate_getitem},
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
{"aten::__or__", op::translate_or},
{"aten::__xor__", op::translate_bitwise_xor},
{"aten::__range_length", op::translate_range_length},
{"aten::_convolution", op::translate_convolution},
{"aten::_convolution_mode", op::translate_convolution_mode},

View File

@ -0,0 +1,87 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestXor(PytorchLayerTest):
def _prepare_input(self):
return self.input_data
def create_model_tensor_input(self):
class aten_xor_tensor(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, tensor_a, tensor_b):
return tensor_a ^ tensor_b
ref_net = None
return aten_xor_tensor(), ref_net, "aten::__xor__"
def create_model_bool_input(self):
class aten_xor_bool(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, bool_a: bool, bool_b: bool):
return bool_a ^ bool_b
ref_net = None
return aten_xor_bool(), ref_net, "aten::__xor__"
def create_model_int_input(self):
class aten_xor_int(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, int_a: int, int_b: int):
return int_a ^ int_b
ref_net = None
return aten_xor_int(), ref_net, "aten::__xor__"
@pytest.mark.nightly
@pytest.mark.precommit
def test_xor_tensor(self, ie_device, precision, ir_version):
self.input_data = (np.array([True, False, False], dtype=np.bool_), np.array(
[True, True, False], dtype=np.bool_))
self._test(*self.create_model_tensor_input(),
ie_device, precision, ir_version)
@pytest.mark.nightly
@pytest.mark.precommit
def test_xor_bool(self, ie_device, precision, ir_version):
self.input_data = (np.array(True, dtype=np.bool_),
np.array(True, dtype=np.bool_))
self._test(*self.create_model_bool_input(),
ie_device, precision, ir_version)
@pytest.mark.xfail(reason="bitwise_xor is not implemented")
@pytest.mark.nightly
@pytest.mark.precommit
def test_xor_int(self, ie_device, precision, ir_version):
self.input_data = (np.array(3, dtype=np.int),
np.array(4, dtype=np.int))
self._test(*self.create_model_int_input(),
ie_device, precision, ir_version)
@pytest.mark.xfail(reason="bitwise_xor is not implemented")
@pytest.mark.nightly
@pytest.mark.precommit
def test_xor_tensor(self, ie_device, precision, ir_version):
self.input_data = (np.array([3, 5, 8], dtype=np.int), np.array(
[7, 11, 2], dtype=np.int))
self._test(*self.create_model_tensor_input(),
ie_device, precision, ir_version)