[core]Migrate TopK operator to new API (#20254)

* Migrate TopK to new API

* Refactor compare_max for TopK

* Unify check of k for const and non-const input

* Update src/core/include/openvino/op/util/evaluate_helpers.hpp

Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com>

* Move `get_tensors_partial_shapes` to dev API

---------

Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com>
This commit is contained in:
Pawel Raasz 2023-10-13 12:53:58 +02:00 committed by Alexander Nesterov
parent 146b0c0be8
commit 08ab7da931
8 changed files with 244 additions and 368 deletions

View File

@ -78,5 +78,10 @@ bool try_apply_auto_padding(const PartialShape& image_shape,
CoordinateDiff& padding_above,
CoordinateDiff& padding_below);
/// @brief Get the tensors shapes as ov::PartialShape.
///
/// @param tensors Input tensors vector to get their shapes.
/// @return Vector of partial shapes same size as input tensor vector.
OPENVINO_API std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors);
} // namespace util
} // namespace ov

View File

@ -36,7 +36,7 @@ public:
/// the biggest element of two.
/// \param sort Specifies order of output elements and/or indices
/// Accepted values: none, index, value
/// \param index_element_type Specyfies type of produced indices
/// \param index_element_type Specifies type of produced indices
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
@ -53,9 +53,7 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
protected:
@ -83,7 +81,7 @@ public:
/// the biggest element of two.
/// \param sort Specifies order of output elements and/or indices
/// Accepted values: none, index, value
/// \param index_element_type Specyfies type of produced indices
/// \param index_element_type Specifies type of produced indices
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
@ -99,9 +97,7 @@ public:
const element::Type& index_element_type = element::i32);
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v3
@ -153,9 +149,7 @@ public:
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
bool get_stable() const {

View File

@ -1,23 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/partial_shape.hpp"
#include "openvino/runtime/tensor.hpp"
namespace ov {
namespace op {
namespace util {
/**
* @brief Get the tensors shapes as ov::PartialShape.
*
* @param tensors Input tensors vector to get its shapes.
* @return Vector of partial shapes sam size as input tensor vector.
*/
std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors);
} // namespace util
} // namespace op
} // namespace ov

View File

