[PT FE]: extend logical operations support (#19981)
* [PT FE]: extend logical operations support * tests * more tests
This commit is contained in:
parent
3f3d89678e
commit
8d59fcd34f
@ -4,7 +4,9 @@
|
|||||||
|
|
||||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
#include "openvino/op/logical_and.hpp"
|
#include "openvino/op/logical_and.hpp"
|
||||||
|
#include "openvino/op/logical_not.hpp"
|
||||||
#include "openvino/op/logical_or.hpp"
|
#include "openvino/op/logical_or.hpp"
|
||||||
|
#include "openvino/op/logical_xor.hpp"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
@ -15,25 +17,57 @@ namespace op {
|
|||||||
using namespace ov::op;
|
using namespace ov::op;
|
||||||
|
|
||||||
OutputVector translate_or(const NodeContext& context) {
|
OutputVector translate_or(const NodeContext& context) {
|
||||||
num_inputs_check(context, 2, 2);
|
num_inputs_check(context, 2, 3);
|
||||||
auto x = context.get_input(0);
|
auto x = context.get_input(0);
|
||||||
auto y = context.get_input(1);
|
auto y = context.get_input(1);
|
||||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
||||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
|
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
|
||||||
// TODO: use bitwise op here when will be supported by openvino
|
// TODO: use bitwise op here when will be supported by openvino
|
||||||
auto or_node = context.mark_node(std::make_shared<v1::LogicalOr>(x, y));
|
auto or_node = context.mark_node(std::make_shared<v1::LogicalOr>(x, y));
|
||||||
|
if (!context.input_is_none(2)) {
|
||||||
|
context.mutate_input(2, or_node);
|
||||||
|
}
|
||||||
return {or_node};
|
return {or_node};
|
||||||
};
|
};
|
||||||
|
|
||||||
OutputVector translate_and(const NodeContext& context) {
|
OutputVector translate_and(const NodeContext& context) {
|
||||||
num_inputs_check(context, 2, 2);
|
num_inputs_check(context, 2, 3);
|
||||||
auto x = context.get_input(0);
|
auto x = context.get_input(0);
|
||||||
auto y = context.get_input(1);
|
auto y = context.get_input(1);
|
||||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
||||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
|
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
|
||||||
// TODO: use bitwise op here when will be supported by openvino
|
// TODO: use bitwise op here when will be supported by openvino
|
||||||
auto or_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
|
auto and_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
|
||||||
return {or_node};
|
if (!context.input_is_none(2)) {
|
||||||
|
context.mutate_input(2, and_node);
|
||||||
|
}
|
||||||
|
return {and_node};
|
||||||
|
};
|
||||||
|
|
||||||
|
OutputVector translate_not(const NodeContext& context) {
|
||||||
|
num_inputs_check(context, 1, 2);
|
||||||
|
auto x = context.get_input(0);
|
||||||
|
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
||||||
|
// TODO: use bitwise op here when will be supported by openvino
|
||||||
|
auto not_node = context.mark_node(std::make_shared<v1::LogicalNot>(x));
|
||||||
|
if (!context.input_is_none(1)) {
|
||||||
|
context.mutate_input(1, not_node);
|
||||||
|
}
|
||||||
|
return {not_node};
|
||||||
|
};
|
||||||
|
|
||||||
|
OutputVector translate_xor(const NodeContext& context) {
|
||||||
|
num_inputs_check(context, 2, 3);
|
||||||
|
auto x = context.get_input(0);
|
||||||
|
auto y = context.get_input(1);
|
||||||
|
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
|
||||||
|
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
|
||||||
|
// TODO: use bitwise op here when will be supported by openvino
|
||||||
|
auto xor_node = context.mark_node(std::make_shared<v1::LogicalXor>(x, y));
|
||||||
|
if (!context.input_is_none(2)) {
|
||||||
|
context.mutate_input(2, xor_node);
|
||||||
|
}
|
||||||
|
return {xor_node};
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -111,6 +111,7 @@ OP_CONVERTER(translate_new_zeros);
|
|||||||
OP_CONVERTER(translate_nms);
|
OP_CONVERTER(translate_nms);
|
||||||
OP_CONVERTER(translate_nonzero);
|
OP_CONVERTER(translate_nonzero);
|
||||||
OP_CONVERTER(translate_norm);
|
OP_CONVERTER(translate_norm);
|
||||||
|
OP_CONVERTER(translate_not);
|
||||||
OP_CONVERTER(translate_numel);
|
OP_CONVERTER(translate_numel);
|
||||||
OP_CONVERTER(translate_one_hot);
|
OP_CONVERTER(translate_one_hot);
|
||||||
OP_CONVERTER(translate_ones);
|
OP_CONVERTER(translate_ones);
|
||||||
@ -188,6 +189,7 @@ OP_CONVERTER(translate_quantized_cat);
|
|||||||
OP_CONVERTER(translate_quantized_convnd);
|
OP_CONVERTER(translate_quantized_convnd);
|
||||||
OP_CONVERTER(translate_quantized_convnd_relu);
|
OP_CONVERTER(translate_quantized_convnd_relu);
|
||||||
OP_CONVERTER(translate_quantized_linear);
|
OP_CONVERTER(translate_quantized_linear);
|
||||||
|
OP_CONVERTER(translate_xor);
|
||||||
// Torch FX Translations
|
// Torch FX Translations
|
||||||
OP_CONVERTER(translate_arange_fx);
|
OP_CONVERTER(translate_arange_fx);
|
||||||
OP_CONVERTER(translate_batch_norm_fx);
|
OP_CONVERTER(translate_batch_norm_fx);
|
||||||
@ -343,6 +345,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
|||||||
{"aten::linspace", op::translate_linspace},
|
{"aten::linspace", op::translate_linspace},
|
||||||
{"aten::log", op::translate_log},
|
{"aten::log", op::translate_log},
|
||||||
{"aten::log_", op::inplace_op<op::translate_log>},
|
{"aten::log_", op::inplace_op<op::translate_log>},
|
||||||
|
{"aten::logical_and", op::translate_and},
|
||||||
|
{"aten::logical_or", op::translate_or},
|
||||||
|
{"aten::logical_not", op::translate_not},
|
||||||
|
{"aten::logical_xor", op::translate_xor},
|
||||||
{"aten::log_softmax", op::translate_log_softmax},
|
{"aten::log_softmax", op::translate_log_softmax},
|
||||||
{"aten::log2", op::translate_log2},
|
{"aten::log2", op::translate_log2},
|
||||||
{"aten::log2_", op::inplace_op<op::translate_log2>},
|
{"aten::log2_", op::inplace_op<op::translate_log2>},
|
||||||
|
64
tests/layer_tests/pytorch_tests/test_logical_ops.py
Normal file
64
tests/layer_tests/pytorch_tests/test_logical_ops.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
class TestLogicalOp(PytorchLayerTest):
|
||||||
|
|
||||||
|
def _prepare_input(self, out, unary, first_dtype, second_dtype):
|
||||||
|
x = np.random.randint(1, 5, (1, 10)).astype(first_dtype)
|
||||||
|
if unary:
|
||||||
|
return (x, ) if not out else (x, np.zeros_like(x).astype(bool))
|
||||||
|
y = np.random.randint(1, 5, (1, 10)).astype(second_dtype)
|
||||||
|
if not out:
|
||||||
|
return x, y
|
||||||
|
return x, y, np.zeros_like(x).astype(bool)
|
||||||
|
|
||||||
|
def create_model(self, op_name, out):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ops = {
|
||||||
|
"and": torch.logical_and,
|
||||||
|
"or": torch.logical_or,
|
||||||
|
"xor": torch.logical_xor,
|
||||||
|
"not": torch.logical_not
|
||||||
|
}
|
||||||
|
op = ops[op_name]
|
||||||
|
class aten_logical(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, op, out) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.op = op
|
||||||
|
if op == torch.logical_not:
|
||||||
|
self.forward = self.forward_not
|
||||||
|
if out:
|
||||||
|
self.forward = self.forward_out if not op == torch.logical_not else self.forward_not_out
|
||||||
|
|
||||||
|
def forward(self, tensor_a, tensor_b):
|
||||||
|
return self.op(tensor_a, tensor_b)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_out(self, tensor_a, tensor_b, out):
|
||||||
|
return self.op(tensor_a, tensor_b, out=out), out
|
||||||
|
|
||||||
|
def forward_not(self, tensor_a):
|
||||||
|
return self.op(tensor_a)
|
||||||
|
|
||||||
|
def forward_not_out(self, tensor_a, out):
|
||||||
|
return self.op(tensor_a, out=out), out
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_logical(op, out), ref_net, f"aten::logical_{op_name}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
@pytest.mark.parametrize("op_type", ["and", "or", "not", "xor"])
|
||||||
|
@pytest.mark.parametrize("first_dtype", ["bool", "int32", 'int8', 'float32'])
|
||||||
|
@pytest.mark.parametrize("second_dtype", ["bool", "int32", 'int8', 'float32'])
|
||||||
|
@pytest.mark.parametrize("out", [True, False])
|
||||||
|
def test_logical(self, op_type, out, first_dtype, second_dtype, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(op_type, out),
|
||||||
|
ie_device, precision, ir_version,
|
||||||
|
kwargs_to_prepare_input={"out": out, "unary": op_type == "not",
|
||||||
|
"first_dtype": first_dtype, "second_dtype": second_dtype})
|
Loading…
Reference in New Issue
Block a user