Review topk for shape inference aspects (#14890)

* Review TopK for:
- label and dimension propagation
- partial value and label propagation
- preserve partial value and labels
- add evaluate upper, lower and label

* TopK v1 use shape infer instead fallback
- update static shape inference tests

* TopK shape_infer return output shapes

* Add new way to get tensor data as shape
with custom data processing
- Update tail op to use new function
- Update topk op to use this function
- Add test for negative k

* Add missing include

* Fix compilation issues

* Add support for i4 and u4 element types in
get_raw_data_as

* Fix signed and unsigned and compile warnings

* Remove constexpr from InTypeRange::operator()

* Use forward reference for functor
- minor corrections in InTypeRange class

* Use shape)infer in evaluate
- fix TopK v3 ctor for input data validation

* Fix transformation tests to use correct type for k

* Fix f16 handling in get_raw_data_as

* Correct topk bounds evaluators

* Topk detect overlap for same size dimensions
As op specification not guarantee correct order of
several elements same value

* Remove evaluate bounds
required investigation if required then will be provided

* Remove bound evaluation leftovers

* Update get const data in slice ops
This commit is contained in:
Pawel Raasz 2023-01-24 09:38:08 +01:00 committed by GitHub
parent ea776672ba
commit 9ee80d67b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1083 additions and 308 deletions

View File

@ -143,7 +143,7 @@ TEST_F(TransformationTestsF, ConvertTopKToTopKIEDynamic3) {
TEST_F(TransformationTestsF, ConvertTopKToTopKIENegative) {
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::PartialShape::dynamic());
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
function = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)},
@ -154,7 +154,7 @@ TEST_F(TransformationTestsF, ConvertTopKToTopKIENegative) {
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::PartialShape::dynamic());
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
function_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)},

View File

@ -112,6 +112,7 @@ protected:
Shape compute_output_shape(const std::string& node_description,
const PartialShape input_partial_shape,
const int64_t k) const;
virtual void k_type_check(const element::Type& k_element_type) const;
};
} // namespace v1
@ -161,6 +162,7 @@ public:
protected:
size_t read_k_from_constant_node(const std::shared_ptr<Node>& node,
const element::Type& k_element_type) const override;
void k_type_check(const element::Type& k_element_type) const override;
};
} // namespace v3
} // namespace op

View File

@ -4,6 +4,10 @@
#pragma once
#include <type_traits>
#include "openvino/core/type/float16.hpp"
namespace ov {
namespace cmp {
/** \brief Enumerate bounds to compare */
@ -73,5 +77,92 @@ public:
return _exp_value == value;
}
};
/**
* \brief Compare two integers (a < b) in safe way against lossy integer conversion.
*
* \tparam T Type of a value.
* \tparam U Type of b value.
*
* \param a Integer value.
* \param b Integer value.
*
* \return true if a less b otherwise false.
*/
template <
class T,
class U,
typename std::enable_if<(std::is_signed<T>::value && std::is_signed<U>::value) ||
(std::is_unsigned<T>::value && std::is_unsigned<U>::value) ||
// temporary to be able compare float element types
(std::is_floating_point<T>::value || std::is_floating_point<U>::value) ||
(std::is_same<T, float16>::value || std::is_same<U, float16>::value)>::type* = nullptr>
constexpr bool lt(T a, U b) noexcept {
return a < b;
}
template <class T,
class U,
typename std::enable_if<std::is_signed<T>::value && std::is_integral<T>::value &&
std::is_unsigned<U>::value>::type* = nullptr>
constexpr bool lt(T a, U b) noexcept {
return a < 0 ? true : static_cast<typename std::make_unsigned<T>::type>(a) < b;
}
template <class T,
class U,
typename std::enable_if<std::is_unsigned<T>::value && std::is_integral<U>::value &&
std::is_signed<U>::value>::type* = nullptr>
constexpr bool lt(T a, U b) noexcept {
return b < 0 ? false : a < static_cast<typename std::make_unsigned<U>::type>(b);
}
/**
* \brief Compare two integers (a > b) in safe way against lossy integer conversion.
*
* \tparam T Type of a value.
* \tparam U Type of b value.
*
* \param a Integer value.
* \param b Integer value.
*
* \return true if a > b otherwise false.
*/
template <class T, class U>
bool gt(T a, U b) noexcept {
return lt(b, a);
}
/**
* \brief Compare two integers (a <= b) in safe way against lossy integer conversion.
*
* \tparam T Type of a value.
* \tparam U Type of b value.
*
* \param a Integer value.
* \param b Integer value.
*
* \return true if a <= b otherwise false.
*/
template <class T, class U>
bool le(T a, U b) noexcept {
return !gt(a, b);
}
/**
* \brief Compare two integers (a >= b) in safe way against lossy integer conversion.
*
* \tparam T Type of a value.
* \tparam U Type of b value.
*
* \param a Integer value.
* \param b Integer value.
*
* \return true if a >= b otherwise false.
*/
template <class T, class U>
bool ge(T a, U b) noexcept {
return !lt(a, b);
}
} // namespace cmp
} // namespace ov

View File

@ -0,0 +1,58 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "compare.hpp"
#include "openvino/core/except.hpp"
namespace ov {
namespace sh_infer {
namespace tr {
/**
* \brief Trnsform tensor data by cast them to type T
*
* \tparam T Type of returned value.
*/
template <class T>
struct Cast {
constexpr Cast() = default;
template <class U>
constexpr T operator()(const U u) const {
return static_cast<T>(u);
}
};
/**
* \brief Check if input data is in [T::min(), T::max()] and then cast it to T.
*
* \tparam T Type of returned value and used to specified min, max of valid value range.
*
* \throws ov::AssertFailure if input value not in type range.
*/
template <class T>
struct InTypeRange {
const std::pair<T, T> m_range{};
constexpr InTypeRange() : m_range{std::numeric_limits<T>::min(), std::numeric_limits<T>::max()} {};
template <class U>
T operator()(const U u) const {
OPENVINO_ASSERT(cmp::le(m_range.first, u) && cmp::le(u, m_range.second),
"Value ",
u,
" not in range [",
m_range.first,
":",
m_range.second,
"]");
return static_cast<T>(u);
}
};
} // namespace tr
} // namespace sh_infer
} // namespace ov

View File