@ -8,7 +8,7 @@
#include <cmath>
#include <numeric>
#include "openvino/op/topk.hpp"
#include "openvino/op/util/attr_types.hpp"
#include "openvino/reference/utils/coordinate_index.hpp"
#include "openvino/reference/utils/coordinate_transform.hpp"
@ -17,23 +17,11 @@ namespace reference {
// This used to be lambda expressions but MSVC had difficulty compiling it. This way is more explicit.
template <bool D, typename T, typename U>
inline bool compare_max(const std::tuple<T, U>& a, const std::tuple<T, U>& b) {
// this is intentional to be able to compare floats directly
// without using relative or absolute tolerance
#if defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wfloat-equal"
#endif
if (std::get<0>(a) == std::get<0>(b)) {
if (std::get<0>(a) != std::get<0>(b)) {
return D ? std::get<0>(a) > std::get<0>(b) : std::get<0>(a) < std::get<0>(b);
} else {
return std::get<1>(a) < std::get<1>(b);
}
#if defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
if (D)
return std::get<0>(a) > std::get<0>(b);
else
return std::get<0>(a) < std::get<0>(b);
}
template <typename T, typename U>
@ -41,63 +29,76 @@ inline bool compare_indices_ascending(const std::tuple<T, U>& a, const std::tupl
return std::get<1>(a) < std::get<1>(b);
}
// TopK reference implementation provides stable indices output
/**
* @brief Reference implementation fo TopK operator
*
* @param arg Pointer to input data.
* @param out_indices Pointer to output indicies.
* @param out_values Pointer to output values.
* @param in_shape Input data shape.
* @param out_shape Output data (values, indicies) shape.
* @param axis Axis for search of top K elements.
* @param k Number to find of top elements.
* @param compute_max Select mode of find max or min.
* @param sort Sorting type.
*/
template <typename T, typename U>
void topk(const T* arg,
U* out_indices,
T* out_values,
const Shape& in_shape,
const Shape& out_shape,
size_t axis,
size_t k,
bool compute_max,
op::TopKSortType sort = op::TopKSortType::NONE) {
using namespace std;
const size_t axis,
const size_t k,
const bool compute_max,
const op::TopKSortType sort = op::TopKSortType::NONE) {
// Create temp vector for sorting.
vector<tuple<T, U>> workspace(in_shape[axis]);
vector<size_t> in_strides = row_major_strides(in_shape);
vector<size_t> out_strides = row_major_strides(out_shape);
auto in_axis_stride = in_strides[axis];
auto out_axis_stride = out_strides[axis];
std::vector<std::tuple<T, U>> workspace(in_shape[axis]);
const auto in_strides = row_major_strides(in_shape);
const auto out_strides = row_major_strides(out_shape);
const auto in_axis_stride = in_strides[axis];
const auto out_axis_stride = out_strides[axis];
// Iterate over elements with 0 index at "axis" dimension
auto traverse_shape = in_shape;
traverse_shape[axis] = 1;
CoordinateTransformBasic traverse_transform(traverse_shape);
for (const Coordinate& coord : traverse_transform) {
for (const auto& coord : traverse_transform) {
auto arg_index = coordinate_index(coord, in_shape);
auto out_index = coordinate_index(coord, out_shape);
// Fill the temp vector
U i = 0;
for (tuple<T, U>& entry : workspace) {
get<0>(entry) = arg[arg_index];
get<1>(entry) = i;
for (auto& entry : workspace) {
std::get<0>(entry) = arg[arg_index];
std::get<1>(entry) = i;
arg_index += in_axis_stride;
i++;
++i;
}
// Sort the temp vector
if (compute_max) {
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_max<true, T, U>);
} else {
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_max<false, T, U>);
}
// Write temp vector to output
const auto cmp_func = compute_max ? compare_max<true, T, U> : compare_max<false, T, U>;
typename std::decay<decltype(cmp_func)>::type sort_func;
switch (sort) {
case op::TopKSortType::NONE:
break;
case op::TopKSortType::SORT_INDICES:
std::sort(workspace.begin(), workspace.begin() + k, compare_indices_ascending<T, U>);
sort_func = compare_indices_ascending<T, U>;
break;
case op::TopKSortType::SORT_VALUES:
if (compute_max)
std::sort(workspace.begin(), workspace.begin() + k, compare_max<true, T, U>);
else
std::sort(workspace.begin(), workspace.begin() + k, compare_max<false, T, U>);
sort_func = cmp_func;
break;
default:
sort_func = nullptr;
break;
}
for (size_t j = 0; j < k; j++) {
std::nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), cmp_func);
if (sort_func) {
std::sort(workspace.begin(), workspace.begin() + k, sort_func);
}
for (size_t j = 0; j < k; ++j) {
const auto& entry = workspace[j];
out_values[out_index] = get<0>(entry);
out_indices[out_index] = get<1>(entry);
out_values[out_index] = std::get<0>(entry);
out_indices[out_index] = std::get<1>(entry);
out_index += out_axis_stride;
}
}

View File

@ -8,7 +8,6 @@
#include "eye_shape_inference.hpp"
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/util/evaluate_helpers.hpp"
#include "openvino/reference/eye.hpp"
namespace ov {
@ -107,7 +106,7 @@ bool Eye::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OPENVINO_ASSERT(outputs.size() == 1);
// Inputs size and shapes checked by shape_infer
const auto input_shapes = util::get_tensors_partial_shapes(inputs);
const auto input_shapes = ov::util::get_tensors_partial_shapes(inputs);
const auto output_shape = shape_infer(this, input_shapes, make_tensor_accessor(inputs)).front().to_shape();
int64_t diagonal_index;

View File

@ -4,163 +4,153 @@
#include "openvino/op/topk.hpp"
#include <memory>
#include <topk_shape_inference.hpp>
#include "element_visitor.hpp"
#include "itt.hpp"
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/core/axis_vector.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/shape.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/reference/topk.hpp"
using namespace std;
#include "topk_shape_inference.hpp"
namespace ov {
OPENVINO_SUPPRESS_DEPRECATED_START
namespace op {
namespace topk {
namespace validate {
namespace {
template <element::Type_t INPUT_ET, element::Type_t INDEX_ET>
inline bool evaluate_execute(const ngraph::HostTensorPtr& arg0,
const ngraph::HostTensorPtr& out_indices,
const ngraph::HostTensorPtr& out_values,
const ov::Shape out_shape,
bool data_type(const element::Type& et) {
switch (et) {
case element::f16:
case element::f32:
case element::i32:
case element::i64:
case element::u32:
case element::u64:
return true;
default:
return false;
}
}
bool k_type(const element::Type& et) {
switch (et) {
case element::i8:
case element::i16:
case element::i32:
case element::i64:
case element::u8:
case element::u16:
case element::u32:
case element::u64:
return true;
default:
return false;
}
}
} // namespace
} // namespace validate
struct Evaluate : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const Tensor& in,
Tensor& out_values,
Tensor& out_indices,
const Shape& out_shape,
const size_t axis,
const size_t k,
const bool compute_max,
const op::v1::TopK::SortType sort) {
using T = typename element_type_traits<INPUT_ET>::value_type;
using U = typename element_type_traits<INDEX_ET>::value_type;
const ov::Shape in_shape = arg0->get_shape();
out_indices->set_shape(out_shape);
out_indices->set_element_type(INDEX_ET);
out_values->set_shape(out_shape);
out_values->set_element_type(arg0->get_element_type());
ov::reference::topk<T, U>(arg0->get_data_ptr<INPUT_ET>(),
out_indices->get_data_ptr<INDEX_ET>(),
out_values->get_data_ptr<INPUT_ET>(),
in_shape,
out_shape,
axis,
k,
compute_max,
sort);
return true;
}
#define EXECUTE_EVALUATE_TOPK(a, ...) \
case element::Type_t::a: { \
OV_OP_SCOPE(OV_PP_CAT3(exec_topk_eval, _, a)); \
rc = evaluate_execute<INPUT_ET, element::Type_t::a>(__VA_ARGS__); \
} break
template <element::Type_t INPUT_ET>
bool evaluate(const ngraph::HostTensorPtr& arg,
const ngraph::HostTensorPtr& out_indices,
const ngraph::HostTensorPtr& out_values,
const ov::Shape out_shape,
const size_t axis,
const size_t k,
const bool max,
const op::v1::TopK::SortType sort,
const element::Type index_et) {
bool rc = true;
switch (index_et) {
EXECUTE_EVALUATE_TOPK(i32, arg, out_indices, out_values, out_shape, axis, k, max, sort);
EXECUTE_EVALUATE_TOPK(i64, arg, out_indices, out_values, out_shape, axis, k, max, sort);
default:
rc = false;
break;
const TopKSortType sort) {
using namespace ov::element;
return IfTypeOf<i32, i64>::apply<EvalByIdxType>(out_indices.get_element_type(),
in.data<const T>(),
out_values.data<T>(),
out_indices,
in.get_shape(),
out_shape,
axis,
out_shape[axis],
compute_max,
sort);
}
return rc;
}
bool evaluate_topk(const ngraph::HostTensorPtr& arg,
const ngraph::HostTensorPtr& out_indices,
const ngraph::HostTensorPtr& out_values,
const ov::Shape out_shape,
const size_t axis,
const size_t k,
const bool max,
const op::v1::TopK::SortType sort,
const element::Type index_et) {
bool rc = true;
switch (arg->get_element_type()) {
OPENVINO_TYPE_CASE(evaluate_topk, i32, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
OPENVINO_TYPE_CASE(evaluate_topk, i64, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
OPENVINO_TYPE_CASE(evaluate_topk, u32, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
OPENVINO_TYPE_CASE(evaluate_topk, u64, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
OPENVINO_TYPE_CASE(evaluate_topk, f16, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
OPENVINO_TYPE_CASE(evaluate_topk, f32, arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
default:
rc = false;
break;
}
return rc;
}
bool TopK_evaluate(const ov::op::util::TopKBase* const node,
const HostTensorVector& outputs,
const HostTensorVector& inputs) {
const auto& arg_shape = inputs[0]->get_shape();
OPENVINO_SUPPRESS_DEPRECATED_START
const auto axis = normalize_axis(node, node->get_provided_axis(), arg_shape.size());
OPENVINO_SUPPRESS_DEPRECATED_END
const auto compute_max = node->get_mode() == ov::op::TopKMode::MAX;
const auto sort_type = node->get_sort_type();
private:
struct EvalByIdxType : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
const auto input_shapes = vector<PartialShape>{inputs[0]->get_partial_shape(), inputs[1]->get_partial_shape()};
auto output_shape = shape_infer(node, input_shapes, ov::make_tensor_accessor(inputs)).front().to_shape();
template <element::Type_t ET, class T, class I = fundamental_type_for<ET>>
static result_type visit(const T* in_first,
T* out_first,
Tensor& out_indices,
const Shape& in_shape,
const Shape& out_shape,
const size_t axis,
const size_t k,
const bool compute_max,
const TopKSortType sort) {
reference::topk(in_first,
out_indices.data<I>(),
out_first,
in_shape,
out_shape,
axis,
k,
compute_max,
sort);
return true;
}
};
};
namespace {
bool evaluate(const util::TopKBase* const node, TensorVector& outputs, const TensorVector& inputs) {
auto output_shapes = shape_infer(node, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs));
OPENVINO_ASSERT(outputs.size() == output_shapes.size());
auto output_shape = output_shapes.front().get_shape();
const auto axis = ov::util::normalize(node->get_provided_axis(), output_shape.size());
if (output_shape[axis] == 0) {
// the kernel can't handle K (output_shape[axis]) equal 0, use arg_shape[axis] instead.
output_shape[axis] = arg_shape[axis];
output_shape[axis] = inputs[0].get_shape()[axis];
}
const size_t k = output_shape[axis];
OPENVINO_ASSERT(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
for (auto& t : outputs) {
t.set_shape(output_shape);
}
// TopK reference implementation provides stable indices output so this parameter is not passed on
return evaluate_topk(inputs[0],
outputs[1],
outputs[0],
output_shape,
axis,
k,
compute_max,
sort_type,
node->get_index_element_type());
using namespace ov::element;
return IfTypeOf<f16, f32, i32, i64, u32, u64>::apply<topk::Evaluate>(inputs[0].get_element_type(),
inputs[0],
outputs[0],
outputs[1],
output_shape,
axis,
(node->get_mode() == ov::op::TopKMode::MAX),
node->get_sort_type());
}
} // namespace
} // namespace topk
// v1 version starts
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
namespace v1 {
TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
constructor_validate_and_infer_types();
}
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
constructor_validate_and_infer_types();
}
void op::v1::TopK::k_type_check(const element::Type& k_element_type) const {
void TopK::k_type_check(const element::Type& k_element_type) const {
NODE_VALIDATION_CHECK(
this,
k_element_type == element::i8 || k_element_type == element::i32 || k_element_type == element::i64,
@ -169,156 +159,84 @@ void op::v1::TopK::k_type_check(const element::Type& k_element_type) const {
").");
}
shared_ptr<Node> op::v1::TopK::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> TopK::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_TopK_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v1::TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
return std::make_shared<TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
}
bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool TopK::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_TopK_evaluate);
return topk::TopK_evaluate(this, outputs, inputs);
return topk::evaluate(this, outputs, inputs);
}
bool op::v1::TopK::has_evaluate() const {
bool TopK::has_evaluate() const {
OV_OP_SCOPE(v1_TopK_has_evaluate);
switch (get_input_element_type(0)) {
case element::i32:
case element::i64:
case element::u32:
case element::u64:
case element::f16:
case element::f32:
break;
default:
return false;
}
if (op::util::is_constant(input_value(1).get_node())) {
switch (get_input_element_type(1)) {
case element::i8:
case element::i32:
case element::i64:
break;
default:
return false;
}
} else {
switch (get_input_element_type(1)) {
case element::i8:
case element::i16:
case element::i32:
case element::i64:
case element::u8:
case element::u16:
case element::u32:
case element::u64:
break;
default:
return false;
}
}
return true;
return topk::validate::data_type(get_input_element_type(0)) && topk::validate::k_type(get_input_element_type(1));
}
} // namespace v1
// v3 version starts
op::v3::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
namespace v3 {
TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: TopK(data, k, axis, as_enum<Mode>(mode), as_enum<SortType>(sort), index_element_type) {}
op::v3::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: util::TopKBase{data, k, axis, mode, sort, index_element_type} {
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v3::TopK::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> TopK::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v3_TopK_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v3::TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
return std::make_shared<TopK>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort, m_index_element_type);
}
bool op::v3::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool TopK::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v3_TopK_evaluate);
return topk::TopK_evaluate(this, outputs, inputs);
return topk::evaluate(this, outputs, inputs);
}
bool op::v3::TopK::has_evaluate() const {
bool TopK::has_evaluate() const {
OV_OP_SCOPE(v3_TopK_has_evaluate);
switch (get_input_element_type(0)) {
case element::i32:
case element::i64:
case element::u32:
case element::u64:
case element::f16:
case element::f32:
break;
default:
return false;
}
if (op::util::is_constant(input_value(1).get_node())) {
switch (get_input_element_type(1)) {
case element::i8:
case element::i32:
case element::i64:
break;
default:
return false;
}
} else {
switch (get_input_element_type(1)) {
case element::i8:
case element::i16:
case element::i32:
case element::i64:
case element::u8:
case element::u16:
case element::u32:
case element::u64:
break;
default:
return false;
}
}
return true;
return topk::validate::data_type(get_input_element_type(0)) && topk::validate::k_type(get_input_element_type(1));
}
} // namespace v3
// =============== V11 ===============
ov::op::v11::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type,
const bool stable)
namespace v11 {
TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type,
const bool stable)
: TopK(data, k, axis, as_enum<TopKMode>(mode), as_enum<TopKSortType>(sort), index_element_type, stable) {}
ov::op::v11::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type,
const bool stable)
TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type,
const bool stable)
: util::TopKBase{data, k, axis, mode, sort, index_element_type},
m_stable{stable} {
constructor_validate_and_infer_types();
}
void ov::op::v11::TopK::validate_and_infer_types() {
void TopK::validate_and_infer_types() {
OV_OP_SCOPE(v11_TopK_validate_and_infer_types);
if (m_stable) {
@ -331,44 +249,34 @@ void ov::op::v11::TopK::validate_and_infer_types() {
util::TopKBase::validate_and_infer_types();
}
bool ov::op::v11::TopK::visit_attributes(AttributeVisitor& visitor) {
bool TopK::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v11_TopK_visit_attributes);
util::TopKBase::visit_attributes(visitor);
visitor.on_attribute("stable", m_stable);
return true;
}
std::shared_ptr<Node> ov::op::v11::TopK::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> TopK::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v11_TopK_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<ov::op::v11::TopK>(new_args.at(0),
new_args.at(1),
m_axis,
m_mode,
m_sort,
m_index_element_type,
m_stable);
return std::make_shared<TopK>(new_args.at(0),
new_args.at(1),
m_axis,
m_mode,
m_sort,
m_index_element_type,
m_stable);
}
bool ov::op::v11::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool TopK::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v11_TopK_evaluate);
return topk::TopK_evaluate(this, outputs, inputs);
return topk::evaluate(this, outputs, inputs);
}
bool ov::op::v11::TopK::has_evaluate() const {
bool TopK::has_evaluate() const {
OV_OP_SCOPE(v11_TopK_has_evaluate);
switch (get_input_element_type(0)) {
case element::i32:
case element::i64:
case element::u32:
case element::u64:
case element::f16:
case element::f32:
break;
default:
return false;
}
return true;
return topk::validate::data_type(get_input_element_type(0));
}
} // namespace v11
} // namespace op
} // namespace ov

View File

@ -4,8 +4,6 @@
#include "ngraph/op/util/evaluate_helpers.hpp"
#include "openvino/op/util/evaluate_helpers.hpp"
namespace ngraph {
AxisSet get_normalized_axes_from_tensor(const HostTensorPtr tensor,
const ngraph::Rank& rank,
@ -17,18 +15,3 @@ AxisSet get_normalized_axes_from_tensor(const HostTensorPtr tensor,
return AxisSet{normalized_axes};
}
} // namespace ngraph
namespace ov {
namespace op {
namespace util {
std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors) {
std::vector<PartialShape> shapes;
shapes.reserve(tensors.size());
for (const auto& t : tensors) {
shapes.emplace_back(t.get_shape());
}
return shapes;
}
} // namespace util
} // namespace op
} // namespace ov

View File

@ -1384,5 +1384,14 @@ std::shared_ptr<Constant> get_constant_from_source(const Output<Node>& source) {
return {};
}
}
std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors) {
std::vector<PartialShape> shapes;
shapes.reserve(tensors.size());
for (const auto& t : tensors) {
shapes.emplace_back(t.get_shape());
}
return shapes;
}
} // namespace util
} // namespace ov