[core]Migrate TopK operator to new API (#20254)
* Migrate TopK to new API * Refactor compare_max for TopK * Unify check of k for const and non-const input * Update src/core/include/openvino/op/util/evaluate_helpers.hpp Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com> * Move `get_tensors_partial_shapes` to dev API --------- Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com>
This commit is contained in:
parent
146b0c0be8
commit
08ab7da931
@ -78,5 +78,10 @@ bool try_apply_auto_padding(const PartialShape& image_shape,
|
||||
CoordinateDiff& padding_above,
|
||||
CoordinateDiff& padding_below);
|
||||
|
||||
/// @brief Get the tensors shapes as ov::PartialShape.
|
||||
///
|
||||
/// @param tensors Input tensors vector to get their shapes.
|
||||
/// @return Vector of partial shapes same size as input tensor vector.
|
||||
OPENVINO_API std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors);
|
||||
} // namespace util
|
||||
} // namespace ov
|
||||
|
@ -36,7 +36,7 @@ public:
|
||||
/// the biggest element of two.
|
||||
/// \param sort Specifies order of output elements and/or indices
|
||||
/// Accepted values: none, index, value
|
||||
/// \param index_element_type Specyfies type of produced indices
|
||||
/// \param index_element_type Specifies type of produced indices
|
||||
TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
@ -53,9 +53,7 @@ public:
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
|
||||
protected:
|
||||
@ -83,7 +81,7 @@ public:
|
||||
/// the biggest element of two.
|
||||
/// \param sort Specifies order of output elements and/or indices
|
||||
/// Accepted values: none, index, value
|
||||
/// \param index_element_type Specyfies type of produced indices
|
||||
/// \param index_element_type Specifies type of produced indices
|
||||
TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
@ -99,9 +97,7 @@ public:
|
||||
const element::Type& index_element_type = element::i32);
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
};
|
||||
} // namespace v3
|
||||
@ -153,9 +149,7 @@ public:
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
|
||||
bool get_stable() const {
|
||||
|
@ -1,23 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
|
||||
/**
|
||||
* @brief Get the tensors shapes as ov::PartialShape.
|
||||
*
|
||||
* @param tensors Input tensors vector to get its shapes.
|
||||
* @return Vector of partial shapes sam size as input tensor vector.
|
||||
*/
|
||||
std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors);
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -8,7 +8,7 @@
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "openvino/op/topk.hpp"
|
||||
#include "openvino/op/util/attr_types.hpp"
|
||||
#include "openvino/reference/utils/coordinate_index.hpp"
|
||||
#include "openvino/reference/utils/coordinate_transform.hpp"
|
||||
|
||||
@ -17,23 +17,11 @@ namespace reference {
|
||||
// This used to be lambda expressions but MSVC had difficulty compiling it. This way is more explicit.
|
||||
template <bool D, typename T, typename U>
|
||||
inline bool compare_max(const std::tuple<T, U>& a, const std::tuple<T, U>& b) {
|
||||
// this is intentional to be able to compare floats directly
|
||||
// without using relative or absolute tolerance
|
||||
#if defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wfloat-equal"
|
||||
#endif
|
||||
if (std::get<0>(a) == std::get<0>(b)) {
|
||||
if (std::get<0>(a) != std::get<0>(b)) {
|
||||
return D ? std::get<0>(a) > std::get<0>(b) : std::get<0>(a) < std::get<0>(b);
|
||||
} else {
|
||||
return std::get<1>(a) < std::get<1>(b);
|
||||
}
|
||||
#if defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
if (D)
|
||||
return std::get<0>(a) > std::get<0>(b);
|
||||
else
|
||||
return std::get<0>(a) < std::get<0>(b);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
@ -41,63 +29,76 @@ inline bool compare_indices_ascending(const std::tuple<T, U>& a, const std::tupl
|
||||
return std::get<1>(a) < std::get<1>(b);
|
||||
}
|
||||
|
||||
// TopK reference implementation provides stable indices output
|
||||
/**
|
||||
* @brief Reference implementation fo TopK operator
|
||||
*
|
||||
* @param arg Pointer to input data.
|
||||
* @param out_indices Pointer to output indicies.
|
||||
* @param out_values Pointer to output values.
|
||||
* @param in_shape Input data shape.
|
||||
* @param out_shape Output data (values, indicies) shape.
|
||||
* @param axis Axis for search of top K elements.
|
||||
* @param k Number to find of top elements.
|
||||
* @param compute_max Select mode of find max or min.
|
||||
* @param sort Sorting type.
|
||||
*/
|
||||
template <typename T, typename U>
|
||||
void topk(const T* arg,
|
||||
U* out_indices,
|
||||
T* out_values,
|
||||
const Shape& in_shape,
|
||||
const Shape& out_shape,
|
||||
size_t axis,
|
||||
size_t k,
|
||||
bool compute_max,
|
||||
op::TopKSortType sort = op::TopKSortType::NONE) {
|
||||
using namespace std;
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool compute_max,
|
||||
const op::TopKSortType sort = op::TopKSortType::NONE) {
|
||||
// Create temp vector for sorting.
|
||||
vector<tuple<T, U>> workspace(in_shape[axis]);
|
||||
vector<size_t> in_strides = row_major_strides(in_shape);
|
||||
vector<size_t> out_strides = row_major_strides(out_shape);
|
||||
auto in_axis_stride = in_strides[axis];
|
||||
auto out_axis_stride = out_strides[axis];
|
||||
std::vector<std::tuple<T, U>> workspace(in_shape[axis]);
|
||||
const auto in_strides = row_major_strides(in_shape);
|
||||
const auto out_strides = row_major_strides(out_shape);
|
||||
const auto in_axis_stride = in_strides[axis];
|
||||
const auto out_axis_stride = out_strides[axis];
|
||||
|
||||
// Iterate over elements with 0 index at "axis" dimension
|
||||
auto traverse_shape = in_shape;
|
||||
traverse_shape[axis] = 1;
|
||||
CoordinateTransformBasic traverse_transform(traverse_shape);
|
||||
for (const Coordinate& coord : traverse_transform) {
|
||||
for (const auto& coord : traverse_transform) {
|
||||
auto arg_index = coordinate_index(coord, in_shape);
|
||||
auto out_index = coordinate_index(coord, out_shape);
|
||||
// Fill the temp vector
|
||||
U i = 0;
|
||||
for (tuple<T, U>& entry : workspace) {
|
||||
get<0>(entry) = arg[arg_index];
|
||||
get<1>(entry) = i;
|
||||
for (auto& entry : workspace) {
|
||||
std::get<0>(entry) = arg[arg_index];
|
||||
std::get<1>(entry) = i;
|
||||
arg_index += in_axis_stride;
|
||||
i++;
|
||||
++i;
|
||||
}
|
||||
// Sort the temp vector
|
||||
if (compute_max) {
|
||||
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_max<true, T, U>);
|
||||
} else {
|
||||
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_max<false, T, U>);
|
||||
}
|
||||
// Write temp vector to output
|
||||
|
||||
const auto cmp_func = compute_max ? compare_max<true, T, U> : compare_max<false, T, U>;
|
||||
|
||||
typename std::decay<decltype(cmp_func)>::type sort_func;
|
||||
switch (sort) {
|
||||
case op::TopKSortType::NONE:
|
||||
break;
|
||||
case op::TopKSortType::SORT_INDICES:
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_indices_ascending<T, U>);
|
||||
sort_func = compare_indices_ascending<T, U>;
|
||||
break;
|
||||
case op::TopKSortType::SORT_VALUES:
|
||||
if (compute_max)
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_max<true, T, U>);
|
||||
else
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_max<false, T, U>);
|
||||
sort_func = cmp_func;
|
||||
break;
|
||||
default:
|
||||
sort_func = nullptr;
|
||||
break;
|
||||
}
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
|
||||
std::nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), cmp_func);
|
||||
if (sort_func) {
|
||||
std::sort(workspace.begin(), workspace.begin() + k, sort_func);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < k; ++j) {
|
||||
const auto& entry = workspace[j];
|
||||
out_values[out_index] = get<0>(entry);
|
||||
out_indices[out_index] = get<1>(entry);
|
||||
out_values[out_index] = std::get<0>(entry);
|
||||
out_indices[out_index] = std::get<1>(entry);
|
||||
out_index += out_axis_stride;
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include "eye_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/util/evaluate_helpers.hpp"
|
||||
#include "openvino/reference/eye.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -107,7 +106,7 @@ bool Eye::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
// Inputs size and shapes checked by shape_infer
|
||||
const auto input_shapes = util::get_tensors_partial_shapes(inputs);
|
||||
const auto input_shapes = ov::util::get_tensors_partial_shapes(inputs);
|
||||
const auto output_shape = shape_infer(this, input_shapes, make_tensor_accessor(inputs)).front().to_shape();
|
||||
|
||||
int64_t diagonal_index;
|
||||
|
@ -4,163 +4,153 @@
|
||||
|
||||
#include "openvino/op/topk.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <topk_shape_inference.hpp>
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/attribute_visitor.hpp"
|
||||
#include "openvino/core/axis_vector.hpp"
|
||||
#include "openvino/core/dimension_tracker.hpp"
|
||||
#include "openvino/core/shape.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/reference/topk.hpp"
|
||||
|
||||
using namespace std;
|
||||
#include "topk_shape_inference.hpp"
|
||||
|
||||
namespace ov {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace op {
|
||||
namespace topk {
|
||||
namespace validate {
|
||||
namespace {
|
||||
template <element::Type_t INPUT_ET, element::Type_t INDEX_ET>
|
||||
inline bool evaluate_execute(const ngraph::HostTensorPtr& arg0,
|
||||
const ngraph::HostTensorPtr& out_indices,
|
||||
const ngraph::HostTensorPtr& out_values,
|
||||
const ov::Shape out_shape,
|
||||
bool data_type(const element::Type& et) {
|
||||
switch (et) {
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool k_type(const element::Type& et) {
|
||||
switch (et) {
|
||||
case element::i8:
|
||||
case element::i16:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u8:
|
||||
case element::u16:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace validate
|
||||
|
||||
struct Evaluate : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(const Tensor& in,
|
||||
Tensor& out_values,
|
||||
Tensor& out_indices,
|
||||
const Shape& out_shape,
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool compute_max,
|
||||
const op::v1::TopK::SortType sort) {
|
||||
using T = typename element_type_traits<INPUT_ET>::value_type;
|
||||
using U = typename element_type_traits<INDEX_ET>::value_type;
|
||||
const ov::Shape in_shape = arg0->get_shape();
|
||||
out_indices->set_shape(out_shape);
|
||||
out_indices->set_element_type(INDEX_ET);
|
||||
|
||||
out_values->set_shape(out_shape);
|
||||
out_values->set_element_type(arg0->get_element_type());
|
||||
|
||||
ov::reference::topk<T, U>(arg0->get_data_ptr<INPUT_ET>(),
|
||||
out_indices->get_data_ptr<INDEX_ET>(),
|
||||
out_values->get_data_ptr<INPUT_ET>(),
|
||||
in_shape,
|
||||
out_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort);
|
||||
return true;
|
||||
}
|
||||
|
||||
#define EXECUTE_EVALUATE_TOPK(a, ...) \
|
||||
case element::Type_t::a: { \
|
||||
OV_OP_SCOPE(OV_PP_CAT3(exec_topk_eval, _, a)); \
|
||||
rc = evaluate_execute<INPUT_ET, element::Type_t::a>(__VA_ARGS__); \
|
||||
} break
|
||||
|
||||
template <element::Type_t INPUT_ET>
|
||||
bool evaluate(const ngraph::HostTensorPtr& arg,
|
||||
const ngraph::HostTensorPtr& out_indices,
|
||||
const ngraph::HostTensorPtr& out_values,
|
||||
const ov::Shape out_shape,
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool max,
|
||||
const op::v1::TopK::SortType sort,
|
||||
const element::Type index_et) {
|
||||
bool rc = true;
|
||||
switch (index_et) {
|
||||
EXECUTE_EVALUATE_TOPK(i32, arg, out_indices, out_values, out_shape, axis, k, max, sort);
|
||||
EXECUTE_EVALUATE_TOPK(i64, arg, out_indices, out_values, out_shape, axis, k, max, sort);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
const TopKSortType sort) {
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<i32, i64>::apply<EvalByIdxType>(out_indices.get_element_type(),
|
||||
in.data<const T>(),
|
||||
out_values.data<T>(),
|
||||
out_indices,
|
||||
in.get_shape(),
|
||||
out_shape,
|
||||
axis,
|
||||
out_shape[axis],
|
||||
compute_max,
|
||||
sort);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool evaluate_topk(const ngraph::HostTensorPtr& arg,
|
||||
const ngraph::HostTensorPtr& out_indices,
|
||||
const ngraph::HostTensorPtr& out_values,
|
||||
const ov::Shape out_shape,
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool max,
|
||||
const op::v1::TopK::SortType sort,
|
||||
const element::Type index_et) {
|
||||
bool rc = true;
|
||||
switch (arg->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_topk, i32, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
OPENVINO_TYPE_CASE(evaluate_topk, i64, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
OPENVINO_TYPE_CASE(evaluate_topk, u32, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
OPENVINO_TYPE_CASE(evaluate_topk, u64, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
OPENVINO_TYPE_CASE(evaluate_topk, f16, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
OPENVINO_TYPE_CASE(evaluate_topk, f32, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
bool TopK_evaluate(const ov::op::util::TopKBase* const node,
|
||||
const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) {
|
||||
const auto& arg_shape = inputs[0]->get_shape();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto axis = normalize_axis(node, node->get_provided_axis(), arg_shape.size());
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
const auto compute_max = node->get_mode() == ov::op::TopKMode::MAX;
|
||||
const auto sort_type = node->get_sort_type();
|
||||
private:
|
||||
struct EvalByIdxType : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
const auto input_shapes = vector<PartialShape>{inputs[0]->get_partial_shape(), inputs[1]->get_partial_shape()};
|
||||
auto output_shape = shape_infer(node, input_shapes, ov::make_tensor_accessor(inputs)).front().to_shape();
|
||||
template <element::Type_t ET, class T, class I = fundamental_type_for<ET>>
|
||||
static result_type visit(const T* in_first,
|
||||
T* out_first,
|
||||
Tensor& out_indices,
|
||||
const Shape& in_shape,
|
||||
const Shape& out_shape,
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool compute_max,
|
||||
const TopKSortType sort) {
|
||||
reference::topk(in_first,
|
||||
out_indices.data<I>(),
|
||||
out_first,
|
||||
in_shape,
|
||||
out_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
bool evaluate(const util::TopKBase* const node, TensorVector& outputs, const TensorVector& inputs) {
|
||||
auto output_shapes = shape_infer(node, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs));
|
||||
OPENVINO_ASSERT(outputs.size() == output_shapes.size());
|
||||
|
||||
auto output_shape = output_shapes.front().get_shape();
|
||||
const auto axis = ov::util::normalize(node->get_provided_axis(), output_shape.size());
|
||||
if (output_shape[axis] == 0) {
|
||||
// the kernel can't handle K (output_shape[axis]) equal 0, use arg_shape[axis] instead.
|
||||
output_shape[axis] = arg_shape[axis];
|
||||
output_shape[axis] = inputs[0].get_shape()[axis];
|
||||
}
|
||||
|
||||
const size_t k = output_shape[axis];
|
||||
OPENVINO_ASSERT(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
|
||||
for (auto& t : outputs) {
|
||||
t.set_shape(output_shape);
|
||||
}
|
||||
|
||||
// TopK reference implementation provides stable indices output so this parameter is not passed on
|
||||
return evaluate_topk(inputs[0],
|
||||
outputs[1],
|
||||
outputs[0],
|
||||
output_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort_type,
|
||||
node->get_index_element_type());
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<f16, f32, i32, i64, u32, u64>::apply<topk::Evaluate>(inputs[0].get_element_type(),
|
||||
inputs[0],
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
output_shape,
|
||||
axis,
|
||||
(node->get_mode() == ov::op::TopKMode::MAX),
|
||||
node->get_sort_type());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace topk
|
||||
|
||||
// v1 version starts
|
||||
|
||||
op::v1::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const std::string& mode,
|
||||
const std::string& sort,
|
||||
const element::Type& index_element_type)
|
||||
namespace v1 {
|
||||
TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const std::string& mode,
|
||||
const std::string& sort,
|
||||
const element::Type& index_element_type)
|
||||
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::v1::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const Mode mode,
|
||||
const SortType sort,
|
||||
const element::Type& index_element_type)
|
||||
TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const Mode mode,
|
||||
const SortType sort,
|
||||
const element::Type& index_element_type)
|
||||
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v1::TopK::k_type_check(const element::Type& k_element_type) const {
|
||||
void TopK::k_type_check(const element::Type& k_element_type) const {
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
k_element_type == element::i8 || k_element_type == element::i32 || k_element_type == element::i64,
|
||||
@ -169,156 +159,84 @@ void op::v1::TopK::k_type_check(const element::Type& k_element_type) const {
|
||||
").");
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::TopK::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> TopK::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v1_TopK_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v1::TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
|
||||
return std::make_shared<TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
|
||||
}
|
||||
|
||||
bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool TopK::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v1_TopK_evaluate);
|
||||
return topk::TopK_evaluate(this, outputs, inputs);
|
||||
return topk::evaluate(this, outputs, inputs);
|
||||
}
|
||||
|
||||
bool op::v1::TopK::has_evaluate() const {
|
||||
bool TopK::has_evaluate() const {
|
||||
OV_OP_SCOPE(v1_TopK_has_evaluate);
|
||||
|
||||
switch (get_input_element_type(0)) {
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op::util::is_constant(input_value(1).get_node())) {
|
||||
switch (get_input_element_type(1)) {
|
||||
case element::i8:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
switch (get_input_element_type(1)) {
|
||||
case element::i8:
|
||||
case element::i16:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u8:
|
||||
case element::u16:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return topk::validate::data_type(get_input_element_type(0)) && topk::validate::k_type(get_input_element_type(1));
|
||||
}
|
||||
} // namespace v1
|
||||
|
||||
// v3 version starts
|
||||
op::v3::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const std::string& mode,
|
||||
const std::string& sort,
|
||||
const element::Type& index_element_type)
|
||||
namespace v3 {
|
||||
TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const std::string& mode,
|
||||
const std::string& sort,
|
||||
const element::Type& index_element_type)
|
||||
: TopK(data, k, axis, as_enum<Mode>(mode), as_enum<SortType>(sort), index_element_type) {}
|
||||
|
||||
op::v3::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const Mode mode,
|
||||
const SortType sort,
|
||||
const element::Type& index_element_type)
|
||||
TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const Mode mode,
|
||||
const SortType sort,
|
||||
const element::Type& index_element_type)
|
||||
: util::TopKBase{data, k, axis, mode, sort, index_element_type} {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v3::TopK::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> TopK::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v3_TopK_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v3::TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
|
||||
return std::make_shared<TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
|
||||
}
|
||||
|
||||
bool op::v3::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool TopK::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v3_TopK_evaluate);
|
||||
return topk::TopK_evaluate(this, outputs, inputs);
|
||||
return topk::evaluate(this, outputs, inputs);
|
||||
}
|
||||
|
||||
bool op::v3::TopK::has_evaluate() const {
|
||||
bool TopK::has_evaluate() const {
|
||||
OV_OP_SCOPE(v3_TopK_has_evaluate);
|
||||
|
||||
switch (get_input_element_type(0)) {
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op::util::is_constant(input_value(1).get_node())) {
|
||||
switch (get_input_element_type(1)) {
|
||||
case element::i8:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
switch (get_input_element_type(1)) {
|
||||
case element::i8:
|
||||
case element::i16:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u8:
|
||||
case element::u16:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return topk::validate::data_type(get_input_element_type(0)) && topk::validate::k_type(get_input_element_type(1));
|
||||
}
|
||||
} // namespace v3
|
||||
|
||||
// =============== V11 ===============
|
||||
ov::op::v11::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const std::string& mode,
|
||||
const std::string& sort,
|
||||
const element::Type& index_element_type,
|
||||
const bool stable)
|
||||
namespace v11 {
|
||||
TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const std::string& mode,
|
||||
const std::string& sort,
|
||||
const element::Type& index_element_type,
|
||||
const bool stable)
|
||||
: TopK(data, k, axis, as_enum<TopKMode>(mode), as_enum<TopKSortType>(sort), index_element_type, stable) {}
|
||||
|
||||
ov::op::v11::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const TopKMode mode,
|
||||
const TopKSortType sort,
|
||||
const element::Type& index_element_type,
|
||||
const bool stable)
|
||||
TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const TopKMode mode,
|
||||
const TopKSortType sort,
|
||||
const element::Type& index_element_type,
|
||||
const bool stable)
|
||||
: util::TopKBase{data, k, axis, mode, sort, index_element_type},
|
||||
m_stable{stable} {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void ov::op::v11::TopK::validate_and_infer_types() {
|
||||
void TopK::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v11_TopK_validate_and_infer_types);
|
||||
|
||||
if (m_stable) {
|
||||
@ -331,44 +249,34 @@ void ov::op::v11::TopK::validate_and_infer_types() {
|
||||
util::TopKBase::validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ov::op::v11::TopK::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool TopK::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v11_TopK_visit_attributes);
|
||||
util::TopKBase::visit_attributes(visitor);
|
||||
visitor.on_attribute("stable", m_stable);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> ov::op::v11::TopK::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> TopK::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v11_TopK_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ov::op::v11::TopK>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
m_axis,
|
||||
m_mode,
|
||||
m_sort,
|
||||
m_index_element_type,
|
||||
m_stable);
|
||||
return std::make_shared<TopK>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
m_axis,
|
||||
m_mode,
|
||||
m_sort,
|
||||
m_index_element_type,
|
||||
m_stable);
|
||||
}
|
||||
|
||||
bool ov::op::v11::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool TopK::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v11_TopK_evaluate);
|
||||
return topk::TopK_evaluate(this, outputs, inputs);
|
||||
return topk::evaluate(this, outputs, inputs);
|
||||
}
|
||||
|
||||
bool ov::op::v11::TopK::has_evaluate() const {
|
||||
bool TopK::has_evaluate() const {
|
||||
OV_OP_SCOPE(v11_TopK_has_evaluate);
|
||||
|
||||
switch (get_input_element_type(0)) {
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return topk::validate::data_type(get_input_element_type(0));
|
||||
}
|
||||
} // namespace v11
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
@ -4,8 +4,6 @@
|
||||
|
||||
#include "ngraph/op/util/evaluate_helpers.hpp"
|
||||
|
||||
#include "openvino/op/util/evaluate_helpers.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
AxisSet get_normalized_axes_from_tensor(const HostTensorPtr tensor,
|
||||
const ngraph::Rank& rank,
|
||||
@ -17,18 +15,3 @@ AxisSet get_normalized_axes_from_tensor(const HostTensorPtr tensor,
|
||||
return AxisSet{normalized_axes};
|
||||
}
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors) {
|
||||
std::vector<PartialShape> shapes;
|
||||
shapes.reserve(tensors.size());
|
||||
for (const auto& t : tensors) {
|
||||
shapes.emplace_back(t.get_shape());
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
@ -1384,5 +1384,14 @@ std::shared_ptr<Constant> get_constant_from_source(const Output<Node>& source) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors) {
|
||||
std::vector<PartialShape> shapes;
|
||||
shapes.reserve(tensors.size());
|
||||
for (const auto& t : tensors) {
|
||||
shapes.emplace_back(t.get_shape());
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
} // namespace util
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user