@ -94,10 +94,12 @@ void shape_infer(const Slice* op,
return;
}
constexpr auto cast_i64 = sh_infer::tr::Cast<int64_t>();
// compute constant values of begin, end, and strides if possible
const auto start = slice::get_input_bounds<T>(op, 1, constant_data);
const auto stop = slice::get_input_bounds<T>(op, 2, constant_data);
const auto steps = get_input_const_data_as<T, int64_t>(op, 3, constant_data);
const auto steps = get_input_const_data_as<T, int64_t>(op, 3, constant_data, cast_i64);
slice::AxesMap axes_map;
if (input_shapes.size() > 4) {
@ -105,7 +107,7 @@ void shape_infer(const Slice* op,
input_shapes[4].compatible(start_shape),
"Slice `axes` input must have compatible shape with `start`, `stop`, `step` inputs.");
if (auto axes = get_input_const_data_as<T, int64_t>(op, 4, constant_data)) {
if (auto axes = get_input_const_data_as<T, int64_t>(op, 4, constant_data, cast_i64)) {
ov::normalize_axes(op, input_shape.rank().get_length(), *axes);
axes_map.add(*axes);
NODE_VALIDATION_CHECK(op, axes_map.is_valid, "Slice values in `axes` input must be unique.");

View File

@ -233,7 +233,8 @@ std::unique_ptr<TResult> get_input_bounds(const ov::Node* op,
};
std::unique_ptr<TResult> out;
if (auto lowers = op::get_input_const_data_as<TShape, int64_t>(op, idx, constant_data)) {
if (auto lowers =
op::get_input_const_data_as<TShape, int64_t>(op, idx, constant_data, sh_infer::tr::Cast<int64_t>())) {
const auto& et = get_input_const_element_type(op, idx, constant_data);
out.reset(new TResult(make_bounds_vec(et, *lowers, *lowers)));
} else {

View File

@ -64,7 +64,7 @@ void shape_infer(const StridedSlice* op,
std::unique_ptr<std::vector<int64_t>> strides;
if (input_shapes.size() > 3) {
strides = get_input_const_data_as<T, int64_t>(op, 3, constant_data);
strides = get_input_const_data_as<T, int64_t>(op, 3, constant_data, sh_infer::tr::Cast<int64_t>());
} else if (begin) {
// generate default strides
strides.reset(new std::vector<int64_t>(begin->size(), 1));

View File

@ -4,6 +4,7 @@
#pragma once
#include <openvino/op/tile.hpp>
#include "shape_infer_transformations.hpp"
#include "utils.hpp"
namespace ov {
@ -15,7 +16,8 @@ void shape_infer(const Tile* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using TDim = typename std::iterator_traits<typename T::iterator>::value_type;
using TDim = typename T::value_type;
using TDimValue = typename TDim::value_type;
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
@ -26,32 +28,25 @@ void shape_infer(const Tile* op,
auto& output_shape = output_shapes[0];
// Get repeats and pre process values
T repeats;
bool has_repeats;
if (auto rep_data = get_input_const_data_as<T, int64_t>(op, 1, constant_data)) {
// set negatives repeats to 0
repeats.resize(rep_data->size());
std::transform(rep_data->begin(), rep_data->end(), repeats.begin(), [](int64_t r) -> TDim {
return {static_cast<typename TDim::value_type>(std::max(static_cast<int64_t>(0), r))};
});
has_repeats = true;
} else {
has_repeats = get_data_as_shape(1, op, repeats);
}
auto negative_repeats_to_zero = [](const TDimValue v) -> TDimValue {
return std::max<TDimValue>(0, sh_infer::tr::InTypeRange<TDimValue>()(v));
};
auto repeats = get_input_const_data_as_shape<T>(op, 1, constant_data, negative_repeats_to_zero);
const auto& arg_rank = arg_shape.rank();
if (arg_rank.is_static() && has_repeats) {
const auto output_rank = std::max(arg_shape.size(), repeats.size());
if (arg_rank.is_static() && repeats) {
const auto output_rank = std::max(arg_shape.size(), repeats->size());
std::vector<TDim> dims;
dims.reserve(output_rank);
// add missing repeats
repeats.insert(repeats.begin(), output_rank - repeats.size(), TDim{1});
repeats->insert(repeats->begin(), output_rank - repeats->size(), TDim{1});
// insert missing input dimensions
auto rep_it = std::next(repeats.begin(), output_rank - arg_shape.size());
dims.insert(dims.begin(), repeats.begin(), rep_it);
auto rep_it = std::next(repeats->begin(), output_rank - arg_shape.size());
dims.insert(dims.begin(), repeats->begin(), rep_it);
// calc repeated output dimensions
std::transform(arg_shape.begin(), arg_shape.end(), rep_it, std::back_inserter(dims), std::multiplies<TDim>());

View File

@ -13,12 +13,49 @@ namespace ov {
namespace op {
namespace v1 {
template <typename T>
void shape_infer(const TopK* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 && output_shapes.size() == 2));
// Helper to get correct K from tensor as shape.
template <class T>
struct GetK {
const TopK* m_op;
GetK(const TopK* op) : m_op{op} {}
template <class K>
T operator()(const K k) const {
NODE_VALIDATION_CHECK(m_op,
cmp::ge(k, 0) && cmp::le(k, std::numeric_limits<T>::max()),
"The value of 'K' must be more or equal zero.",
" (got ",
k,
").");
return static_cast<T>(k);
}
};
/**
* \brief TopK shape inference
*
* \tparam TShape Type of shape.
*
* \param op Pointer to TopK operator.
* \param input_shapes Input shapes of TopK.
* \param constant_data Map of constant data. DEfault empty.
*
* \return Vector of output shapes for
*/
template <class TShape>
std::vector<TShape> shape_infer(const TopK* op,
const std::vector<TShape>& input_shapes,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
using TDim = typename TShape::value_type;
using TDimValue = typename TDim::value_type;
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2);
const auto& idx_element_type = op->get_index_element_type();
NODE_VALIDATION_CHECK(op,
idx_element_type == element::i32 || idx_element_type == element::i64,
"Index element type attribute should be either \'i32\' or \'i64\'. Got: ",
idx_element_type);
const auto& input_shape = input_shapes[0];
const auto input_rank = input_shape.rank();
@ -31,47 +68,58 @@ void shape_infer(const TopK* op,
auto output_shape = input_shape;
if (input_shape.rank().is_static()) {
T k_as_shape;
auto input_rank = static_cast<int64_t>(input_shape.size());
auto normalized_axis = ov::normalize_axis(op, op->get_provided_axis(), input_rank, -input_rank, input_rank - 1);
const auto normalized_axis = ov::normalize_axis(op, op->get_provided_axis(), input_shape.rank());
auto& dim_axis = output_shape[normalized_axis];
if (get_data_as_shape<T>(1, op, k_as_shape, constant_data)) {
if (auto k_as_shape = get_input_const_data_as_shape<TShape>(op, 1, constant_data, GetK<TDimValue>(op))) {
NODE_VALIDATION_CHECK(op,
k_as_shape.size() == 1,
k_as_shape->size() == 1,
"Only one value (scalar) should be provided as the 'K' input to TopK",
" (got ",
k_as_shape.size(),
k_as_shape->size(),
" elements).");
if (k_as_shape[0].is_static()) {
NODE_VALIDATION_CHECK(op,
k_as_shape[0].get_max_length() >= 0,
"The value of 'K' must not be a negative number.",
" (got ",
k_as_shape[0].get_max_length(),
").");
dim_axis = k_as_shape[0].get_length();
const auto& k = (*k_as_shape)[0];
if (k.is_static()) {
dim_axis = k;
} else {
// in this dynamic branch we are sure of dim_axis's type
const auto in_min = dim_axis.get_min_length();
const auto in_max = dim_axis.get_max_length();
const auto k_min = k_as_shape[0].get_min_length();
const auto k_max = k_as_shape[0].get_max_length();
const auto k_min = k.get_min_length();
const auto k_max = k.get_max_length();
const auto lower = std::min<Dimension::value_type>(in_min, k_min);
const auto lower = std::min<TDimValue>(in_min, k_min);
const auto upper =
in_max < 0 ? Dimension::dynamic().get_max_length() : std::max<Dimension::value_type>(in_max, k_max);
dim_axis = Dimension(lower, upper);
in_max < 0 ? Dimension::dynamic().get_max_length() : std::max<TDimValue>(in_max, k_max);
dim_axis = TDim(lower, upper);
}
} else {
dim_axis = Dimension(0, dim_axis.get_max_length());
dim_axis = TDim(0, dim_axis.get_max_length());
}
}
output_shapes[0] = output_shape;
output_shapes[1] = output_shape;
} // namespace
return std::vector<TShape>(2, output_shape);
}
/**
* \brief TopK shape inference
*
* \tparam TShape Type of shape.
*
* \param op Pointer to TopK operator.
* \param input_shapes Input shapes of TopK.
* \param output_shapes Output shapes of TopK
* \param constant_data Map of constant data. Default empty.
*/
template <typename T>
void shape_infer(const TopK* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
output_shapes = shape_infer(op, input_shapes, constant_data);
}
} // namespace v1
} // namespace op
} // namespace ov

View File

@ -3,8 +3,12 @@
//
#pragma once
#include <iterator>
#include <ngraph/validation_util.hpp>
#include <openvino/opsets/opset1.hpp>
#include <type_traits>
#include "shape_infer_transformations.hpp"
template <class OpType, class T>
void copy_shape_infer(const OpType* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
@ -44,8 +48,156 @@ void eltwise_shape_infer(const OpType* op, const std::vector<T>& input_shapes, s
}
namespace ov {
namespace op {
/**
* \brief Get the raw data as TResult object.
*
* \tparam T TResult data type.
* \tparam TResult Type of return object, must support creation of std::inserter. Default std::vector<T>.
* \tparam UnaryOperation Unary function object applied on data with signature (T f(const U u)).
*
* \param et Element type of input data.
* \param ptr Pointer to data of type et.
* \param size Data size as number of elements.
* \param func Unary operation function object.
*
* \throws ov::AssertionFailure for not supported element type.
* \return Object of TResult with data from input pointer and transformed by unary operation.
*/
template <class T, class TResult = std::vector<T>, class UnaryOperation>
TResult get_raw_data_as(const element::Type_t et, const void* const ptr, const size_t size, UnaryOperation&& func) {
TResult out;
auto out_it = std::inserter(out, out.end());
switch (et) {
case element::Type_t::i4: {
using dtype = fundamental_type_for<element::Type_t::i4>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i8: {
using dtype = fundamental_type_for<element::Type_t::i8>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i16: {
using dtype = fundamental_type_for<element::Type_t::i16>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i32: {
using dtype = fundamental_type_for<element::Type_t::i32>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i64: {
using dtype = fundamental_type_for<element::Type_t::i64>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u4: {
using dtype = fundamental_type_for<element::Type_t::u4>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u8: {
using dtype = fundamental_type_for<element::Type_t::u8>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u16: {
using dtype = fundamental_type_for<element::Type_t::u16>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u32: {
using dtype = fundamental_type_for<element::Type_t::u32>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u64: {
using dtype = fundamental_type_for<element::Type_t::u64>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::f16: {
using dtype = fundamental_type_for<element::Type_t::f16>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::f32: {
using dtype = fundamental_type_for<element::Type_t::f32>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
default:
OPENVINO_ASSERT(false, "Not supported element type ", et);
};
return out;
}
/**
* \brief Get data from Host tensor as object TResult.
*
* \tparam T TResult data type.
* \tparam TResult Type of return object, must support creation of std::inserter. Default std::vector<T>.
* \tparam UnaryOperation Unary function object applied on data with signature (T f(const U u)).
*
* \param tv Input host tensor.
* \param func Unary operation function object.
*
* \return Object of TResult with data from host tensor.
*/
template <class T, class TResult = std::vector<T>, class UnaryOperation>
TResult get_tensor_data_as(HostTensor& tv, UnaryOperation&& func) {
auto t = Tensor(tv.get_element_type(), tv.get_shape(), tv.get_data_ptr());
return get_tensor_data_as<T, TResult>(t, std::forward<UnaryOperation>(func));
}
/**
* \brief Get data from ov:tensor as object TResult.
*
* \tparam T TResult data type.
* \tparam TResult Type of return object, must support creation of std::inserter. Default std::vector<T>.
* \tparam UnaryOperation Unary function object applied on data with signature (T f(const U u)).
*
* \param t Input tensor.
* \param func Unary operation function object.
*
* \return Object of TResult with data from tensor.
*/
template <class T, class TResult = std::vector<T>, class UnaryOperation>
TResult get_tensor_data_as(const Tensor& t, UnaryOperation&& func) {
return get_raw_data_as<T, TResult>(t.get_element_type(),
t.data(),
t.get_size(),
std::forward<UnaryOperation>(func));
}
namespace op {
/**
* \brief Get the operator's input const as pointer to vector of specified type.
*
@ -55,26 +207,36 @@ namespace op {
* \tparam TShape Shape type which enabled this version (not ov::PartialShape)
* \tparam TData Type use to cast input's data.
* \tparam TRes Result type which has got default type as std::vector<TData>.
* \tparam UnaryOperation Unary function object applied on data with signature (Ret f(const TData &a)).
*
* \param op Pointer to operator.
* \param idx Operator's input number.
* \param constant_data Map with constant. Default empty.
* \param func Unary operation function object.
*
* \return Pointer to constant data or nullptr if input has no constant data.
*/
template <class TShape,
class TData,
class TRes = std::vector<TData>,
class UnaryOperation,
typename std::enable_if<!std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
size_t idx,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
const std::map<size_t, HostTensorPtr>& constant_data = {},
UnaryOperation&& func = sh_infer::tr::Cast<TData>()) {
if (constant_data.count(idx)) {
return std::unique_ptr<TRes>(new TRes(ov::opset1::Constant(constant_data.at(idx)).cast_vector<TData>()));
return std::unique_ptr<TRes>(
new TRes(get_tensor_data_as<TData, TRes>(*constant_data.at(idx), std::forward<UnaryOperation>(func))));
} else {
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
return std::unique_ptr<TRes>(new TRes(constant->cast_vector<TData>()));
const auto& et = constant->get_element_type();
const auto& shape = constant->get_shape();
return std::unique_ptr<TRes>(new TRes(get_raw_data_as<TData, TRes>(et,
constant->get_data_ptr(),
shape_size(shape),
std::forward<UnaryOperation>(func))));
}
}
@ -87,29 +249,76 @@ std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
* \tparam TShape Shape type which enabled this version (ov::PartialShape)
* \tparam TData Type use to cast input's data.
* \tparam TRes Result type which has got default type as std::vector<TData>.
* \tparam UnaryOperation Unary function object applied on data with signature (Ret f(const TData &a)).
*
* \param op Pointer to operator.
* \param idx Operator's input number.
* \param constant_data Map with constant. Default empty.
* \param func Unary operation function object.
*
* \return Pointer to constant data or nullptr if input has no constant data.
*/
template <class TShape,
class TData,
class TRes = std::vector<TData>,
class UnaryOperation,
typename std::enable_if<std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
size_t idx,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
const std::map<size_t, HostTensorPtr>& constant_data = {},
UnaryOperation&& func = sh_infer::tr::Cast<TData>()) {
if (constant_data.count(idx)) {
return std::unique_ptr<TRes>(new TRes(ov::opset1::Constant(constant_data.at(idx)).cast_vector<TData>()));
return std::unique_ptr<TRes>(
new TRes(get_tensor_data_as<TData, TRes>(*constant_data.at(idx), std::forward<UnaryOperation>(func))));
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
return std::unique_ptr<TRes>(new TRes(constant->cast_vector<TData>()));
const auto& et = constant->get_element_type();
const auto& shape = constant->get_shape();
return std::unique_ptr<TRes>(new TRes(get_raw_data_as<TData, TRes>(et,
constant->get_data_ptr(),
shape_size(shape),
std::forward<UnaryOperation>(func))));
} else {
return {};
}
}
/**
* \brief Get the input const data as shape object.
*
* The input data can be processed by unary operation. By default is validated and casted to shape's dimension type.
*
* \tparam TShape
* \tparam UnaryOperation Unary function object applied on data with signature (Ret f(const TDimValue &a)).
*
* \param op Pointer to operator.
* \param idx Operator input index.
* \param constant_data Map with constant data. Default empty.
* \param func Unary operation function object to apply in input data.
* Default sh_infer::tr::InTypeRange<TDimValue>.
*
* \return Unique pointer to shape created from input data.
*/
template <class TShape,
class TDimValue = typename TShape::value_type::value_type,
class UnaryOperation = sh_infer::tr::InTypeRange<TDimValue>>
std::unique_ptr<TShape> get_input_const_data_as_shape(
const ov::Node* op,
size_t idx,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {},
UnaryOperation&& func = sh_infer::tr::InTypeRange<TDimValue>()) {
std::unique_ptr<TShape> shape_ptr;
if (auto d =
get_input_const_data_as<TShape, TDimValue>(op, idx, constant_data, std::forward<UnaryOperation>(func))) {
shape_ptr.reset(new TShape(std::move(*d)));
} else {
PartialShape shape;
if (ov::evaluate_as_partial_shape(op->input_value(idx), shape)) {
shape_ptr.reset(new TShape(std::move(shape)));
}
}
return shape_ptr;
}
} // namespace op
} // namespace ov
@ -119,7 +328,8 @@ inline bool get_data_as(const ov::Node* op,
size_t idx,
std::vector<TData>& data_out,
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
if (auto out = ov::op::get_input_const_data_as<TShape, TData>(op, idx, constant_data)) {
if (auto out =
ov::op::get_input_const_data_as<TShape, TData>(op, idx, constant_data, ov::sh_infer::tr::Cast<TData>())) {
data_out = std::move(*out);
return true;
} else {
@ -163,8 +373,9 @@ inline bool get_data_as_shape(size_t idx,
const ov::Node* op,
TShape& shape,
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
// Note, assumes that get_input_const_data_as throws exception for TShape different then ov::PartialShape.
shape = *ov::op::get_input_const_data_as<TShape, size_t, TShape>(op, idx, constant_data);
using TDimValue = typename TShape::value_type::value_type;
shape = std::move(
*ov::op::get_input_const_data_as_shape<TShape>(op, idx, constant_data, ov::sh_infer::tr::Cast<TDimValue>()));
return true;
}

View File

@ -7,6 +7,7 @@
#include <memory>
#include <topk_shape_inference.hpp>
#include "dimension_tracker.hpp"
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_vector.hpp"
@ -103,44 +104,11 @@ bool evaluate_topk(const HostTensorPtr& arg,
}
return rc;
}
template <element::Type_t K_ET>
size_t get_k_from_hosttensor(const HostTensorPtr& arg) {
using T = typename element_type_traits<K_ET>::value_type;
auto p = arg->get_data_ptr<T>();
size_t k = p[0];
return k;
}
#define CASE_GET_K(a, ...) \
case element::Type_t::a: { \
OV_OP_SCOPE(OV_PP_CAT3(topk_get_k, _, a)); \
k = get_k_from_hosttensor<element::Type_t::a>(__VA_ARGS__); \
} break
size_t read_k_from_host_tensor(const HostTensorPtr& arg_k) {
size_t k = 0;
switch (arg_k->get_element_type()) {
CASE_GET_K(i8, arg_k);
CASE_GET_K(i16, arg_k);
CASE_GET_K(i32, arg_k);
CASE_GET_K(i64, arg_k);
CASE_GET_K(u8, arg_k);
CASE_GET_K(u16, arg_k);
CASE_GET_K(u32, arg_k);
CASE_GET_K(u64, arg_k);
default:
// other types are not supported and would have thrown in ctor
ngraph_error("read_k_from_host_tensor: type is not integral\n");
break;
}
return k;
}
} // namespace
} // namespace topk
// v1 version starts
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
constexpr auto UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
@ -148,15 +116,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: Op{{data, k}},
m_axis{axis},
m_normalized_axis{UNKNOWN_NORMALIZED_AXIS},
m_mode{as_enum<Mode>(mode)},
m_sort{as_enum<SortType>(sort)},
m_index_element_type{index_element_type} {
ov::mark_as_precision_sensitive(input(1));
constructor_validate_and_infer_types();
}
: TopK(data, k, axis, as_enum<Mode>(mode), as_enum<SortType>(sort), index_element_type) {}
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
@ -186,23 +146,12 @@ bool ngraph::op::v1::TopK::visit_attributes(AttributeVisitor& visitor) {
void op::v1::TopK::validate_and_infer_types() {
OV_OP_SCOPE(v1_TopK_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
m_index_element_type == element::i32 || m_index_element_type == element::i64,
"Index element type attribute should be either \'i32\' or \'i64\'. Got: ",
m_index_element_type);
if (ov::op::util::is_constant(input_value(1).get_node())) {
// Check k value
read_k_from_constant_node(input_value(1).get_node_shared_ptr(), get_input_element_type(1));
}
k_type_check(get_input_element_type(1));
set_axis(get_input_partial_shape(0).rank(), get_provided_axis());
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}};
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0), get_input_partial_shape(1)};
shape_infer(this, input_shapes, output_shapes);
const auto output_shapes = shape_infer(this, get_node_input_partial_shapes(*this));
set_output_size(2);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
set_output_type(1, m_index_element_type, output_shapes[1]);
}
@ -211,33 +160,17 @@ ov::Shape op::v1::TopK::compute_output_shape(const std::string& node_description
const ov::PartialShape input_partial_shape,
const int64_t k) const {
ov::PartialShape output_shape{input_partial_shape};
auto normalized_axis = ngraph::normalize_axis(node_description, m_axis, output_shape.rank());
if (k != 0) {
output_shape[normalized_axis] = k;
} else {
output_shape[normalized_axis] = input_partial_shape[normalized_axis];
}
const auto normalized_axis = ngraph::normalize_axis(node_description, m_axis, output_shape.rank());
output_shape[normalized_axis] = (k != 0) ? k : input_partial_shape[normalized_axis];
return output_shape.get_shape();
}
void op::v1::TopK::set_axis(const int64_t axis) {
const auto input_rank = get_input_partial_shape(0).rank();
if (input_rank.is_static()) {
m_normalized_axis = ngraph::normalize_axis(this, axis, input_rank);
} else {
m_normalized_axis = UNKNOWN_NORMALIZED_AXIS;
}
m_axis = axis;
set_axis(get_input_partial_shape(0).rank(), axis);
}
void op::v1::TopK::set_axis(const Rank& input_rank, const int64_t axis) {
if (input_rank.is_static()) {
m_normalized_axis = ngraph::normalize_axis(this, axis, input_rank);
} else {
m_normalized_axis = UNKNOWN_NORMALIZED_AXIS;
}
m_normalized_axis = input_rank.is_static() ? normalize_axis(this, axis, input_rank) : UNKNOWN_NORMALIZED_AXIS;
m_axis = axis;
}
@ -247,14 +180,18 @@ uint64_t op::v1::TopK::get_axis() const {
return m_normalized_axis;
}
size_t op::v1::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
const element::Type& k_element_type) const {
void op::v1::TopK::k_type_check(const element::Type& k_element_type) const {
NODE_VALIDATION_CHECK(
this,
k_element_type == element::i8 || k_element_type == element::i32 || k_element_type == element::i64,
"K input element type must be i8, i32 or i64 (got ",
k_element_type,
").");
}
size_t op::v1::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
const element::Type& k_element_type) const {
k_type_check(k_element_type);
const auto k_constant = ov::as_type_ptr<op::v0::Constant>(node);
@ -325,29 +262,24 @@ void op::v1::TopK::set_k(size_t k) {
bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
OV_OP_SCOPE(v1_TopK_evaluate);
ov::Shape arg_shape = inputs[0]->get_shape();
const auto& arg_shape = inputs[0]->get_shape();
// 1. get axis, mode (max/min), sort_type
size_t axis = ngraph::normalize_axis(this, m_axis, arg_shape.size());
bool compute_max = get_mode() == TopKMode::MAX ? true : false;
SortType sort_type = get_sort_type();
auto axis = ngraph::normalize_axis(this, m_axis, arg_shape.size());
auto compute_max = get_mode() == TopKMode::MAX;
auto sort_type = get_sort_type();
// 2. get value of k - from constant node or from HT
size_t k = 0;
if (op::util::is_constant(input_value(1).get_node())) {
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(), get_input_element_type(1));
NGRAPH_CHECK(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
} else {
k = topk::read_k_from_host_tensor(inputs[1]);
const auto input_shapes = std::vector<PartialShape>{inputs[0]->get_partial_shape(), inputs[1]->get_partial_shape()};
const auto constant_data = std::map<size_t, HostTensorPtr>{{1, inputs[1]}};
auto output_shape = shape_infer(this, input_shapes, constant_data).front().to_shape();
if (output_shape[axis] == 0) {
// the kernel can't handle K (output_shape[axis]) equal 0, use arg_shape[axis] instead.
output_shape[axis] = arg_shape[axis];
}
// 3. Compute output_shape
auto output_shape = compute_output_shape(this->description(), inputs[0]->get_shape(), k);
// do this after compute_output_shape
if (k == 0) {
// the kernel can't handle k = 0, but output_shape[axis] = arg_shape[axis]
k = arg_shape[axis];
}
// 2. get value of k
size_t k = output_shape[axis];
OPENVINO_ASSERT(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
return topk::evaluate_topk(inputs[0],
outputs[1],
@ -410,9 +342,7 @@ op::v3::TopK::TopK(const Output<Node>& data,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: op::v1::TopK{data, k, axis, mode, sort, index_element_type} {
constructor_validate_and_infer_types();
}
: TopK(data, k, axis, as_enum<Mode>(mode), as_enum<SortType>(sort), index_element_type) {}
op::v3::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
@ -420,7 +350,14 @@ op::v3::TopK::TopK(const Output<Node>& data,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: op::v1::TopK{data, k, axis, mode, sort, index_element_type} {
: op::v1::TopK{} {
set_arguments(OutputVector{data, k});
m_axis = axis;
m_normalized_axis = UNKNOWN_NORMALIZED_AXIS;
m_mode = mode;
m_sort = sort;
m_index_element_type = index_element_type;
ov::mark_as_precision_sensitive(input(1));
constructor_validate_and_infer_types();
}
@ -435,15 +372,21 @@ bool ngraph::op::v3::TopK::visit_attributes(AttributeVisitor& visitor) {
void op::v3::TopK::validate_and_infer_types() {
OV_OP_SCOPE(v3_TopK_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_integral_number(),
"K input has to be an integer type, which does match the provided one:",
get_input_element_type(1));
op::v1::TopK::validate_and_infer_types();
k_type_check(get_input_element_type(1));
set_axis(get_input_partial_shape(0).rank(), get_provided_axis());
const auto output_shapes = shape_infer(this, get_node_input_partial_shapes(*this));
set_output_type(0, get_input_element_type(0), output_shapes[0]);
set_output_type(1, m_index_element_type, output_shapes[1]);
}
size_t op::v3::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
const element::Type& k_element_type) const {
k_type_check(k_element_type);
const auto k_constant = ov::as_type_ptr<op::v0::Constant>(node);
size_t k = 0;
@ -480,6 +423,13 @@ size_t op::v3::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
return k;
}
void op::v3::TopK::k_type_check(const element::Type& k_element_type) const {
NODE_VALIDATION_CHECK(this,
k_element_type.is_integral_number(),
"K input has to be an integer type, which does match the provided one:",
k_element_type);
}
shared_ptr<Node> op::v3::TopK::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v3_TopK_clone_with_new_inputs);
check_new_args_count(this, new_args);

View File

@ -1419,7 +1419,7 @@ TEST(eval, topk_v1_dyn) {
Shape shape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto k = make_shared<op::Parameter>(element::u32, Shape{});
auto k = make_shared<op::Parameter>(element::i32, Shape{});
auto B = make_shared<op::v1::TopK>(A, k, 1, "max", "index", element::i32);
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});

View File

@ -2,102 +2,400 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "common_test_utils/test_assertions.hpp"
#include "dimension_tracker.hpp"
#include "openvino/opsets/opset10.hpp"
#include "topk_shape_inference.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace ov;
using namespace ov::opset10;
using namespace testing;
// Since v3::TopK is backward compatible with v1::TopK all of these tests should pass
template <typename T>
class topk_type_prop : public ::testing::Test {};
class topk_type_prop : public TypePropOpTest<T> {
protected:
PartialShapes make_broadcast_shapes_of_topk_outs(T* topk) {
PartialShapes bcs_outputs;
for (size_t i = 0; i < topk->get_output_size(); ++i) {
auto bc = std::make_shared<Broadcast>(std::make_shared<Parameter>(element::i64, PartialShape{1}),
topk->output(i),
"BIDIRECTIONAL");
bcs_outputs.push_back(bc->get_output_partial_shape(0));
}
return bcs_outputs;
}
element::Type exp_default_idx_type{element::i32};
};
TYPED_TEST_SUITE_P(topk_type_prop);
TYPED_TEST_P(topk_type_prop, topk_negative_axis_support) {
const auto data_shape = Shape{1, 2, 3, 4};
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto k = op::Constant::create(element::i64, Shape{}, {2});
TYPED_TEST_P(topk_type_prop, default_ctor) {
constexpr int64_t exp_axis = -2;
constexpr auto exp_idx_type = element::i64;
constexpr auto exp_data_type = element::f32;
const auto data = std::make_shared<Parameter>(exp_data_type, Shape{1, 2, 3, 4});
const auto k = Constant::create(element::i64, Shape{}, {2});
const auto op = this->make_op();
op->set_arguments(OutputVector{data, k});
op->set_axis(exp_axis);
op->set_index_element_type(exp_idx_type);
op->set_mode(op::TopKMode::MIN);
op->set_sort_type(op::TopKSortType::SORT_INDICES);
op->validate_and_infer_types();
EXPECT_EQ(op->get_provided_axis(), exp_axis);
EXPECT_EQ(op->get_axis(), 2);
EXPECT_EQ(op->get_input_size(), 2);
EXPECT_EQ(op->get_output_size(), 2);
EXPECT_EQ(op->get_mode(), op::TopKMode::MIN);
EXPECT_EQ(op->get_sort_type(), op::TopKSortType::SORT_INDICES);
EXPECT_THAT(op->outputs(),
ElementsAre(Property("Value type", &Output<Node>::get_element_type, exp_data_type),
Property("Index type", &Output<Node>::get_element_type, exp_idx_type)));
EXPECT_THAT(op->outputs(), Each(Property("Shape", &Output<Node>::get_shape, Shape({1, 2, 2, 4}))));
}
TYPED_TEST_P(topk_type_prop, default_ctor_no_arguments) {
constexpr int64_t exp_axis = 3;
const auto data_shape = PartialShape{1, {3, 4}, 4, {2, 6}};
int64_t k = 3;
const auto op = this->make_op();
op->set_axis(data_shape.rank(), exp_axis);
op->set_mode(op::TopKMode::MIN);
op->set_sort_type(op::TopKSortType::SORT_INDICES);
const auto constant_map =
std::map<size_t, HostTensorPtr>{{1, std::make_shared<HostTensor>(element::i64, Shape{}, &k)}};
const auto outputs = op::v1::shape_infer(op.get(), PartialShapes{data_shape, {}}, constant_map);
EXPECT_EQ(op->get_provided_axis(), exp_axis);
EXPECT_EQ(op->get_axis(), exp_axis);
EXPECT_EQ(op->get_input_size(), 0);
EXPECT_EQ(op->get_output_size(), 0);
EXPECT_EQ(op->get_mode(), op::TopKMode::MIN);
EXPECT_EQ(op->get_sort_type(), op::TopKSortType::SORT_INDICES);
EXPECT_THAT(op->outputs(),
Each(Property("Partial shape", &Output<Node>::get_partial_shape, PartialShape({1, {3, 4}, 4, 3}))));
}
TYPED_TEST_P(topk_type_prop, negative_axis_support) {
constexpr int64_t exp_axis = -1;
constexpr auto exp_data_type = element::f32;
constexpr auto exp_idx_type = element::i64;
auto data_shape = PartialShape{1, 2, 3, 4};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<Parameter>(exp_data_type, data_shape);
const auto k = Constant::create(exp_idx_type, Shape{}, {2});
const auto op = this->make_op(data, k, exp_axis, "max", "value", exp_idx_type);
EXPECT_EQ(op->get_provided_axis(), exp_axis);
EXPECT_EQ(op->get_axis(), 3);
EXPECT_EQ(op->get_input_size(), 2);
EXPECT_EQ(op->get_output_size(), 2);
EXPECT_EQ(op->get_mode(), op::TopKMode::MAX);
EXPECT_EQ(op->get_sort_type(), op::TopKSortType::SORT_VALUES);
EXPECT_THAT(op->outputs(),
ElementsAre(Property("Value type", &Output<Node>::get_element_type, exp_data_type),
Property("Index type", &Output<Node>::get_element_type, exp_idx_type)));
EXPECT_THAT(op->outputs(), Each(Property("Shape", &Output<Node>::get_shape, Shape({1, 2, 3, 2}))));
EXPECT_THAT(op->outputs(),
Each(Property(&Output<Node>::get_partial_shape,
ResultOf(get_shape_labels, ElementsAre(10, 11, 12, ov::no_label)))));
}
TYPED_TEST_P(topk_type_prop, default_index_element_type) {
constexpr auto exp_data_type = element::f32;
const auto data = std::make_shared<Parameter>(exp_data_type, Shape{1, 2, 3, 4});
const auto k = Constant::create(element::i64, Shape{}, {3});
{
// k > dimension
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->outputs(),
ElementsAre(Property("Value type", &Output<Node>::get_element_type, exp_data_type),
Property("Index type", &Output<Node>::get_element_type, this->exp_default_idx_type)));
EXPECT_THAT(op->outputs(), Each(Property("Shape", &Output<Node>::get_shape, Shape({3, 2, 3, 4}))));
}
{
// k < dimension
const auto op = this->make_op(data, k, 3, "max", "value");
EXPECT_THAT(op->outputs(),
ElementsAre(Property("Value type", &Output<Node>::get_element_type, exp_data_type),
Property("Index type", &Output<Node>::get_element_type, this->exp_default_idx_type)));
EXPECT_THAT(op->outputs(), Each(Property("Shape", &Output<Node>::get_shape, Shape({1, 2, 3, 3}))));
}
}
TYPED_TEST_P(topk_type_prop, k_is_negative) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{-1, {-1, 2}});
const auto k = Constant::create(element::i64, Shape{}, {-1});
OV_EXPECT_THROW(const auto op = this->make_op(data, k, 0, "max", "value"),
NodeValidationFailure,
HasSubstr("The value of 'K' must be more or equal zero."));
}
TYPED_TEST_P(topk_type_prop, k_for_dynamic_dimension) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{-1, {-1, 2}});
const auto k = Constant::create(element::i64, Shape{}, {5});
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("Partial Shape", &Output<Node>::get_partial_shape, PartialShape({5, {-1, 2}}))));
}
TYPED_TEST_P(topk_type_prop, k_for_interval_dimension) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{{2, 12}, {-1, 2}});
const auto k = Constant::create(element::i64, Shape{}, {6});
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("Partial Shape", &Output<Node>::get_partial_shape, PartialShape({6, {-1, 2}}))));
}
TYPED_TEST_P(topk_type_prop, k_is_unknown_for_static_dimension) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 10});
const auto k = std::make_shared<Parameter>(element::i32, PartialShape({}));
const auto op = this->make_op(data, k, 1, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("Partial Shape", &Output<Node>::get_partial_shape, PartialShape({2, {0, 10}}))));
}
TYPED_TEST_P(topk_type_prop, k_is_unknown_for_dynamic_dimension) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{-1, {-1, 2}});
const auto k = std::make_shared<Parameter>(element::i32, PartialShape::dynamic());
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("Partial Shape", &Output<Node>::get_partial_shape, PartialShape({-1, {-1, 2}}))));
}
TYPED_TEST_P(topk_type_prop, k_is_unknown_for_interval_dimension) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{{2, 100}, {-1, 2}});
const auto k = std::make_shared<Parameter>(element::i32, PartialShape::dynamic());
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("Partial Shape", &Output<Node>::get_partial_shape, PartialShape({{0, 100}, {-1, 2}}))));
}
TYPED_TEST_P(topk_type_prop, k_is_unknown_for_interval_with_no_upper_bound_dimension) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{{2, -1}, {-1, 2}});
const auto k = std::make_shared<Parameter>(element::i32, PartialShape::dynamic());
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("Partial Shape", &Output<Node>::get_partial_shape, PartialShape({-1, {-1, 2}}))));
}
TYPED_TEST_P(topk_type_prop, data_and_k_shapes_are_dynamic) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k = std::make_shared<Parameter>(element::i32, PartialShape::dynamic());
const auto op = this->make_op(data, k, 1, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("Partial Shape", &Output<Node>::get_partial_shape, PartialShape::dynamic())));
}
TYPED_TEST_P(topk_type_prop, propagate_label_and_not_interval_value_max) {
auto p_shape = PartialShape{5, 6, 4, 3, 8};
set_shape_labels(p_shape, 1);
constexpr auto et = element::i64;
const auto labeled_param = std::make_shared<Parameter>(et, p_shape);
const auto labeled_shape_of = std::make_shared<ShapeOf>(labeled_param);
const auto k = Constant::create(et, Shape{}, {3});
const auto op = this->make_op(labeled_shape_of, k, 0, "max", "index", element::i32);
const auto bc_shapes = this->make_broadcast_shapes_of_topk_outs(op.get());
EXPECT_THAT(bc_shapes, ElementsAre(PartialShape({5, 6, 8}), PartialShape({0, 1, 4})));
EXPECT_THAT(bc_shapes, Each(ResultOf(get_shape_labels, Each(ov::no_label))));
}
TYPED_TEST_P(topk_type_prop, propagate_label_and_not_interval_value_min) {
auto p_shape = PartialShape{5, 6, 3, 4, 8};
set_shape_labels(p_shape, 1);
constexpr auto et = element::i64;
const auto labeled_param = std::make_shared<Parameter>(et, p_shape);
const auto labeled_shape_of = std::make_shared<ShapeOf>(labeled_param);
const auto k = Constant::create(et, Shape{}, {3});
const auto op = this->make_op(labeled_shape_of, k, 0, "min", "index", element::i32);
const auto bc_shapes = this->make_broadcast_shapes_of_topk_outs(op.get());
EXPECT_THAT(bc_shapes, ElementsAre(PartialShape({5, 3, 4}), PartialShape({0, 2, 3})));
EXPECT_THAT(bc_shapes, Each(ResultOf(get_shape_labels, Each(ov::no_label))));
}
TYPED_TEST_P(topk_type_prop, preserve_partial_values_and_labels_k_is_interval) {
auto k_dim = Dimension{10, 20};
auto shape = PartialShape{k_dim};
ov::DimensionTracker::set_label(k_dim, 20);
const auto p_k = std::make_shared<Parameter>(element::i64, shape);
const auto shape_of_k = std::make_shared<ShapeOf>(p_k);
const auto k = std::make_shared<Squeeze>(shape_of_k, Constant::create(element::i64, Shape{}, {0}));
auto data_shape = PartialShape{{2, 5}, {12, 18}, {2, 30}, {30, 40}, {-1, 15}, {15, -1}};
set_shape_labels(data_shape, 1);
const auto data = std::make_shared<Parameter>(element::f32, data_shape);
{
// dim{2,5} k{10,20} -> {2,20}
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->get_output_partial_shape(0),
AllOf(PartialShape({{2, 20}, {12, 18}, {2, 30}, {30, 40}, {-1, 15}, {15, -1}}),
ResultOf(get_shape_labels, ElementsAre(no_label, 2, 3, 4, 5, 6))));
}
{
// dim{12,18} k{10,20} -> {10,20}
const auto op = this->make_op(data, k, 1, "max", "value");
EXPECT_THAT(op->get_output_partial_shape(0),
AllOf(PartialShape({{2, 5}, {10, 20}, {2, 30}, {30, 40}, {-1, 15}, {15, -1}}),
ResultOf(get_shape_labels, ElementsAre(1, no_label, 3, 4, 5, 6))));
}
{
// dim{2, 30} k{10,20} -> {2,30}
const auto op = this->make_op(data, k, 2, "max", "value");
EXPECT_THAT(op->get_output_partial_shape(0),
AllOf(PartialShape({{2, 5}, {12, 18}, {2, 30}, {30, 40}, {-1, 15}, {15, -1}}),
ResultOf(get_shape_labels, ElementsAre(1, 2, no_label, 4, 5, 6))));
}
{
// dim{30,40} k{10,20} -> {10,40} (should use k upper bounds??)
const auto op = this->make_op(data, k, 3, "max", "value");
EXPECT_THAT(op->get_output_partial_shape(0),
AllOf(PartialShape({{2, 5}, {12, 18}, {2, 30}, {10, 40}, {-1, 15}, {15, -1}}),
ResultOf(get_shape_labels, ElementsAre(1, 2, 3, no_label, 5, 6))));
}
{
// dim{-inf,15} k{10,20} -> {0,20}
const auto op = this->make_op(data, k, 4, "max", "value");
EXPECT_THAT(op->get_output_partial_shape(0),
AllOf(PartialShape({{2, 5}, {12, 18}, {2, 30}, {30, 40}, {0, 20}, {15, -1}}),
ResultOf(get_shape_labels, ElementsAre(1, 2, 3, 4, no_label, 6))));
}
{
// dim{15,inf} k{10,20} -> {10,inf}
const auto op = this->make_op(data, k, 5, "max", "value");
EXPECT_THAT(op->get_output_partial_shape(0),
AllOf(PartialShape({{2, 5}, {12, 18}, {2, 30}, {30, 40}, {-1, 15}, {10, -1}}),
ResultOf(get_shape_labels, ElementsAre(1, 2, 3, 4, 5, no_label))));
}
}
TYPED_TEST_P(topk_type_prop, preserve_partial_values_and_labels_k_is_interval_with_no_upper_bound) {
auto shape = PartialShape{{10, -1}};
set_shape_labels(shape, 20);
const auto p_k = std::make_shared<Parameter>(element::i64, shape);
const auto shape_of_k = std::make_shared<ShapeOf>(p_k);
// Squeeze make scalar but if interval value has no upper bound result will be {0,inf}
const auto k = std::make_shared<Squeeze>(shape_of_k, Constant::create(element::i64, Shape{}, {0}));
auto data_shape = PartialShape{5, {2, 8}, {2, 100}};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<Parameter>(element::f32, data_shape);
{
// dim{5} k{0,inf} -> {0,5}
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->get_output_partial_shape(0),
AllOf(PartialShape({{0, 5}, {2, 8}, {2, 100}}),
ResultOf(get_shape_labels, ElementsAre(ov::no_label, 11, 12))));
}
{
// dim{2,8} k{0,inf} -> {0,8}
const auto op = this->make_op(data, k, 1, "max", "value");
EXPECT_THAT(
op->get_output_partial_shape(0),
AllOf(PartialShape({5, {0, 8}, {2, 100}}), ResultOf(get_shape_labels, ElementsAre(10, ov::no_label, 12))));
}
{
// dim{2,100} k{0,inf} -> {0,100}
const auto op = this->make_op(data, k, 2, "max", "value");
EXPECT_THAT(
op->get_output_partial_shape(0),
AllOf(PartialShape({5, {2, 8}, {0, 100}}), ResultOf(get_shape_labels, ElementsAre(10, 11, ov::no_label))));
}
}
TYPED_TEST_P(topk_type_prop, negative_axis_dynamic_rank) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k = Constant::create(element::i64, Shape{}, {2});
const int64_t axis = -2;
const auto op = this->make_op(data, k, axis, "max", "value");
OV_EXPECT_THROW(op->get_axis(), NodeValidationFailure, HasSubstr("Normalized axis of TopK is unknown"));
}
TYPED_TEST_P(topk_type_prop, incorrect_index_element_type) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k = Constant::create(element::i64, Shape{}, {2});
const int64_t axis = -2;
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
ASSERT_EQ(topk->get_provided_axis(), axis);
const auto expect_shape = Shape{1, 2, 2, 4};
ASSERT_EQ(topk->get_output_shape(0), expect_shape);
ASSERT_EQ(topk->get_output_shape(1), expect_shape);
}
TYPED_TEST_P(topk_type_prop, topk_default_index_element_type) {
const auto data_shape = Shape{1, 2, 3, 4};
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto k = op::Constant::create(element::i64, Shape{}, {2});
const int64_t axis = -2;
const auto op = make_shared<op::v1::TopK>(data, k, axis, "max", "value");
ASSERT_EQ(op->get_index_element_type(), element::i32);
}
TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank) {
const auto data_shape = PartialShape::dynamic();
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto k = op::Constant::create(element::i64, Shape{}, {2});
const int64_t axis = -2;
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
try {
topk->get_axis();
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Normalized axis of TopK is unknown"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TYPED_TEST_P(topk_type_prop, topk_v1_partial_ouptut) {
auto data_shape = PartialShape{2, 10};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
{
auto k = make_shared<op::Parameter>(element::i32, PartialShape({}));
auto topk = make_shared<TypeParam>(data, k, 1, "max", "value");
EXPECT_EQ(topk->get_output_partial_shape(0), PartialShape({2, Dimension(0, 10)}));
}
{
auto k = make_shared<op::Constant>(element::i32, Shape{}, 3);
auto topk = make_shared<TypeParam>(data, k, 1, "max", "value");
EXPECT_EQ(topk->get_output_shape(0), Shape({2, 3}));
EXPECT_EQ(topk->get_output_partial_shape(0), PartialShape({2, 3}));
}
}
TYPED_TEST_P(topk_type_prop, topk_rank_static_k_unknown) {
const int64_t axis = 1;
const auto data_shape = Shape{1, 10, 100};
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
{
const auto k = make_shared<op::Parameter>(element::i32, PartialShape({}));
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
const PartialShape fully_dynamic_axis_shape{1, Dimension(0, 10), 100};
EXPECT_EQ(topk->get_output_partial_shape(0), fully_dynamic_axis_shape);
}
{
const auto k = make_shared<op::v0::Constant>(element::i64, Shape{}, 5);
const auto convert_k = make_shared<op::v0::Convert>(k, element::i32);
const auto topk = make_shared<TypeParam>(data, convert_k, axis, "max", "value");
const PartialShape ranged_dynamic_axis_shape{1, Dimension{5}, 100};
EXPECT_EQ(topk->get_output_partial_shape(0), ranged_dynamic_axis_shape);
}
OV_EXPECT_THROW(const auto op = this->make_op(data, k, axis, "max", "value", element::i16),
NodeValidationFailure,
HasSubstr("Index element type attribute should be either \'i32\' or \'i64\'. Got:"));
}
REGISTER_TYPED_TEST_SUITE_P(topk_type_prop,
topk_negative_axis_support,
topk_negative_axis_dynamic_rank,
topk_v1_partial_ouptut,
topk_rank_static_k_unknown,
topk_default_index_element_type);
default_ctor,
default_ctor_no_arguments,
negative_axis_support,
default_index_element_type,
k_is_negative,
k_for_dynamic_dimension,
k_for_interval_dimension,
k_is_unknown_for_static_dimension,
k_is_unknown_for_dynamic_dimension,
k_is_unknown_for_interval_dimension,
k_is_unknown_for_interval_with_no_upper_bound_dimension,
data_and_k_shapes_are_dynamic,
propagate_label_and_not_interval_value_max,
propagate_label_and_not_interval_value_min,
preserve_partial_values_and_labels_k_is_interval,
preserve_partial_values_and_labels_k_is_interval_with_no_upper_bound,
negative_axis_dynamic_rank,
incorrect_index_element_type);
typedef ::testing::Types<op::v1::TopK, op::v3::TopK> TopKTypes;
INSTANTIATE_TYPED_TEST_SUITE_P(type_prop, topk_type_prop, TopKTypes, );
typedef Types<op::v1::TopK, op::v3::TopK> TopKTypes;
INSTANTIATE_TYPED_TEST_SUITE_P(type_prop, topk_type_prop, TopKTypes);
class TypePropTopKV1Test : public TypePropOpTest<op::v1::TopK> {};
TEST_F(TypePropTopKV1Test, k_is_u32) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{5, {-1, 2}});
const auto k = Constant::create(element::u32, Shape{}, {1});
OV_EXPECT_THROW(const auto op = this->make_op(data, k, 0, "max", "value"),
NodeValidationFailure,
HasSubstr("K input element type must be i8, i32 or i64 (got u32)"));
}
class TypePropTopKV3Test : public TypePropOpTest<op::v3::TopK> {};
TEST_F(TypePropTopKV3Test, k_is_u32) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{5, {-1, 2}});
const auto k = Constant::create(element::u32, Shape{}, {1});
const auto op = this->make_op(data, k, 0, "max", "value");
EXPECT_THAT(op->outputs(),
Each(Property("PartialShape", &Output<Node>::get_partial_shape, PartialShape({1, {-1, 2}}))));
}

