Migrate LogicalXor to new API (#19913)

This commit is contained in:
Pawel Raasz
2023-09-21 12:09:49 +02:00
committed by GitHub
parent d90667c190
commit 5bf5e488b7
4 changed files with 99 additions and 92 deletions

View File

@@ -34,9 +34,7 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v1

View File

@@ -34,9 +34,7 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0

View File

@@ -4,21 +4,36 @@
#pragma once
#include <algorithm>
#include <cstddef>
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/shape.hpp"
#include "openvino/reference/autobroadcast_binop.hpp"
namespace ov {
namespace reference {
namespace func {
template <class T>
T logical_xor(const T a, const T b) {
return static_cast<T>((a || b) && !(a && b));
}
} // namespace func
template <typename T>
void logical_xor(const T* arg0, const T* arg1, T* out, size_t count) {
for (size_t i = 0; i < count; i++) {
out[i] = static_cast<T>((arg0[i] || arg1[i]) && !(arg0[i] && arg1[i]));
}
void logical_xor(const T* arg0, const T* arg1, T* out, const size_t count) {
std::transform(arg0, std::next(arg0, count), arg1, out, &func::logical_xor<T>);
}
/**
* @brief Reference implementation of binary elementwise LogicalXor operator.
*
* @param arg0 Pointer to input 0 data.
* @param arg1 Pointer to input 1 data.
* @param out Pointer to output data.
* @param arg_shape0 Input 0 shape.
* @param arg_shape1 Input 1 shape.
* @param broadcast_spec Broadcast specification mode.
*/
template <typename T>
void logical_xor(const T* arg0,
const T* arg1,
@@ -26,9 +41,7 @@ void logical_xor(const T* arg0,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec) {
autobroadcast_binop(arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return static_cast<T>((x || y) && !(x && y));
});
autobroadcast_binop(arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, &func::logical_xor<T>);
}
} // namespace reference
} // namespace ov

View File

@@ -2,105 +2,103 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/xor.hpp"
#include "openvino/op/xor.hpp"
#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/op/logical_xor.hpp"
#include "openvino/reference/xor.hpp"
#include "shape_util.hpp"
using namespace std;
using namespace ngraph;
op::v1::LogicalXor::LogicalXor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast) {
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::LogicalXor::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_LogicalXor_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v1::LogicalXor>(new_args.at(0), new_args.at(1), this->get_autob());
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace ov {
namespace op {
namespace logxor {
namespace {
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
const op::AutoBroadcastSpec& broadcast_spec) {
ov::reference::logical_xor(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<ET>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
struct Evaluate : element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t ET>
static result_type visit(const Tensor& arg0,
const Tensor& arg1,
Tensor& out,
const AutoBroadcastSpec& broadcast_spec) {
using T = typename element_type_traits<ET>::value_type;
reference::logical_xor(arg0.data<const T>(),
arg1.data<const T>(),
out.data<T>(),
arg0.get_shape(),
arg1.get_shape(),
broadcast_spec);
return true;
return true;
}
};
namespace {
bool input_supported_type(const element::Type& et) {
return et == element::boolean;
}
bool evaluate_logxor(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
const op::AutoBroadcastSpec& broadcast_spec) {
bool rc = true;
out->set_broadcast(broadcast_spec, arg0, arg1);
switch (arg0->get_element_type()) {
NGRAPH_TYPE_CASE(evaluate_logxor, boolean, arg0, arg1, out, broadcast_spec);
default:
rc = false;
break;
}
return rc;
bool evaluate(TensorVector& outputs, const TensorVector& inputs, const AutoBroadcastSpec& broadcast_spec) {
OPENVINO_ASSERT(outputs.size() == 1);
OPENVINO_ASSERT(inputs.size() == 2);
outputs[0].set_shape(ov::util::get_broadcast_shape(inputs[0].get_shape(), inputs[1].get_shape(), broadcast_spec));
using namespace ov::element;
return IfTypeOf<boolean>::apply<logxor::Evaluate>(inputs[0].get_element_type(),
inputs[0],
inputs[1],
outputs[0],
broadcast_spec);
}
} // namespace
} // namespace logxor
bool op::v1::LogicalXor::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
OV_OP_SCOPE(v1_LogicalXor_evaluate);
OPENVINO_SUPPRESS_DEPRECATED_START
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1) && validate_host_tensor_vector(inputs, 2));
OPENVINO_SUPPRESS_DEPRECATED_END
return logxor::evaluate_logxor(inputs[0], inputs[1], outputs[0], get_autob());
}
bool op::v1::LogicalXor::has_evaluate() const {
OV_OP_SCOPE(v1_LogicalXor_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::boolean:
return true;
default:
break;
}
return false;
}
op::v0::Xor::Xor(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& auto_broadcast)
namespace v0 {
Xor::Xor(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast) {
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Xor::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Xor::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v0_Xor_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v0::Xor>(new_args.at(0), new_args.at(1), this->get_autob());
return std::make_shared<Xor>(new_args.at(0), new_args.at(1), this->get_autob());
}
bool op::v0::Xor::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool Xor::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v0_Xor_evaluate);
return logxor::evaluate_logxor(inputs[0], inputs[1], outputs[0], get_autob());
return logxor::evaluate(outputs, inputs, get_autob());
}
bool op::v0::Xor::has_evaluate() const {
bool Xor::has_evaluate() const {
OV_OP_SCOPE(v0_Xor_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::boolean:
return true;
default:
break;
}
return false;
return logxor::input_supported_type(get_input_element_type(0));
}
} // namespace v0
namespace v1 {
LogicalXor::LogicalXor(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast) {
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> LogicalXor::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_LogicalXor_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<LogicalXor>(new_args.at(0), new_args.at(1), this->get_autob());
}
bool LogicalXor::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_LogicalXor_evaluate);
return logxor::evaluate(outputs, inputs, get_autob());
}
bool LogicalXor::has_evaluate() const {
OV_OP_SCOPE(v1_LogicalXor_has_evaluate);
return logxor::input_supported_type(get_input_element_type(0));
}
} // namespace v1
} // namespace op
} // namespace ov