[core]Migrate Multiply operator to new API (#20853)
* Migrate Multiply operator to new API * Add comment explain use of custom multiply * Update custom multiply comment Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com> --------- Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com>
This commit is contained in:
parent
6210deba49
commit
b8eea7bf84
@ -29,9 +29,7 @@ public:
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
};
|
||||
} // namespace v1
|
||||
|
@ -4,21 +4,36 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
|
||||
#include "openvino/core/shape.hpp"
|
||||
#include "openvino/op/util/attr_types.hpp"
|
||||
#include "openvino/reference/autobroadcast_binop.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
namespace func {
|
||||
// Usage of custom function instead of std::multiplies gives smaller binary size.
|
||||
template <class T>
|
||||
constexpr T multiply(const T a, const T b) {
|
||||
return a * b;
|
||||
}
|
||||
} // namespace func
|
||||
|
||||
template <typename T>
|
||||
void multiply(const T* arg0, const T* arg1, T* out, size_t count) {
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
out[i] = arg0[i] * arg1[i];
|
||||
}
|
||||
void multiply(const T* arg0, const T* arg1, T* out, const size_t count) {
|
||||
std::transform(arg0, arg0 + count, arg1, out, func::multiply<T>);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Reference implementation of binary elementwise Multiply operator.
|
||||
*
|
||||
* @param arg0 Pointer to input 0 data.
|
||||
* @param arg1 Pointer to input 1 data.
|
||||
* @param out Pointer to output data.
|
||||
* @param arg_shape0 Input 0 shape.
|
||||
* @param arg_shape1 Input 1 shape.
|
||||
* @param broadcast_spec Broadcast specification mode.
|
||||
*/
|
||||
template <typename T>
|
||||
void multiply(const T* arg0,
|
||||
const T* arg1,
|
||||
@ -26,9 +41,7 @@ void multiply(const T* arg0,
|
||||
const Shape& arg0_shape,
|
||||
const Shape& arg1_shape,
|
||||
const op::AutoBroadcastSpec& broadcast_spec) {
|
||||
autobroadcast_binop(arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
|
||||
return x * y;
|
||||
});
|
||||
autobroadcast_binop(arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, func::multiply<T>);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace ov
|
||||
|
@ -2,90 +2,76 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/reference/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace multiply {
|
||||
struct Evaluate : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace multiplyop {
|
||||
namespace {
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
const op::AutoBroadcastSpec& broadcast_spec) {
|
||||
ov::reference::multiply(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<ET>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
broadcast_spec);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_multiply(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
const op::AutoBroadcastSpec& broadcast_spec) {
|
||||
bool rc = true;
|
||||
out->set_broadcast(broadcast_spec, arg0, arg1);
|
||||
switch (arg0->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, i32, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, i64, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, u32, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, u64, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, f16, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, f32, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, f64, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, bf16, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, u8, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, i16, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, u16, arg0, arg1, out, broadcast_spec);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(const Tensor& arg0,
|
||||
const Tensor& arg1,
|
||||
Tensor& out,
|
||||
const Shape& shape0,
|
||||
const Shape& shape1,
|
||||
const AutoBroadcastSpec& broadcast_spec) {
|
||||
reference::multiply(arg0.data<const T>(), arg1.data<const T>(), out.data<T>(), shape0, shape1, broadcast_spec);
|
||||
return true;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace multiplyop
|
||||
};
|
||||
} // namespace multiply
|
||||
|
||||
// ------------------------------------ v1 -------------------------------------
|
||||
op::v1::Multiply::Multiply(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& auto_broadcast)
|
||||
namespace v1 {
|
||||
Multiply::Multiply(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& auto_broadcast)
|
||||
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::Multiply::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Multiply::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v1_Multiply_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::v1::Multiply>(new_args.at(0), new_args.at(1), this->get_autob());
|
||||
return std::make_shared<Multiply>(new_args.at(0), new_args.at(1), get_autob());
|
||||
}
|
||||
|
||||
bool op::v1::Multiply::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool Multiply::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v1_Multiply_evaluate);
|
||||
return multiplyop::evaluate_multiply(inputs[0], inputs[1], outputs[0], get_autob());
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
outputs[0].set_shape(infer_broadcast_shape(this, inputs));
|
||||
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<bf16, f16, f32, f64, i32, i64, u32, u64>::apply<multiply::Evaluate>(inputs[0].get_element_type(),
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
outputs[0],
|
||||
inputs[0].get_shape(),
|
||||
inputs[1].get_shape(),
|
||||
get_autob());
|
||||
}
|
||||
|
||||
bool op::v1::Multiply::has_evaluate() const {
|
||||
bool Multiply::has_evaluate() const {
|
||||
OV_OP_SCOPE(v1_Multiply_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
case ngraph::element::u32:
|
||||
case ngraph::element::u64:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
case ngraph::element::f64:
|
||||
case ngraph::element::bf16:
|
||||
case element::bf16:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
case element::f64:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user