View File

@ -439,6 +439,8 @@ std::shared_ptr<IShapeInferCommon> make_shape_inference(const std::shared_ptr<ng
return make_shared_entryIO(node);
} else if (auto node = ov::as_type_ptr<ov::opset6::ExperimentalDetectronDetectionOutput>(op)) {
return make_shared_entryIO(node);
} else if (auto node = ov::as_type_ptr<ov::opset1::TopK>(op)) {
return make_shared_entryIOC(node);
} else if (auto node = ov::as_type_ptr<ov::opset3::TopK>(op)) {
return make_shared_entryIOC(node);
} else if (auto node = ov::as_type_ptr<ov::opset3::Bucketize>(op)) {

View File

@ -0,0 +1,161 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gmock/gmock.h"
#include "openvino/opsets/opset10.hpp"
#include "topk_shape_inference.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace ov::opset10;
using namespace testing;
namespace topk_test {
using TopKTestParams = std::tuple<ShapeVector, // Input shapes
int64_t, // axis
int64_t, // k value
StaticShape // Expected output shape
>;
template <class TOp>
class TopKTest : public OpStaticShapeInferenceTest<TOp>, public WithParamInterface<TopKTestParams> {
protected:
void SetUp() override {
std::tie(this->input_shapes, this->axis, this->k, this->exp_shape) = GetParam();
this->output_shapes.resize(2);
}
int64_t axis, k;
};
const auto TopkTestValues = Values(make_tuple(ShapeVector{{0}, {}}, 0, 1, StaticShape{1}),
make_tuple(ShapeVector{{5, 2, 10, 0}, {}}, -1, 5, StaticShape{5, 2, 10, 5}),
make_tuple(ShapeVector{{3, 5, 6}, {}}, 1, 2, StaticShape{3, 2, 6}));
namespace v1 {
using TopKV1AssertStaticShapeInferenceTest = OpStaticShapeInferenceTest<op::v1::TopK>;
TEST_F(TopKV1AssertStaticShapeInferenceTest, k_is_negative) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = std::make_shared<Parameter>(element::i64, PartialShape::dynamic());
const auto op = make_op(data, k_node, 0, "max", "value");
input_shapes = ShapeVector{{5, 2}, {}};
output_shapes = ShapeVector(2);
int64_t k = -2;
const auto const_map =
std::map<size_t, HostTensorPtr>{{1, std::make_shared<HostTensor>(element::i64, Shape{}, &k)}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes, const_map),
ov::AssertFailure,
HasSubstr("The value of 'K' must be more or equal zero. (got " + std::to_string(k) + ")"));
}
using TopKV1Test = TopKTest<op::v1::TopK>;
INSTANTIATE_TEST_SUITE_P(StaticShapeInference, TopKV1Test, TopkTestValues, PrintToStringParamName());
TEST_P(TopKV1Test, no_constant_map) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = Constant::create(element::i64, Shape{}, {k});
const auto op = make_op(data, k_node, axis, "max", "value");
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 2);
EXPECT_THAT(output_shapes, Each(exp_shape));
}
TEST_P(TopKV1Test, k_as_param_no_const_map) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = std::make_shared<Parameter>(element::i64, PartialShape::dynamic());
const auto op = make_op(data, k_node, axis, "min", "value");
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
NodeValidationFailure,
HasSubstr("Static shape inference lacks constant data on port 1"));
}
TEST_P(TopKV1Test, k_as_param_in_const_map) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = std::make_shared<Parameter>(element::i64, PartialShape::dynamic());
const auto const_map =
std::map<size_t, HostTensorPtr>{{1, std::make_shared<HostTensor>(element::i64, Shape{}, &k)}};
const auto op = make_op(data, k_node, axis, "min", "value");
shape_inference(op.get(), input_shapes, output_shapes, const_map);
EXPECT_EQ(output_shapes.size(), 2);
EXPECT_THAT(output_shapes, Each(exp_shape));
}
} // namespace v1
namespace v3 {
using TopKV3AssertStaticShapeInferenceTest = OpStaticShapeInferenceTest<op::v3::TopK>;
TEST_F(TopKV3AssertStaticShapeInferenceTest, k_is_negative) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = std::make_shared<Parameter>(element::i64, PartialShape::dynamic());
const auto op = make_op(data, k_node, 0, "max", "value");
input_shapes = ShapeVector{{5, 2}, {}};
output_shapes = ShapeVector(2);
int64_t k = -2;
const auto const_map =
std::map<size_t, HostTensorPtr>{{1, std::make_shared<HostTensor>(element::i64, Shape{}, &k)}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes, const_map),
ov::AssertFailure,
HasSubstr("The value of 'K' must be more or equal zero. (got " + std::to_string(k) + ")"));
}
using TopKV3Test = TopKTest<op::v3::TopK>;
INSTANTIATE_TEST_SUITE_P(StaticShapeInference, TopKV3Test, TopkTestValues, PrintToStringParamName());
TEST_P(TopKV3Test, k_as_constant) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = Constant::create(element::i64, Shape{}, {k});
const auto op = make_op(data, k_node, axis, "min", "value");
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 2);
EXPECT_THAT(output_shapes, Each(exp_shape));
}
TEST_P(TopKV3Test, k_as_param_no_const_map) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = std::make_shared<Parameter>(element::i64, PartialShape::dynamic());
const auto op = make_op(data, k_node, axis, "min", "value");
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
NodeValidationFailure,
HasSubstr("Static shape inference lacks constant data on port 1"));
}
TEST_P(TopKV3Test, k_as_param_in_const_map) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto k_node = std::make_shared<Parameter>(element::i64, PartialShape::dynamic());
const auto const_map =
std::map<size_t, HostTensorPtr>{{1, std::make_shared<HostTensor>(element::i64, Shape{}, &k)}};
const auto op = make_op(data, k_node, axis, "max", "value");
shape_inference(op.get(), input_shapes, output_shapes, const_map);
EXPECT_EQ(output_shapes.size(), 2);
EXPECT_THAT(output_shapes, Each(exp_shape));
}
} // namespace v3
} // namespace topk_test

View File

@ -1,44 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <topk_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
static std::shared_ptr<op::v3::TopK> build_topk(PartialShape data_shape = PartialShape::dynamic(),
int64_t axis = 1,
int k_value = -1) {
std::shared_ptr<ov::Node> k;
const auto data = std::make_shared<op::v0::Parameter>(element::f32, data_shape);
if (k_value >= 0)
k = op::v0::Constant::create(element::i64, ov::Shape{}, {2});
else
k = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape{});
return std::make_shared<op::v3::TopK>(data, k, axis, "max", "value");
}
TEST(StaticShapeInferenceTest, TopKv3) {
const auto topk = build_topk(PartialShape::dynamic(), 1, 2);
check_static_shape(topk.get(),
{StaticShape{1, 10, 100}, StaticShape{}},
{StaticShape({1, 2, 100}), StaticShape({1, 2, 100})});
}
TEST(StaticShapeInferenceTest, TopKv3_StaticNoConstMap) {
const auto topk = build_topk();
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 10, 100}, StaticShape{}};
std::vector<StaticShape> static_output_shapes = {StaticShape{}, StaticShape{}};
EXPECT_THROW(shape_inference(topk.get(), static_input_shapes, static_output_shapes), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, TopKv3_StaticWithConstMap) {
const auto topk = build_topk();
check_static_shape(topk.get(), {StaticShape{1, 10, 100}, 2}, {StaticShape{1, 2, 100}, StaticShape{1, 2, 100}});
}