From 921e621404727cacc4810bc6d7ed70e45d18de56 Mon Sep 17 00:00:00 2001 From: Tomasz Jankowski Date: Thu, 5 Oct 2023 05:26:49 +0200 Subject: [PATCH] [Ref] Drop legacy API (#20006) * Drop legacy API - CoordinateTransform * Refactor AvgPool ref * Fix zero padding indices ceil rounding * Transpose 3D reshaped kernels * Reuse max_pool v8 ref in v1 * Change ref::slice params validation * Fix wrong alloc in max_pool ref * Drop legacy from TopK * Fix mvn ref * Deprecate unused * Drop ngraph from TopK op * Remove deprecated * Use OPENVINO_ASSERT * Replace headers paths * Replace headers paths * Add missing include * Address review comments * Clean up * Remove unused and duplicated code --- .../openvino/reference/adaptive_avg_pool.hpp | 9 +- .../openvino/reference/adaptive_max_pool.hpp | 16 +- .../include/openvino/reference/and.hpp | 4 +- .../reference/autobroadcast_binop.hpp | 76 ++-- .../include/openvino/reference/avg_pool.hpp | 341 ++++++++---------- .../include/openvino/reference/batch_norm.hpp | 6 +- .../openvino/reference/binary_convolution.hpp | 2 +- .../include/openvino/reference/broadcast.hpp | 4 +- .../include/openvino/reference/bucketize.hpp | 2 +- .../include/openvino/reference/concat.hpp | 2 +- .../include/openvino/reference/convert.hpp | 6 +- .../openvino/reference/convert_color_nv12.hpp | 12 +- .../openvino/reference/convolution.hpp | 46 +-- .../reference/convolution_backprop_data.hpp | 48 ++- .../openvino/reference/ctc_greedy_decoder.hpp | 14 +- .../include/openvino/reference/ctc_loss.hpp | 6 +- .../reference/deformable_convolution.hpp | 39 +- .../reference/deformable_psroi_pooling.hpp | 2 +- .../openvino/reference/depth_to_space.hpp | 4 +- .../openvino/reference/detection_output.hpp | 27 +- .../include/openvino/reference/divide.hpp | 8 +- .../include/openvino/reference/einsum.hpp | 3 +- .../reference/embedding_bag_offsets_sum.hpp | 2 +- .../reference/embedding_bag_packed_sum.hpp | 2 +- .../reference/embedding_segments_sum.hpp | 2 +- .../include/openvino/reference/equal.hpp | 4 +- .../include/openvino/reference/erf.hpp | 3 - ...xperimental_detectron_detection_output.hpp | 9 +- ...imental_detectron_prior_grid_generator.hpp | 7 +- ...mental_detectron_proposal_single_image.hpp | 9 +- ...mental_detectron_roi_feature_extractor.hpp | 8 +- .../experimental_detectron_topk_rois.hpp | 7 +- .../reference/extract_image_patches.hpp | 4 +- .../include/openvino/reference/eye.hpp | 2 +- .../openvino/reference/fake_quantize.hpp | 30 +- .../include/openvino/reference/floor_mod.hpp | 1 - .../include/openvino/reference/gather.hpp | 2 +- .../openvino/reference/gather_tree.hpp | 4 +- .../include/openvino/reference/gelu.hpp | 4 +- .../openvino/reference/generate_proposal.hpp | 11 +- .../include/openvino/reference/greater.hpp | 4 +- .../include/openvino/reference/greater_eq.hpp | 4 +- .../openvino/reference/grid_sample.hpp | 2 +- .../group_convolution_backprop_data.hpp | 9 +- .../openvino/reference/interpolate.hpp | 59 ++- .../openvino/reference/interpolate_pil.hpp | 4 +- .../include/openvino/reference/irdft.hpp | 2 +- .../include/openvino/reference/less.hpp | 4 +- .../include/openvino/reference/less_eq.hpp | 4 +- .../openvino/reference/log_softmax.hpp | 27 +- .../openvino/reference/logical_reduction.hpp | 1 - .../include/openvino/reference/lrn.hpp | 8 +- .../include/openvino/reference/matmul.hpp | 1 - .../include/openvino/reference/matrix_nms.hpp | 13 +- .../include/openvino/reference/max_pool.hpp | 157 ++------ .../include/openvino/reference/maximum.hpp | 4 +- .../include/openvino/reference/minimum.hpp | 4 +- .../openvino/reference/multiclass_nms.hpp | 13 +- .../include/openvino/reference/multiply.hpp | 4 +- .../include/openvino/reference/mvn.hpp | 10 +- .../reference/non_max_suppression.hpp | 21 +- .../include/openvino/reference/non_zero.hpp | 2 +- .../openvino/reference/normalize_l2.hpp | 2 - .../include/openvino/reference/not_equal.hpp | 4 +- .../include/openvino/reference/one_hot.hpp | 2 +- .../include/openvino/reference/or.hpp | 4 +- .../include/openvino/reference/pad.hpp | 6 +- .../include/openvino/reference/power.hpp | 4 +- .../include/openvino/reference/prelu.hpp | 16 +- .../include/openvino/reference/prior_box.hpp | 7 +- .../reference/prior_box_clustered.hpp | 9 +- .../include/openvino/reference/proposal.hpp | 4 +- .../openvino/reference/psroi_pooling.hpp | 4 +- .../include/openvino/reference/quantize.hpp | 76 ---- .../openvino/reference/random_uniform.hpp | 6 +- .../include/openvino/reference/range.hpp | 6 +- .../include/openvino/reference/rdft.hpp | 6 +- .../openvino/reference/region_yolo.hpp | 4 +- .../include/openvino/reference/reorg_yolo.hpp | 2 +- .../include/openvino/reference/reshape.hpp | 4 +- .../include/openvino/reference/result.hpp | 2 - .../openvino/reference/reverse_sequence.hpp | 5 +- .../include/openvino/reference/roi_align.hpp | 69 ++-- .../openvino/reference/roi_pooling.hpp | 4 +- .../include/openvino/reference/roll.hpp | 2 +- .../include/openvino/reference/round.hpp | 4 +- .../{round_guard.hpp => rounding_guard.hpp} | 6 +- .../reference/scatter_elements_update.hpp | 4 +- .../openvino/reference/scatter_nd_update.hpp | 3 +- .../openvino/reference/scatter_update.hpp | 5 +- .../include/openvino/reference/sequences.hpp | 48 +-- .../include/openvino/reference/shape_of.hpp | 2 +- .../openvino/reference/shuffle_channels.hpp | 2 +- .../include/openvino/reference/slice.hpp | 2 +- .../include/openvino/reference/softmax.hpp | 34 +- .../openvino/reference/space_to_depth.hpp | 4 +- .../openvino/reference/squared_difference.hpp | 1 - .../include/openvino/reference/subtract.hpp | 2 + .../include/openvino/reference/tile.hpp | 5 +- .../include/openvino/reference/topk.hpp | 33 +- .../include/openvino/reference/transpose.hpp | 2 +- .../include/openvino/reference/unique.hpp | 4 +- .../openvino/reference/utils/fft_common.hpp | 5 +- .../openvino/reference/utils/nms_common.hpp | 6 +- src/core/reference/src/op/depth_to_space.cpp | 22 +- src/core/reference/src/op/einsum.cpp | 28 +- ...xperimental_detectron_detection_output.cpp | 6 +- ...mental_detectron_proposal_single_image.cpp | 6 +- ...mental_detectron_roi_feature_extractor.cpp | 6 +- src/core/reference/src/op/function.cpp | 4 - src/core/reference/src/op/gather_tree.cpp | 5 +- .../reference/src/op/generate_proposal.cpp | 8 +- .../reference/src/op/group_convolution.cpp | 18 +- .../op/group_convolution_backprop_data.cpp | 60 +-- src/core/reference/src/op/if.cpp | 16 +- src/core/reference/src/op/interpolate.cpp | 20 +- src/core/reference/src/op/irdft.cpp | 2 +- src/core/reference/src/op/matmul.cpp | 4 +- src/core/reference/src/op/matrix_nms.cpp | 4 +- src/core/reference/src/op/multiclass_nms.cpp | 5 +- src/core/reference/src/op/pad.cpp | 9 +- src/core/reference/src/op/random_uniform.cpp | 18 +- src/core/reference/src/op/rdft.cpp | 2 +- src/core/reference/src/op/reorg_yolo.cpp | 2 +- src/core/reference/src/op/reshape.cpp | 4 +- src/core/reference/src/op/reverse.cpp | 6 +- src/core/reference/src/op/slice.cpp | 22 +- src/core/reference/src/op/space_to_depth.cpp | 16 +- src/core/reference/src/op/split.cpp | 2 - src/core/reference/src/op/strided_slice.cpp | 1 - src/core/reference/src/op/transpose.cpp | 2 +- .../reference/src/op/utils/fft_common.cpp | 4 +- .../reference/src/op/utils/nms_common.cpp | 12 +- .../reference/src/op/utils/round_guard.cpp | 11 - .../reference/src/op/utils/rounding_guard.cpp | 11 + .../src/runtime/opt_kernel/reshape.cpp | 1 - .../src/utils/coordinate_transform.cpp | 16 +- src/core/src/op/interpolate.cpp | 31 +- src/core/src/op/topk.cpp | 114 +++--- .../template/backend/ops/interpolate.cpp | 36 +- .../functional/op_reference/avg_pool.cpp | 32 +- 141 files changed, 869 insertions(+), 1262 deletions(-) delete mode 100644 src/core/reference/include/openvino/reference/quantize.hpp rename src/core/reference/include/openvino/reference/{round_guard.hpp => rounding_guard.hpp} (84%) delete mode 100644 src/core/reference/src/op/utils/round_guard.cpp create mode 100644 src/core/reference/src/op/utils/rounding_guard.cpp diff --git a/src/core/reference/include/openvino/reference/adaptive_avg_pool.hpp b/src/core/reference/include/openvino/reference/adaptive_avg_pool.hpp index d8c6bf94430..d59eb1bf6e1 100644 --- a/src/core/reference/include/openvino/reference/adaptive_avg_pool.hpp +++ b/src/core/reference/include/openvino/reference/adaptive_avg_pool.hpp @@ -8,8 +8,7 @@ #include #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { @@ -23,7 +22,7 @@ inline size_t window_end(size_t idx, size_t arg_shape, size_t out_shape) { } template T avg_div(const T sum, size_t n) { - NGRAPH_CHECK(n != 0, "AdaptiveAvgPool elements == 0, must be non-zero"); + OPENVINO_ASSERT(n != 0, "AdaptiveAvgPool elements == 0, must be non-zero"); if (std::is_same::value || std::is_same::value) { return static_cast(std::nearbyint(static_cast(sum) / n)); @@ -90,8 +89,8 @@ void adaptive_avg_pool_3d(const T* arg, } // namespace adaptive_pool template void adaptive_avg_pool(const T* arg, T* out, const Shape& arg_shape, const Shape& out_shape) { - NGRAPH_CHECK(arg_shape.size() == out_shape.size() && 2 < arg_shape.size() && arg_shape.size() < 6, - "AdaptiveAvgPool supports only 3D, 4D and 5D input shape"); + OPENVINO_ASSERT(arg_shape.size() == out_shape.size() && 2 < arg_shape.size() && arg_shape.size() < 6, + "AdaptiveAvgPool supports only 3D, 4D and 5D input shape"); size_t channel_size = 1; for (size_t i = 2; i < arg_shape.size(); i++) { channel_size *= arg_shape[i]; diff --git a/src/core/reference/include/openvino/reference/adaptive_max_pool.hpp b/src/core/reference/include/openvino/reference/adaptive_max_pool.hpp index 69c8ef2c940..a84ed81b47d 100644 --- a/src/core/reference/include/openvino/reference/adaptive_max_pool.hpp +++ b/src/core/reference/include/openvino/reference/adaptive_max_pool.hpp @@ -8,8 +8,7 @@ #include #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/adaptive_avg_pool.hpp" namespace ov { @@ -19,7 +18,7 @@ void adaptive_max_pool_1d(const T* arg, T* out, IT* indices, size_t h_in, size_t for (size_t i = 0; i < h_out; i++) { auto from = arg + adaptive_pool::window_start(i, h_in, h_out); auto to = arg + adaptive_pool::window_end(i, h_in, h_out); - NGRAPH_CHECK(to - from != 0, "AdaptiveMaxPool elements == 0, must be non-zero"); + OPENVINO_ASSERT(to - from != 0, "AdaptiveMaxPool elements == 0, must be non-zero"); auto it = std::max_element(from, to); out[i] = static_cast(*it); indices[i] = static_cast(it - arg); @@ -33,7 +32,8 @@ void adaptive_max_pool_2d(const T* arg, T* out, IT* indices, size_t h_in, size_t for (size_t j = 0; j < w_out; j++) { size_t w_start = adaptive_pool::window_start(j, w_in, w_out); size_t w_end = adaptive_pool::window_end(j, w_in, w_out); - NGRAPH_CHECK((w_end - w_start) * (h_end - h_start) != 0, "AdaptiveMaxPool elements == 0, must be non-zero"); + OPENVINO_ASSERT((w_end - w_start) * (h_end - h_start) != 0, + "AdaptiveMaxPool elements == 0, must be non-zero"); auto result = arg + h_start * w_in + w_start; for (size_t n = h_start; n < h_end; n++) { auto from = arg + n * w_in + w_start; @@ -65,8 +65,8 @@ void adaptive_max_pool_3d(const T* arg, for (size_t k = 0; k < w_out; k++) { size_t w_start = adaptive_pool::window_start(k, w_in, w_out); size_t w_end = adaptive_pool::window_end(k, w_in, w_out); - NGRAPH_CHECK((w_end - w_start) * (h_end - h_start) != 0, - "AdaptiveMaxPool elements == 0, must be non-zero"); + OPENVINO_ASSERT((w_end - w_start) * (h_end - h_start) != 0, + "AdaptiveMaxPool elements == 0, must be non-zero"); auto result = arg + d_start * h_in * w_in + h_start * w_in + w_start; for (size_t n = d_start; n < d_end; n++) { for (size_t m = h_start; m < h_end; m++) { @@ -84,8 +84,8 @@ void adaptive_max_pool_3d(const T* arg, } template void adaptive_max_pool(const T* arg, T* out, IT* selected_indices, const Shape& arg_shape, const Shape& out_shape) { - NGRAPH_CHECK(arg_shape.size() == out_shape.size() && 2 < arg_shape.size() && arg_shape.size() < 6, - "AdaptiveAvgPool supports only 3D, 4D and 5D input shape"); + OPENVINO_ASSERT(arg_shape.size() == out_shape.size() && 2 < arg_shape.size() && arg_shape.size() < 6, + "AdaptiveAvgPool supports only 3D, 4D and 5D input shape"); size_t channel_size = 1; for (size_t i = 2; i < arg_shape.size(); i++) { channel_size *= arg_shape[i]; diff --git a/src/core/reference/include/openvino/reference/and.hpp b/src/core/reference/include/openvino/reference/and.hpp index f1b6783b285..326e4b59d77 100644 --- a/src/core/reference/include/openvino/reference/and.hpp +++ b/src/core/reference/include/openvino/reference/and.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/autobroadcast_binop.hpp b/src/core/reference/include/openvino/reference/autobroadcast_binop.hpp index 0de0e66de9f..510d691f228 100644 --- a/src/core/reference/include/openvino/reference/autobroadcast_binop.hpp +++ b/src/core/reference/include/openvino/reference/autobroadcast_binop.hpp @@ -8,8 +8,9 @@ #include #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape_util.hpp" +#include "openvino/op/util/attr_types.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -103,13 +104,13 @@ void autobroadcast_binop(const T* arg0, } break; case op::AutoBroadcastType::NUMPY: - // We'll be using CoordinateTransform to handle the broadcasting. The general + // We'll be using CoordinateTransformBasic to handle the broadcasting. The general // procedure is as follows: // // (1) Left pad the shorter of the two shapes with ones. // (2) Squeeze (remove ones from) both shapes, and record the squeezed axis // indices. - // (3) Using CoordinateTransform, broadcast both args to the final output + // (3) Using CoordinateTransformBasic, broadcast both args to the final output // shape. The "broadcasted axes" will be those that were squeezed in step // 2. // @@ -207,7 +208,7 @@ void autobroadcast_binop(const T* arg0, } break; case op::AutoBroadcastType::PDPD: - // We'll be using CoordinateTransform to handle the broadcasting. No need to + // We'll be using CoordinateTransformBasic to handle the broadcasting. No need to // process arg0 and output shape will be the same as arg0. We need to process // arg1 and the general procedure is as follows: // @@ -216,7 +217,7 @@ void autobroadcast_binop(const T* arg0, // to align between arg0 and arg1. // (3) Squeeze (remove ones from) arg1 shape, and record the squeezed axis // indices. - // (3) Using CoordinateTransform, broadcast arg1 to the final output + // (3) Using CoordinateTransformBasic, broadcast arg1 to the final output // shape. The "broadcasted axes" will be those that were squeezed in step // 23. // @@ -262,18 +263,15 @@ void autobroadcast_binop(const T* arg0, } } - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform arg0_transform(arg0_shape); - CoordinateTransform arg1_transform(arg1_squeezed_shape); - CoordinateTransform output_transform(arg0_shape); + const CoordinateTransformBasic output_transform{arg0_shape}; for (const Coordinate& output_coord : output_transform) { - Coordinate arg1_coord = ngraph::reduce(output_coord, arg1_squeezed_axes, false); - out[output_transform.index(output_coord)] = - elementwise_functor(arg0[arg0_transform.index(output_coord)], - arg1[arg1_transform.index(arg1_coord)]); + const auto arg1_coord = util::reduce(output_coord, arg1_squeezed_axes); + const auto out_index = coordinate_index(output_coord, arg0_shape); + const auto arg0_index = coordinate_index(output_coord, arg0_shape); + const auto arg1_index = coordinate_index(arg1_coord, arg1_squeezed_shape); + out[out_index] = elementwise_functor(arg0[arg0_index], arg1[arg1_index]); } - NGRAPH_SUPPRESS_DEPRECATED_END } } } @@ -366,10 +364,7 @@ void autobroadcast_select(const U* arg0, output_shape.push_back(std::max({arg0_padded_shape[i], arg2_padded_shape[i], arg1_padded_shape[i]})); } - CoordinateTransformBasic arg0_transform(arg0_squeezed_shape); - CoordinateTransformBasic arg1_transform(arg1_squeezed_shape); - CoordinateTransformBasic arg2_transform(arg2_squeezed_shape); - CoordinateTransformBasic output_transform(output_shape); + const CoordinateTransformBasic output_transform{output_shape}; const auto arg0_strides = row_major_strides(arg0_squeezed_shape); const auto arg1_strides = row_major_strides(arg1_squeezed_shape); @@ -377,20 +372,14 @@ void autobroadcast_select(const U* arg0, const auto output_strides = row_major_strides(output_shape); for (const Coordinate& output_coord : output_transform) { - NGRAPH_SUPPRESS_DEPRECATED_START - const Coordinate arg0_coord = ngraph::reduce(output_coord, arg0_squeezed_axes, false); - const Coordinate arg1_coord = ngraph::reduce(output_coord, arg1_squeezed_axes, false); - const Coordinate arg2_coord = ngraph::reduce(output_coord, arg2_squeezed_axes, false); - NGRAPH_SUPPRESS_DEPRECATED_END + const auto arg0_coord = util::reduce(output_coord, arg0_squeezed_axes); + const auto arg1_coord = util::reduce(output_coord, arg1_squeezed_axes); + const auto arg2_coord = util::reduce(output_coord, arg2_squeezed_axes); - const size_t arg0_idx = - std::inner_product(arg0_coord.begin(), arg0_coord.end(), arg0_strides.begin(), uint64_t(0)); - const size_t arg1_idx = - std::inner_product(arg1_coord.begin(), arg1_coord.end(), arg1_strides.begin(), uint64_t(0)); - const size_t arg2_idx = - std::inner_product(arg2_coord.begin(), arg2_coord.end(), arg2_strides.begin(), uint64_t(0)); - const size_t output_idx = - std::inner_product(output_coord.begin(), output_coord.end(), output_strides.begin(), uint64_t(0)); + const size_t arg0_idx = coordinate_offset(arg0_coord, arg0_strides); + const size_t arg1_idx = coordinate_offset(arg1_coord, arg1_strides); + const size_t arg2_idx = coordinate_offset(arg2_coord, arg2_strides); + const size_t output_idx = coordinate_offset(output_coord, output_strides); out[output_idx] = elementwise_functor(arg0[arg0_idx], arg1[arg1_idx], arg2[arg2_idx]); } } @@ -446,29 +435,20 @@ void autobroadcast_select(const U* arg0, } } - CoordinateTransformBasic arg0_transform(arg0_squeezed_shape); - CoordinateTransformBasic arg1_transform(arg1_shape); - CoordinateTransformBasic arg2_transform(arg2_squeezed_shape); - CoordinateTransformBasic output_transform(arg1_shape); + const CoordinateTransformBasic output_transform{arg1_shape}; const auto arg0_strides = row_major_strides(arg0_squeezed_shape); const auto arg2_strides = row_major_strides(arg2_squeezed_shape); const auto output_strides = row_major_strides(arg1_shape); for (const Coordinate& output_coord : output_transform) { - NGRAPH_SUPPRESS_DEPRECATED_START - const Coordinate arg0_coord = ngraph::reduce(output_coord, arg0_squeezed_axes, false); - const Coordinate arg2_coord = ngraph::reduce(output_coord, arg2_squeezed_axes, false); - NGRAPH_SUPPRESS_DEPRECATED_END + const auto arg0_coord = util::reduce(output_coord, arg0_squeezed_axes); + const auto arg2_coord = util::reduce(output_coord, arg2_squeezed_axes); - const size_t arg0_idx = - std::inner_product(arg0_coord.begin(), arg0_coord.end(), arg0_strides.begin(), uint64_t(0)); - const size_t arg1_idx = - std::inner_product(output_coord.begin(), output_coord.end(), output_strides.begin(), uint64_t(0)); - const size_t arg2_idx = - std::inner_product(arg2_coord.begin(), arg2_coord.end(), arg2_strides.begin(), uint64_t(0)); - const size_t output_idx = - std::inner_product(output_coord.begin(), output_coord.end(), output_strides.begin(), uint64_t(0)); + const size_t arg0_idx = coordinate_offset(arg0_coord, arg0_strides); + const size_t arg1_idx = coordinate_offset(output_coord, output_strides); + const size_t arg2_idx = coordinate_offset(arg2_coord, arg2_strides); + const size_t output_idx = coordinate_offset(output_coord, output_strides); out[output_idx] = elementwise_functor(arg0[arg0_idx], arg1[arg1_idx], arg2[arg2_idx]); } diff --git a/src/core/reference/include/openvino/reference/avg_pool.hpp b/src/core/reference/include/openvino/reference/avg_pool.hpp index bb7bd25d933..c395ff5c52f 100644 --- a/src/core/reference/include/openvino/reference/avg_pool.hpp +++ b/src/core/reference/include/openvino/reference/avg_pool.hpp @@ -7,227 +7,168 @@ #include #include #include -#include #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/axis_vector.hpp" +#include "openvino/core/coordinate.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/reference/rounding_guard.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { namespace reference { -template -void avg_pool_backprop(const T* delta, - T* out, - const Shape& delta_shape, - const Shape& out_shape, - const Shape& window_shape, - const Strides& window_movement_strides, - const Shape& padding_below, - const Shape& padding_above, - bool include_padding_in_avg_computation) { - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform out_transform(out_shape); - - for (const Coordinate& out_coord : out_transform) { - out[out_transform.index(out_coord)] = 0; - } - - CoordinateTransform delta_transform(delta_shape); - - for (const Coordinate& delta_coord : delta_transform) { - size_t img_index = delta_coord[0]; - size_t channel = delta_coord[1]; - - size_t n_image_dimensions = out_shape.size() - 2; - Coordinate source_window_transform_start(2 + n_image_dimensions); - Coordinate source_window_transform_end(2 + n_image_dimensions); - Strides source_window_transform_source_strides(2 + n_image_dimensions, 1); - AxisVector source_window_transform_source_axis_order(2 + n_image_dimensions); - CoordinateDiff source_window_transform_padding_below(2 + n_image_dimensions); - CoordinateDiff source_window_transform_padding_above(2 + n_image_dimensions); - - source_window_transform_start[0] = img_index; - source_window_transform_end[0] = img_index + 1; - source_window_transform_start[1] = channel; - source_window_transform_end[1] = channel + 1; - source_window_transform_padding_below[0] = 0; - source_window_transform_padding_below[1] = 0; - source_window_transform_padding_above[0] = 0; - source_window_transform_padding_above[1] = 0; - - for (size_t i = 2; i < n_image_dimensions + 2; i++) { - size_t window_shape_this_dim = window_shape[i - 2]; - size_t movement_stride = window_movement_strides[i - 2]; - - source_window_transform_start[i] = movement_stride * delta_coord[i]; - source_window_transform_end[i] = source_window_transform_start[i] + window_shape_this_dim; - source_window_transform_padding_below[i] = padding_below[i - 2]; - source_window_transform_padding_above[i] = padding_above[i - 2]; - } - std::iota(begin(source_window_transform_source_axis_order), end(source_window_transform_source_axis_order), 0); - - CoordinateTransform source_window_transform(out_shape, - source_window_transform_start, - source_window_transform_end, - source_window_transform_source_strides, - source_window_transform_source_axis_order, - source_window_transform_padding_below, - source_window_transform_padding_above); - - size_t num_elements_in_window = 0; - - for (const Coordinate& source_window_coord : source_window_transform) { - if (source_window_transform.has_source_coordinate(source_window_coord) || - include_padding_in_avg_computation) { - num_elements_in_window++; - } - } - - for (const Coordinate& source_window_coord : source_window_transform) { - if (source_window_transform.has_source_coordinate(source_window_coord)) { - size_t out_index = source_window_transform.index(source_window_coord); - out[out_index] += delta[delta_transform.index(delta_coord)] / num_elements_in_window; - } +namespace { +inline bool elem_in_padding_area(const Coordinate& kernel_position, + const Coordinate& kernel_offset, + const Shape& data_shape) { + for (size_t dim = 0; dim + 2 < data_shape.size(); ++dim) { + if (static_cast(kernel_position[dim]) + static_cast(kernel_offset[dim]) < 0LL || + kernel_position[dim] + kernel_offset[dim] >= data_shape[dim + 2]) { + return true; } } - NGRAPH_SUPPRESS_DEPRECATED_END + + return false; } +inline Coordinate calculate_kernel_position(const Coordinate& out_elem_coord, + const Strides& kernel_strides, + const Shape& pads_begin) { + Coordinate top_left_corner; + top_left_corner.reserve(out_elem_coord.size()); + for (size_t i = 0u; i < out_elem_coord.size(); ++i) { + top_left_corner.emplace_back(out_elem_coord[i] * kernel_strides[i] - pads_begin[i]); + } + return top_left_corner; +} + +namespace kernel { +template +void avg_pool_3d(const Values_t* data, + Values_t* out, + const Shape& data_shape, + const Shape& out_shape, + const Shape& kernel, + const Strides& kernel_strides, + const Shape& pads_begin, + const Shape& pads_end, + const bool pads_in_avg) { + // helper constants(axes) denoting dimensions in the input data shape and kernel shape + constexpr size_t data_D = 2, data_H = 3, data_W = 4; + constexpr size_t kernel_D = 0, kernel_H = 1, kernel_W = 2; + + // select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx) + size_t out_idx = 0u; + for (size_t out_channel = 0u; out_channel < out_shape[data_D]; ++out_channel) { + for (size_t out_row = 0u; out_row < out_shape[data_H]; ++out_row) { + for (size_t out_col = 0u; out_col < out_shape[data_W]; ++out_col) { + auto sum = Values_t{0}; + auto count = size_t{0}; + + const auto kernel_position = + calculate_kernel_position({out_channel, out_row, out_col}, kernel_strides, pads_begin); + + for (size_t kernel_channel = 0; kernel_channel < kernel[kernel_D]; ++kernel_channel) { + for (size_t kernel_row = 0; kernel_row < kernel[kernel_H]; ++kernel_row) { + for (size_t kernel_col = 0; kernel_col < kernel[kernel_W]; ++kernel_col) { + // offset from the top-left corner of the kernel for a given row and col + const Coordinate kernel_offset{kernel_channel, kernel_row, kernel_col}; + + const auto in_padding = elem_in_padding_area(kernel_position, kernel_offset, data_shape); + // ignore the elements in the padding area + if (!in_padding) { + // index of the flattened tensor element under the current row & column of the kernel + const size_t data_elem_index = + data_shape[data_H] * data_shape[data_W] * + (kernel_offset[kernel_D] + kernel_position[kernel_D]) + + data_shape[data_W] * (kernel_offset[kernel_H] + kernel_position[kernel_H]) + + kernel_offset[kernel_W] + kernel_position[kernel_W]; + + sum += data[data_elem_index]; + } + if (pads_in_avg || !in_padding) { + ++count; + } + } + } + } + + if (count != 0) { + if (std::is_same::value || std::is_same::value) { + out[out_idx] = static_cast(std::nearbyint(sum / count)); + } else { + out[out_idx] = sum / static_cast(count); + } + } else { + out[out_idx] = Values_t{0}; + } + ++out_idx; + } + } + } +} +} // namespace kernel +} // namespace + template -void avg_pool(const T* arg, - T* out, +void avg_pool(const T* const arg, + T* const out, const Shape& arg_shape, const Shape& out_shape, const Shape& window_shape, const Strides& window_movement_strides, const Shape& padding_below, const Shape& padding_above, - bool include_padding_in_avg_computation) { - NGRAPH_SUPPRESS_DEPRECATED_START - auto old_mode = std::fegetround(); - std::fesetround(FE_TONEAREST); - // At the outermost level we will walk over every output coordinate O. - CoordinateTransform output_transform(out_shape); + const bool include_padding_in_avg_computation) { + if (window_shape.size() > 3) + return; + const RoundingGuard rounding_g{FE_TONEAREST}; - for (const Coordinate& out_coord : output_transform) { - // Our output coordinate O will have the form: - // - // (N,chan,i_1,...,i_n) + const auto not_zero = [](size_t p) { + return p != 0; + }; + const auto pads_in_avg = + include_padding_in_avg_computation && (std::any_of(padding_below.begin(), padding_below.end(), not_zero) || + std::any_of(padding_above.begin(), padding_above.end(), not_zero)); - size_t batch_index = out_coord[0]; - size_t channel = out_coord[1]; + Shape arg_shape_3D{arg_shape}; + Shape out_shape_3D{out_shape}; + Shape window_shape_3D{window_shape}; + Strides window_movement_strides_3D{window_movement_strides}; + Shape padding_below_3D{padding_below}; + Shape padding_above_3D{padding_above}; - // For the input data we need to iterate the coordinate: - // - // I: - // - // over the range (noninclusive on the right): - // - // (N,chan,s_1*i_1,s_2*i_2,...,s_n*i_n) -> - // - // (N+1,chan+1,s_1*i_1 + window_shape_1,...,s_n*i_n + window_shape_n) - // - // with unit stride. - // - // We iterate this over the *padded* data, so below we will need to check for - // coordinates that fall in the padding area. - - size_t n_spatial_dimensions = arg_shape.size() - 2; - - Coordinate input_batch_transform_start(2 + n_spatial_dimensions); - Coordinate input_batch_transform_end(2 + n_spatial_dimensions); - Strides input_batch_transform_source_strides(2 + n_spatial_dimensions, 1); - AxisVector input_batch_transform_source_axis_order(2 + n_spatial_dimensions); - CoordinateDiff input_batch_transform_padding_below(2 + n_spatial_dimensions); - CoordinateDiff input_batch_transform_padding_above(2 + n_spatial_dimensions); - - input_batch_transform_start[0] = batch_index; - input_batch_transform_end[0] = batch_index + 1; - input_batch_transform_start[1] = channel; - input_batch_transform_end[1] = channel + 1; - input_batch_transform_padding_below[0] = 0; - input_batch_transform_padding_below[1] = 0; - input_batch_transform_padding_above[0] = 0; - input_batch_transform_padding_above[1] = 0; - - for (size_t i = 2; i < n_spatial_dimensions + 2; i++) { - size_t window_shape_this_dim = window_shape[i - 2]; - size_t movement_stride = window_movement_strides[i - 2]; - - input_batch_transform_start[i] = movement_stride * out_coord[i]; - input_batch_transform_end[i] = input_batch_transform_start[i] + window_shape_this_dim; - input_batch_transform_padding_below[i] = padding_below[i - 2]; - input_batch_transform_padding_above[i] = padding_above[i - 2]; - // If a window (kernel) is out of arg shape bounds, trim it to fit - auto padded_upper_bound = arg_shape[i] + padding_below[i - 2] + padding_above[i - 2]; - if (input_batch_transform_end[i] > padded_upper_bound) { - input_batch_transform_end[i] = padded_upper_bound; - } - } - - for (size_t i = 0; i < arg_shape.size(); i++) { - input_batch_transform_source_axis_order[i] = i; - } - - CoordinateTransform input_batch_transform(arg_shape, - input_batch_transform_start, - input_batch_transform_end, - input_batch_transform_source_strides, - input_batch_transform_source_axis_order, - input_batch_transform_padding_below, - input_batch_transform_padding_above); - - // As we go, we compute the sum value: - // - // output[O] := output[O] + arg[I] - // - // and the number of elements: - // - // n_elements := n_elements + 1 - - T result = 0; - size_t n_elements = 0; - - // The below conditions are to provide conformance between the ref and plugins: - // If exclude_padding is disabled (include_padding... enabled), then: - // The size of window doesn't change even if the window was clipped to fit the - // input, number of elements will be equal to window_size.width * - // window_size.height. The exception from this rule is if padding is not - // present, then window size is calculated each time. - - auto padding_present = - padding_below[0] != 0 || padding_below[1] != 0 || padding_above[0] != 0 || padding_above[1] != 0; - - if (include_padding_in_avg_computation && padding_present) { - n_elements = shape_size(window_shape); - } - for (const Coordinate& input_batch_coord : input_batch_transform) { - bool in_bounds = input_batch_transform.has_source_coordinate(input_batch_coord); - - if (in_bounds || include_padding_in_avg_computation) { - T v = in_bounds ? arg[input_batch_transform.index(input_batch_coord)] : static_cast(0); - result += v; - if (!padding_present || (in_bounds && !include_padding_in_avg_computation)) { - n_elements++; - } - } - } - - if (n_elements != 0) { - if (std::is_same::value || std::is_same::value) { - out[output_transform.index(out_coord)] = - static_cast(std::nearbyint(static_cast(result) / n_elements)); - } else { - out[output_transform.index(out_coord)] = result / static_cast(n_elements); - } - } else { - out[output_transform.index(out_coord)] = T{0}; - } - - std::fesetround(old_mode); + if (window_shape.size() < 3) { + const size_t dim_diff = 3 - window_shape.size(); + arg_shape_3D.insert(std::next(arg_shape_3D.begin(), 2), dim_diff, 1); + out_shape_3D.insert(std::next(out_shape_3D.begin(), 2), dim_diff, 1); + window_shape_3D.insert(window_shape_3D.begin(), dim_diff, 1); + window_movement_strides_3D.insert(window_movement_strides_3D.begin(), dim_diff, 1); + padding_below_3D.insert(padding_below_3D.begin(), dim_diff, 0); + padding_above_3D.insert(padding_above_3D.begin(), dim_diff, 0); + } + + const auto data_batch_elems = shape_size(std::begin(arg_shape) + 1, std::end(arg_shape)); + const auto data_channel_elems = shape_size(std::begin(arg_shape) + 2, std::end(arg_shape)); + + const auto out_batch_elems = shape_size(std::begin(out_shape) + 1, std::end(out_shape)); + const auto out_channel_elems = shape_size(std::begin(out_shape) + 2, std::end(out_shape)); + + for (size_t b = 0; b < arg_shape[0]; ++b) { + for (size_t c = 0; c < arg_shape[1]; ++c) { + const T* data_channel_first_elem = arg + b * data_batch_elems + c * data_channel_elems; + T* out_channel_first_elem = out + b * out_batch_elems + c * out_channel_elems; + kernel::avg_pool_3d(data_channel_first_elem, + out_channel_first_elem, + arg_shape_3D, + out_shape_3D, + window_shape_3D, + window_movement_strides_3D, + padding_below_3D, + padding_above_3D, + pads_in_avg); + } } - NGRAPH_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/batch_norm.hpp b/src/core/reference/include/openvino/reference/batch_norm.hpp index 15e10cd5c91..1050d78ecaa 100644 --- a/src/core/reference/include/openvino/reference/batch_norm.hpp +++ b/src/core/reference/include/openvino/reference/batch_norm.hpp @@ -7,7 +7,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -26,11 +26,10 @@ void batch_norm_inference(float eps, const T* variance, T* out, const Shape& in_shape) { - NGRAPH_SUPPRESS_DEPRECATED_START auto eps_casted = static_cast(eps); size_t in_idx = 0; - CoordinateTransform in_transform(in_shape); + const CoordinateTransformBasic in_transform{in_shape}; for (Coordinate in_coord : in_transform) { auto ch_num = in_coord[1]; auto ch_gamma = gamma[ch_num]; @@ -42,7 +41,6 @@ void batch_norm_inference(float eps, out[in_idx] = normalized * ch_gamma + ch_beta; in_idx++; } - NGRAPH_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/binary_convolution.hpp b/src/core/reference/include/openvino/reference/binary_convolution.hpp index bf3e16beae2..ba13dad1a29 100644 --- a/src/core/reference/include/openvino/reference/binary_convolution.hpp +++ b/src/core/reference/include/openvino/reference/binary_convolution.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/convolution.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/broadcast.hpp b/src/core/reference/include/openvino/reference/broadcast.hpp index abad1b8ce5e..36767f94cbd 100644 --- a/src/core/reference/include/openvino/reference/broadcast.hpp +++ b/src/core/reference/include/openvino/reference/broadcast.hpp @@ -4,8 +4,8 @@ #pragma once -#include "ngraph/axis_set.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/axis_set.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/bucketize.hpp b/src/core/reference/include/openvino/reference/bucketize.hpp index 2288d15388d..49fd22eaf66 100644 --- a/src/core/reference/include/openvino/reference/bucketize.hpp +++ b/src/core/reference/include/openvino/reference/bucketize.hpp @@ -6,7 +6,7 @@ #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/concat.hpp b/src/core/reference/include/openvino/reference/concat.hpp index 6d210a2244b..13d4499dc85 100644 --- a/src/core/reference/include/openvino/reference/concat.hpp +++ b/src/core/reference/include/openvino/reference/concat.hpp @@ -6,7 +6,7 @@ #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/convert.hpp b/src/core/reference/include/openvino/reference/convert.hpp index a6b843bd5f0..e943e548a8f 100644 --- a/src/core/reference/include/openvino/reference/convert.hpp +++ b/src/core/reference/include/openvino/reference/convert.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/type/element_type.hpp" -#include "ngraph/type/float16.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/core/type/float16.hpp" namespace ov { namespace reference { @@ -123,7 +123,7 @@ size_t count_out_of_f16_range(const float* arg, size_t count); // Convert values from f32 to f16 with claming to f16 min/max when value is out of normal finite numbers range void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t count); -// overload to handle ngraph::boolean (it is stored as char) +// overload to handle ov::boolean (it is stored as char) template typename std::enable_if::value>::type convert(const TI* arg, TO* out, size_t count) { for (size_t i = 0; i < count; ++i) { diff --git a/src/core/reference/include/openvino/reference/convert_color_nv12.hpp b/src/core/reference/include/openvino/reference/convert_color_nv12.hpp index a42aff6184c..110e1caf411 100644 --- a/src/core/reference/include/openvino/reference/convert_color_nv12.hpp +++ b/src/core/reference/include/openvino/reference/convert_color_nv12.hpp @@ -116,9 +116,9 @@ inline bool color_convert_nv12(const std::shared_ptr& op, static const size_t N_DIM = 0; static const size_t H_DIM = 1; static const size_t W_DIM = 2; - NGRAPH_CHECK(op->get_input_size() == 1 || op->get_input_size() == 2, - "NV12 conversion shall have one or 2 inputs, but it is ", - op->get_input_size()); + OPENVINO_ASSERT(op->get_input_size() == 1 || op->get_input_size() == 2, + "NV12 conversion shall have one or 2 inputs, but it is ", + op->get_input_size()); auto single_plane = op->get_input_size() == 1; const auto& y_tensor = inputs[0]; @@ -163,9 +163,9 @@ inline bool color_convert_i420(const std::shared_ptr& op, static const size_t N_DIM = 0; static const size_t H_DIM = 1; static const size_t W_DIM = 2; - NGRAPH_CHECK(op->get_input_size() == 1 || op->get_input_size() == 3, - "I420 conversion shall have one or 3 inputs, but it is ", - op->get_input_size()); + OPENVINO_ASSERT(op->get_input_size() == 1 || op->get_input_size() == 3, + "I420 conversion shall have one or 3 inputs, but it is ", + op->get_input_size()); auto single_plane = op->get_input_size() == 1; const auto& y_tensor = inputs[0]; diff --git a/src/core/reference/include/openvino/reference/convolution.hpp b/src/core/reference/include/openvino/reference/convolution.hpp index 6a26372befe..fb9e68fa3c3 100644 --- a/src/core/reference/include/openvino/reference/convolution.hpp +++ b/src/core/reference/include/openvino/reference/convolution.hpp @@ -6,7 +6,9 @@ #include -#include "ngraph/util.hpp" +#include "openvino/core/coordinate_diff.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/strides.hpp" namespace ov { namespace reference { @@ -260,33 +262,33 @@ inline void validate_convolution_parameters(const Shape& in_shape, const CoordinateDiff& pads_begin, const CoordinateDiff& pads_end) { // this implementation supports 1D, 2D and 3D convolutions - NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); + OPENVINO_ASSERT(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); - NGRAPH_CHECK(in_shape.size() == f_shape.size(), - "Incompatible input ranks: ", - in_shape.size(), - " and ", - f_shape.size()); + OPENVINO_ASSERT(in_shape.size() == f_shape.size(), + "Incompatible input ranks: ", + in_shape.size(), + " and ", + f_shape.size()); - NGRAPH_CHECK(in_shape[in_channel_axis] == f_shape[filter_in_ch_axis], - "Incompatible input channels in data batch and filters shapes: ", - in_shape[in_channel_axis], - " and ", - f_shape[filter_in_ch_axis]); + OPENVINO_ASSERT(in_shape[in_channel_axis] == f_shape[filter_in_ch_axis], + "Incompatible input channels in data batch and filters shapes: ", + in_shape[in_channel_axis], + " and ", + f_shape[filter_in_ch_axis]); - NGRAPH_CHECK(in_shape.size() == out_shape.size(), - "Incompatible input and output ranks: ", - in_shape.size(), - " and ", - out_shape.size()); + OPENVINO_ASSERT(in_shape.size() == out_shape.size(), + "Incompatible input and output ranks: ", + in_shape.size(), + " and ", + out_shape.size()); const auto spatial_dims = in_shape.size() - 2; - NGRAPH_CHECK(strides.size() == spatial_dims, "Strides not definied for all and only spatial dimensions"); + OPENVINO_ASSERT(strides.size() == spatial_dims, "Strides not definied for all and only spatial dimensions"); - NGRAPH_CHECK(dilations.size() == spatial_dims, "Dilations not defined for all and only spatial dimensions"); + OPENVINO_ASSERT(dilations.size() == spatial_dims, "Dilations not defined for all and only spatial dimensions"); - NGRAPH_CHECK((pads_begin.size() == pads_end.size()) && (pads_begin.size() == spatial_dims), - "Pads not defined for all and only spatial dimensions"); + OPENVINO_ASSERT((pads_begin.size() == pads_end.size()) && (pads_begin.size() == spatial_dims), + "Pads not defined for all and only spatial dimensions"); Shape out_spatial_shape{std::next(out_shape.begin(), 2), std::end(out_shape)}; Shape infered_out_spatial_shape{}; @@ -297,7 +299,7 @@ inline void validate_convolution_parameters(const Shape& in_shape, dilations, pads_begin, pads_end); - NGRAPH_CHECK(out_spatial_shape == infered_out_spatial_shape, "Incorrect output shape provided"); + OPENVINO_ASSERT(out_spatial_shape == infered_out_spatial_shape, "Incorrect output shape provided"); } } // namespace diff --git a/src/core/reference/include/openvino/reference/convolution_backprop_data.hpp b/src/core/reference/include/openvino/reference/convolution_backprop_data.hpp index 05d6e8d559a..d1491d68b5f 100644 --- a/src/core/reference/include/openvino/reference/convolution_backprop_data.hpp +++ b/src/core/reference/include/openvino/reference/convolution_backprop_data.hpp @@ -9,8 +9,6 @@ #include #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/util.hpp" #include "openvino/reference/convolution.hpp" #include "openvino/reference/reverse.hpp" @@ -105,36 +103,36 @@ inline void validate_convolution_backprop_parameters(const Shape& in_shape, const CoordinateDiff& pads_end, const CoordinateDiff& output_padding) { // this implementation supports 1D, 2D and 3D convolutions - NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); + OPENVINO_ASSERT(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); - NGRAPH_CHECK(in_shape.size() == f_shape.size(), - "Incompatible input ranks: ", - in_shape.size(), - " and ", - f_shape.size()); + OPENVINO_ASSERT(in_shape.size() == f_shape.size(), + "Incompatible input ranks: ", + in_shape.size(), + " and ", + f_shape.size()); - NGRAPH_CHECK(in_shape[in_channel_axis] == f_shape[filter_input_ch_axis], - "Incompatible input channels in data batch and filters shapes: ", - in_shape[in_channel_axis], - " and ", - f_shape[filter_input_ch_axis]); + OPENVINO_ASSERT(in_shape[in_channel_axis] == f_shape[filter_input_ch_axis], + "Incompatible input channels in data batch and filters shapes: ", + in_shape[in_channel_axis], + " and ", + f_shape[filter_input_ch_axis]); - NGRAPH_CHECK(in_shape.size() == out_shape.size(), - "Incompatible input and output ranks: ", - in_shape.size(), - " and ", - out_shape.size()); + OPENVINO_ASSERT(in_shape.size() == out_shape.size(), + "Incompatible input and output ranks: ", + in_shape.size(), + " and ", + out_shape.size()); const auto spatial_dims = in_shape.size() - 2; - NGRAPH_CHECK(strides.size() == spatial_dims, "Strides not definied for all and only spatial dimensions."); + OPENVINO_ASSERT(strides.size() == spatial_dims, "Strides not definied for all and only spatial dimensions."); - NGRAPH_CHECK(dilations.size() == spatial_dims, "Dilations not defined for all and only spatial dimensions."); + OPENVINO_ASSERT(dilations.size() == spatial_dims, "Dilations not defined for all and only spatial dimensions."); - NGRAPH_CHECK((pads_begin.size() == pads_end.size()) && (pads_begin.size() == spatial_dims), - "Pads not defined for all and only spatial dimensions."); + OPENVINO_ASSERT((pads_begin.size() == pads_end.size()) && (pads_begin.size() == spatial_dims), + "Pads not defined for all and only spatial dimensions."); - NGRAPH_CHECK(!output_padding.empty() && output_padding.size() == spatial_dims, - "Output padding not defined for all and only spatial dimensions."); + OPENVINO_ASSERT(!output_padding.empty() && output_padding.size() == spatial_dims, + "Output padding not defined for all and only spatial dimensions."); Shape out_spatial_shape{std::next(out_shape.begin(), 2), std::end(out_shape)}; Shape infered_out_spatial_shape{}; @@ -145,7 +143,7 @@ inline void validate_convolution_backprop_parameters(const Shape& in_shape, strides, dilations, output_padding); - NGRAPH_CHECK(out_spatial_shape == infered_out_spatial_shape, "Incorrect output shape provided"); + OPENVINO_ASSERT(out_spatial_shape == infered_out_spatial_shape, "Incorrect output shape provided"); } } // namespace diff --git a/src/core/reference/include/openvino/reference/ctc_greedy_decoder.hpp b/src/core/reference/include/openvino/reference/ctc_greedy_decoder.hpp index 1cab6887744..bb19527c659 100644 --- a/src/core/reference/include/openvino/reference/ctc_greedy_decoder.hpp +++ b/src/core/reference/include/openvino/reference/ctc_greedy_decoder.hpp @@ -8,7 +8,7 @@ #include #include -#include "openvino/reference/utils/coordinate_transform.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" namespace ov { namespace reference { @@ -20,16 +20,11 @@ void ctc_greedy_decoder(const T* data, const Shape& sequence_masks_shape, const Shape& out_shape, const bool ctc_merge_repeated) { - OPENVINO_SUPPRESS_DEPRECATED_START const auto max_seq_len = data_shape[0]; const auto batch_size = data_shape[1]; const auto class_count = data_shape[2]; const uint64_t blank_index = class_count - 1; - CoordinateTransform out_transform = CoordinateTransform(out_shape); - CoordinateTransform data_transform = CoordinateTransform(data_shape); - CoordinateTransform seq_masks_transform = CoordinateTransform(sequence_masks_shape); - // final sequences don't have to fill the whole output, elements that don't store // information are set to -1 @@ -38,10 +33,10 @@ void ctc_greedy_decoder(const T* data, for (unsigned int batch_ind = 0; batch_ind < batch_size; batch_ind++) { T previous_class_index = static_cast(-1); - auto out_index = out_transform.index({batch_ind, 0, 0, 0}); + auto out_index = coordinate_index({batch_ind, 0, 0, 0}, out_shape); for (unsigned int seq_ind = 0; seq_ind < max_seq_len; seq_ind++) { - auto data_index = data_transform.index({seq_ind, batch_ind, 0}); - auto mask_index = seq_masks_transform.index({seq_ind, batch_ind}); + auto data_index = coordinate_index({seq_ind, batch_ind, 0}, data_shape); + auto mask_index = coordinate_index({seq_ind, batch_ind}, sequence_masks_shape); if (sequence_masks[mask_index] == T{0}) { break; @@ -59,7 +54,6 @@ void ctc_greedy_decoder(const T* data, } } std::copy(tmp_out.begin(), tmp_out.end(), out); - OPENVINO_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/ctc_loss.hpp b/src/core/reference/include/openvino/reference/ctc_loss.hpp index 1b16b352eb4..7e12b9abf84 100644 --- a/src/core/reference/include/openvino/reference/ctc_loss.hpp +++ b/src/core/reference/include/openvino/reference/ctc_loss.hpp @@ -4,11 +4,11 @@ #pragma once -#include - +#include +#include #include -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/deformable_convolution.hpp b/src/core/reference/include/openvino/reference/deformable_convolution.hpp index e3b553352c2..a6255782263 100644 --- a/src/core/reference/include/openvino/reference/deformable_convolution.hpp +++ b/src/core/reference/include/openvino/reference/deformable_convolution.hpp @@ -21,17 +21,17 @@ inline void validate_deformable_convolution_params(const Shape& in_shape, const int64_t groups, const int64_t deformable_groups) { // this implementation supports 2D deformable convolutions - NGRAPH_CHECK(in_shape.size() == 4, "Unsupported input rank: ", in_shape); - NGRAPH_CHECK(o_shape.size() == 4, "Unsupported offset rank: ", o_shape); - NGRAPH_CHECK(f_shape.size() == 4, "Unsupported kernel rank: ", f_shape); - NGRAPH_CHECK(m_shape.size() == 4, "Unsupported mask rank: ", m_shape); + OPENVINO_ASSERT(in_shape.size() == 4, "Unsupported input rank: ", in_shape); + OPENVINO_ASSERT(o_shape.size() == 4, "Unsupported offset rank: ", o_shape); + OPENVINO_ASSERT(f_shape.size() == 4, "Unsupported kernel rank: ", f_shape); + OPENVINO_ASSERT(m_shape.size() == 4, "Unsupported mask rank: ", m_shape); - NGRAPH_CHECK(in_shape[1] % groups == 0, - "Input channels of data batch input must be evenly divisible by " - "'groups' attribute"); - NGRAPH_CHECK(f_shape[0] % groups == 0, - "Output channels of filters must be evenly divisible by 'groups' " - "attribute"); + OPENVINO_ASSERT(in_shape[1] % groups == 0, + "Input channels of data batch input must be evenly divisible by " + "'groups' attribute"); + OPENVINO_ASSERT(f_shape[0] % groups == 0, + "Output channels of filters must be evenly divisible by 'groups' " + "attribute"); const Shape scaled_f_shape = [f_shape](int64_t g) { Shape shape{f_shape}; @@ -46,14 +46,15 @@ inline void validate_deformable_convolution_params(const Shape& in_shape, const Shape m_spatial_shape{std::next(m_shape.begin(), 2), std::end(m_shape)}; const Shape out_spatial_shape{std::next(out_shape.begin(), 2), std::end(out_shape)}; - NGRAPH_CHECK(o_shape[1] == deformable_groups * shape_size(f_spatial_shape) * 2, - "The channels dimension of offsets input is not " - "compatible with filters and 'deformable group' attribute"); - NGRAPH_CHECK(m_shape[1] == deformable_groups * shape_size(f_spatial_shape), - "The channels dimension of mask input is not " - "compatible with filters and 'deformable group' attribute"); - NGRAPH_CHECK(out_spatial_shape == o_spatial_shape, "Spatial dimensions of output and offsets values must be equal"); - NGRAPH_CHECK(out_spatial_shape == m_spatial_shape, "Spatial dimensions of output and mask values must be equal"); + OPENVINO_ASSERT(o_shape[1] == deformable_groups * shape_size(f_spatial_shape) * 2, + "The channels dimension of offsets input is not " + "compatible with filters and 'deformable group' attribute"); + OPENVINO_ASSERT(m_shape[1] == deformable_groups * shape_size(f_spatial_shape), + "The channels dimension of mask input is not " + "compatible with filters and 'deformable group' attribute"); + OPENVINO_ASSERT(out_spatial_shape == o_spatial_shape, + "Spatial dimensions of output and offsets values must be equal"); + OPENVINO_ASSERT(out_spatial_shape == m_spatial_shape, "Spatial dimensions of output and mask values must be equal"); } inline Shape shape_reduce(const Shape& s) { @@ -295,7 +296,7 @@ void deformable_convolution(const T* in, const int64_t deformable_groups, const bool bilinear_interpolation_pad = false) { Shape m_shape = {o_shape[0], o_shape[1] / 2, o_shape[2], o_shape[3]}; - std::vector mask(ngraph::shape_size(m_shape), 1); + std::vector mask(shape_size(m_shape), 1); deformable_convolution(in, offsets, filters, diff --git a/src/core/reference/include/openvino/reference/deformable_psroi_pooling.hpp b/src/core/reference/include/openvino/reference/deformable_psroi_pooling.hpp index 5c1d5fed7df..1d62f34f4f9 100644 --- a/src/core/reference/include/openvino/reference/deformable_psroi_pooling.hpp +++ b/src/core/reference/include/openvino/reference/deformable_psroi_pooling.hpp @@ -14,7 +14,7 @@ #include #include "clamp.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/depth_to_space.hpp b/src/core/reference/include/openvino/reference/depth_to_space.hpp index d895354aad0..93b957f93ea 100644 --- a/src/core/reference/include/openvino/reference/depth_to_space.hpp +++ b/src/core/reference/include/openvino/reference/depth_to_space.hpp @@ -4,8 +4,8 @@ #pragma once -#include "ngraph/op/depth_to_space.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/depth_to_space.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/detection_output.hpp b/src/core/reference/include/openvino/reference/detection_output.hpp index b55ad70916d..f5e79988ba7 100644 --- a/src/core/reference/include/openvino/reference/detection_output.hpp +++ b/src/core/reference/include/openvino/reference/detection_output.hpp @@ -6,12 +6,11 @@ #include #include -#include #include -#include "ngraph/op/detection_output.hpp" -#include "ngraph/op/util/detection_output_base.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/detection_output.hpp" +#include "openvino/op/util/detection_output_base.hpp" namespace ov { namespace reference { @@ -28,7 +27,7 @@ private: }; using LabelBBox = std::map>; - ngraph::op::util::DetectionOutputBase::AttributesBase attrs; + op::util::DetectionOutputBase::AttributesBase attrs; size_t numImages; size_t priorSize; size_t numPriors; @@ -417,10 +416,10 @@ private: } public: - referenceDetectionOutput(const ngraph::op::DetectionOutputAttrs& _attrs, - const ngraph::Shape& locShape, - const ngraph::Shape& priorsShape, - const ngraph::Shape& outShape) + referenceDetectionOutput(const op::v0::DetectionOutput::Attributes& _attrs, + const Shape& locShape, + const Shape& priorsShape, + const Shape& outShape) : attrs(_attrs) { numImages = locShape[0]; priorSize = _attrs.normalized ? 4 : 5; @@ -433,11 +432,11 @@ public: outTotalSize = shape_size(outShape); } - referenceDetectionOutput(const ngraph::op::util::DetectionOutputBase::AttributesBase& _attrs, - const ngraph::Shape& locShape, - const ngraph::Shape& classPredShape, - const ngraph::Shape& priorsShape, - const ngraph::Shape& outShape) + referenceDetectionOutput(const op::util::DetectionOutputBase::AttributesBase& _attrs, + const Shape& locShape, + const Shape& classPredShape, + const Shape& priorsShape, + const Shape& outShape) : attrs(_attrs) { numImages = locShape[0]; priorSize = _attrs.normalized ? 4 : 5; diff --git a/src/core/reference/include/openvino/reference/divide.hpp b/src/core/reference/include/openvino/reference/divide.hpp index 858d8f4f696..08b75017c29 100644 --- a/src/core/reference/include/openvino/reference/divide.hpp +++ b/src/core/reference/include/openvino/reference/divide.hpp @@ -8,10 +8,10 @@ #include #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" -#include "ngraph/type/bfloat16.hpp" -#include "ngraph/type/float16.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/type/float16.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/einsum.hpp b/src/core/reference/include/openvino/reference/einsum.hpp index 8e477959bcb..c1a42524d50 100644 --- a/src/core/reference/include/openvino/reference/einsum.hpp +++ b/src/core/reference/include/openvino/reference/einsum.hpp @@ -5,9 +5,8 @@ #pragma once #include -#include -#include "ngraph/shape.hpp" +#include "openvino/runtime/tensor.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/embedding_bag_offsets_sum.hpp b/src/core/reference/include/openvino/reference/embedding_bag_offsets_sum.hpp index d6b03e1fbc5..0d87538de89 100644 --- a/src/core/reference/include/openvino/reference/embedding_bag_offsets_sum.hpp +++ b/src/core/reference/include/openvino/reference/embedding_bag_offsets_sum.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/embedding_bag_packed_sum.hpp b/src/core/reference/include/openvino/reference/embedding_bag_packed_sum.hpp index f16b2355b94..678b7495c73 100644 --- a/src/core/reference/include/openvino/reference/embedding_bag_packed_sum.hpp +++ b/src/core/reference/include/openvino/reference/embedding_bag_packed_sum.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/embedding_segments_sum.hpp b/src/core/reference/include/openvino/reference/embedding_segments_sum.hpp index f11947fac9b..557fd248d0d 100644 --- a/src/core/reference/include/openvino/reference/embedding_segments_sum.hpp +++ b/src/core/reference/include/openvino/reference/embedding_segments_sum.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/equal.hpp b/src/core/reference/include/openvino/reference/equal.hpp index 62554f9b4a2..c81d47c23d1 100644 --- a/src/core/reference/include/openvino/reference/equal.hpp +++ b/src/core/reference/include/openvino/reference/equal.hpp @@ -11,8 +11,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/erf.hpp b/src/core/reference/include/openvino/reference/erf.hpp index 09ff245dd49..ea69fe98bd6 100644 --- a/src/core/reference/include/openvino/reference/erf.hpp +++ b/src/core/reference/include/openvino/reference/erf.hpp @@ -8,9 +8,6 @@ #include #include -#include "ngraph/type/bfloat16.hpp" -#include "ngraph/type/float16.hpp" - namespace ov { namespace reference { template ::value, bool>::type = true> diff --git a/src/core/reference/include/openvino/reference/experimental_detectron_detection_output.hpp b/src/core/reference/include/openvino/reference/experimental_detectron_detection_output.hpp index c2ba17605e6..52e3602897c 100644 --- a/src/core/reference/include/openvino/reference/experimental_detectron_detection_output.hpp +++ b/src/core/reference/include/openvino/reference/experimental_detectron_detection_output.hpp @@ -16,16 +16,11 @@ #pragma once -#include #include #include -#include #include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/op/experimental_detectron_detection_output.hpp" namespace ov { namespace reference { @@ -41,7 +36,7 @@ void experimental_detectron_detection_output(const float* input_rois, void experimental_detectron_detection_output_postprocessing(void* pboxes, void* pclasses, void* pscores, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& output_boxes, const std::vector& output_classes, const std::vector& output_scores, diff --git a/src/core/reference/include/openvino/reference/experimental_detectron_prior_grid_generator.hpp b/src/core/reference/include/openvino/reference/experimental_detectron_prior_grid_generator.hpp index 565baefb5e1..c5437649d54 100644 --- a/src/core/reference/include/openvino/reference/experimental_detectron_prior_grid_generator.hpp +++ b/src/core/reference/include/openvino/reference/experimental_detectron_prior_grid_generator.hpp @@ -18,13 +18,8 @@ #include #include -#include -#include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/experimental_detectron_proposal_single_image.hpp b/src/core/reference/include/openvino/reference/experimental_detectron_proposal_single_image.hpp index 4e890c051e3..68d7f7e6889 100644 --- a/src/core/reference/include/openvino/reference/experimental_detectron_proposal_single_image.hpp +++ b/src/core/reference/include/openvino/reference/experimental_detectron_proposal_single_image.hpp @@ -4,16 +4,11 @@ #pragma once -#include #include #include -#include #include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/op/experimental_detectron_generate_proposals.hpp" namespace ov { namespace reference { @@ -32,7 +27,7 @@ void experimental_detectron_proposals_single_image( void experimental_detectron_proposals_single_image_postprocessing(void* prois, void* pscores, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& output_rois, const std::vector& output_scores, const Shape& output_rois_shape, diff --git a/src/core/reference/include/openvino/reference/experimental_detectron_roi_feature_extractor.hpp b/src/core/reference/include/openvino/reference/experimental_detectron_roi_feature_extractor.hpp index df8edf00986..80d283d2dad 100644 --- a/src/core/reference/include/openvino/reference/experimental_detectron_roi_feature_extractor.hpp +++ b/src/core/reference/include/openvino/reference/experimental_detectron_roi_feature_extractor.hpp @@ -6,13 +6,9 @@ #include #include -#include #include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/op/experimental_detectron_roi_feature.hpp" namespace ov { namespace reference { @@ -25,7 +21,7 @@ void experimental_detectron_roi_feature_extractor( void experimental_detectron_roi_feature_extractor_postprocessing(void* prois_features, void* prois, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& output_roi_features, const std::vector& output_rois, const Shape& output_roi_features_shape, diff --git a/src/core/reference/include/openvino/reference/experimental_detectron_topk_rois.hpp b/src/core/reference/include/openvino/reference/experimental_detectron_topk_rois.hpp index 8f29fc33161..876874c2a3a 100644 --- a/src/core/reference/include/openvino/reference/experimental_detectron_topk_rois.hpp +++ b/src/core/reference/include/openvino/reference/experimental_detectron_topk_rois.hpp @@ -4,16 +4,11 @@ #pragma once -#include #include #include -#include #include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/extract_image_patches.hpp b/src/core/reference/include/openvino/reference/extract_image_patches.hpp index c6130c38906..24840917856 100644 --- a/src/core/reference/include/openvino/reference/extract_image_patches.hpp +++ b/src/core/reference/include/openvino/reference/extract_image_patches.hpp @@ -2,9 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include - -#include "ngraph/shape_util.hpp" +#include "openvino/op/extractimagepatches.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/eye.hpp b/src/core/reference/include/openvino/reference/eye.hpp index 328c7f942a2..09916370315 100644 --- a/src/core/reference/include/openvino/reference/eye.hpp +++ b/src/core/reference/include/openvino/reference/eye.hpp @@ -6,7 +6,7 @@ #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "utils/span.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/fake_quantize.hpp b/src/core/reference/include/openvino/reference/fake_quantize.hpp index a9122ede16e..d0828cd2308 100644 --- a/src/core/reference/include/openvino/reference/fake_quantize.hpp +++ b/src/core/reference/include/openvino/reference/fake_quantize.hpp @@ -11,10 +11,10 @@ #include #include -#include "ngraph/check.hpp" -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/core/except.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/shape_util.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -65,11 +65,11 @@ void fake_quantize(const T* const arg, out[i] = q(arg[i]); } } else { - NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() && in_high_shape.size() <= arg_shape.size() && - out_low_shape.size() <= arg_shape.size() && out_high_shape.size() <= arg_shape.size(), - "Tensors with input\\output ranges should have rank less or " - "equal to data tensor rank equal to ", - arg_shape.size()); + OPENVINO_ASSERT(in_low_shape.size() <= arg_shape.size() && in_high_shape.size() <= arg_shape.size() && + out_low_shape.size() <= arg_shape.size() && out_high_shape.size() <= arg_shape.size(), + "Tensors with input\\output ranges should have rank less or " + "equal to data tensor rank equal to ", + arg_shape.size()); Shape arg0_padded_shape = arg_shape; Shape arg1_padded_shape = in_low_shape; @@ -156,13 +156,11 @@ void fake_quantize(const T* const arg, const auto output_strides = row_major_strides(output_shape); for (const Coordinate& output_coord : output_transform) { - OPENVINO_SUPPRESS_DEPRECATED_START - const Coordinate arg0_coord = ngraph::reduce(output_coord, arg0_squeezed_axes, false); - const Coordinate arg1_coord = ngraph::reduce(output_coord, arg1_squeezed_axes, false); - const Coordinate arg2_coord = ngraph::reduce(output_coord, arg2_squeezed_axes, false); - const Coordinate arg3_coord = ngraph::reduce(output_coord, arg3_squeezed_axes, false); - const Coordinate arg4_coord = ngraph::reduce(output_coord, arg4_squeezed_axes, false); - OPENVINO_SUPPRESS_DEPRECATED_END + const auto arg0_coord = util::reduce(output_coord, arg0_squeezed_axes); + const auto arg1_coord = util::reduce(output_coord, arg1_squeezed_axes); + const auto arg2_coord = util::reduce(output_coord, arg2_squeezed_axes); + const auto arg3_coord = util::reduce(output_coord, arg3_squeezed_axes); + const auto arg4_coord = util::reduce(output_coord, arg4_squeezed_axes); const size_t arg0_idx = std::inner_product(arg0_coord.begin(), arg0_coord.end(), arg0_strides.begin(), uint64_t(0)); diff --git a/src/core/reference/include/openvino/reference/floor_mod.hpp b/src/core/reference/include/openvino/reference/floor_mod.hpp index 09add88410d..2c63b92310c 100644 --- a/src/core/reference/include/openvino/reference/floor_mod.hpp +++ b/src/core/reference/include/openvino/reference/floor_mod.hpp @@ -7,7 +7,6 @@ #include #include -#include "ngraph/shape_util.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/gather.hpp b/src/core/reference/include/openvino/reference/gather.hpp index 30a52889a7b..4324e0ffc5d 100644 --- a/src/core/reference/include/openvino/reference/gather.hpp +++ b/src/core/reference/include/openvino/reference/gather.hpp @@ -6,7 +6,7 @@ #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "utils/span.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/gather_tree.hpp b/src/core/reference/include/openvino/reference/gather_tree.hpp index 5ac49c4c337..df0e581f0df 100644 --- a/src/core/reference/include/openvino/reference/gather_tree.hpp +++ b/src/core/reference/include/openvino/reference/gather_tree.hpp @@ -4,8 +4,8 @@ #pragma once -#include "ngraph/shape.hpp" -#include "ngraph/type/element_type.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/type/element_type.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/gelu.hpp b/src/core/reference/include/openvino/reference/gelu.hpp index d67bcd8827d..091887c4f9d 100644 --- a/src/core/reference/include/openvino/reference/gelu.hpp +++ b/src/core/reference/include/openvino/reference/gelu.hpp @@ -4,9 +4,9 @@ #pragma once -#include #include -#include + +#include "openvino/op/gelu.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/generate_proposal.hpp b/src/core/reference/include/openvino/reference/generate_proposal.hpp index ae29efa0cc3..64bbfc6e55d 100644 --- a/src/core/reference/include/openvino/reference/generate_proposal.hpp +++ b/src/core/reference/include/openvino/reference/generate_proposal.hpp @@ -4,16 +4,11 @@ #pragma once -#include #include #include -#include #include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/op/generate_proposals.hpp" namespace ov { namespace reference { @@ -33,8 +28,8 @@ void generate_proposals(const std::vector& im_info, void generate_proposals_postprocessing(void* prois, void* pscores, void* proi_num, - const ngraph::element::Type& output_type, - const ngraph::element::Type& roi_num_type, + const element::Type& output_type, + const element::Type& roi_num_type, const std::vector& output_rois, const std::vector& output_scores, const std::vector& num_rois, diff --git a/src/core/reference/include/openvino/reference/greater.hpp b/src/core/reference/include/openvino/reference/greater.hpp index 41128fc1c5d..2dff5e6c489 100644 --- a/src/core/reference/include/openvino/reference/greater.hpp +++ b/src/core/reference/include/openvino/reference/greater.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/greater_eq.hpp b/src/core/reference/include/openvino/reference/greater_eq.hpp index 5072604a413..79f66e3280f 100644 --- a/src/core/reference/include/openvino/reference/greater_eq.hpp +++ b/src/core/reference/include/openvino/reference/greater_eq.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/grid_sample.hpp b/src/core/reference/include/openvino/reference/grid_sample.hpp index 9ad50767dc2..88c071538cc 100644 --- a/src/core/reference/include/openvino/reference/grid_sample.hpp +++ b/src/core/reference/include/openvino/reference/grid_sample.hpp @@ -10,7 +10,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/op/grid_sample.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/group_convolution_backprop_data.hpp b/src/core/reference/include/openvino/reference/group_convolution_backprop_data.hpp index d64c63734a1..66d2a6f431d 100644 --- a/src/core/reference/include/openvino/reference/group_convolution_backprop_data.hpp +++ b/src/core/reference/include/openvino/reference/group_convolution_backprop_data.hpp @@ -4,7 +4,8 @@ #pragma once -#include "ngraph/util.hpp" +#include "openvino/core/coordinate_diff.hpp" +#include "openvino/core/strides.hpp" #include "openvino/reference/group_convolution.hpp" namespace ov { @@ -100,8 +101,8 @@ void group_convolution_backprop_data(const T* in, // DEPRECATED, can't be removed currently due to arm-plugin dependency template ::type> -NGRAPH_DEPRECATED("group_convolution_backprop_data function without output_paddings is deprecated, " - "use the one with output_padding.") +OPENVINO_DEPRECATED("group_convolution_backprop_data function without output_paddings is deprecated, " + "use the one with output_padding.") void group_convolution_backprop_data(const INPUT* in, const FILTER* f, OUTPUT* out, @@ -112,7 +113,7 @@ void group_convolution_backprop_data(const INPUT* in, const Strides& dilation, const CoordinateDiff& pads_begin, const CoordinateDiff& pads_end) { - const ngraph::CoordinateDiff output_padding(in_shape.size() - 2, 0); + const CoordinateDiff output_padding(in_shape.size() - 2, 0); group_convolution_backprop_data(in, f, diff --git a/src/core/reference/include/openvino/reference/interpolate.hpp b/src/core/reference/include/openvino/reference/interpolate.hpp index 10b6a466f22..13fb11c1620 100644 --- a/src/core/reference/include/openvino/reference/interpolate.hpp +++ b/src/core/reference/include/openvino/reference/interpolate.hpp @@ -13,16 +13,16 @@ #include #include "interpolate_pil.hpp" -#include "ngraph/op/interpolate.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/op/interpolate.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" #include "transpose.hpp" namespace ov { namespace reference { -using Nearest_mode = ngraph::op::v4::Interpolate::NearestMode; -using Transform_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode; -using InterpolateMode = ngraph::op::v4::Interpolate::InterpolateMode; +using Nearest_mode = op::v4::Interpolate::NearestMode; +using Transform_mode = op::v4::Interpolate::CoordinateTransformMode; +using InterpolateMode = op::v4::Interpolate::InterpolateMode; /// \brief Calculation of nearest pixel. class GetNearestPixel final { @@ -211,7 +211,7 @@ public: float prod_a; std::vector a; std::vector r; - Shape shape_for_indeces; + Shape shape_for_indices; }; InfoForLinearMode get_info_for_linear_mode(); @@ -362,9 +362,7 @@ template void InterpolateEval::linear_func(const T* input_data, T* out) { auto info = helper.get_info_for_linear_mode(); - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform output_transform(m_out_shape); - CoordinateTransform input_transform(m_input_data_shape); + const CoordinateTransformBasic output_transform{m_out_shape}; for (const Coordinate& output_coord : output_transform) { auto icoords_data = helper.get_icoords(output_coord); @@ -372,29 +370,30 @@ void InterpolateEval::linear_func(const T* input_data, T* out) { float summa = 0.0f; float wsum = 0.0f; - CoordinateTransform indices{info.shape_for_indeces}; + const CoordinateTransformBasic indices{info.shape_for_indices}; for (const auto& index : indices) { auto inner_result = helper.inner_calculation(output_coord, icoords_data, info, index); if (!inner_result.condition) { continue; } + const auto input_index = coordinate_index(inner_result.inner_coord, m_input_data_shape); wsum += inner_result.w; - summa += inner_result.w * static_cast(input_data[input_transform.index(inner_result.inner_coord)]); + summa += inner_result.w * static_cast(input_data[input_index]); } + const auto out_index = coordinate_index(output_coord, m_out_shape); if (wsum == 0.0f) { - out[output_transform.index(output_coord)] = T{}; + out[out_index] = T{}; } else { if (std::is_integral()) { // Round value for integral return types - out[output_transform.index(output_coord)] = static_cast(std::round(summa / wsum)); + out[out_index] = static_cast(std::round(summa / wsum)); } else { - out[output_transform.index(output_coord)] = static_cast(summa / wsum); + out[out_index] = static_cast(summa / wsum); } } } - NGRAPH_SUPPRESS_DEPRECATED_END } template @@ -532,9 +531,7 @@ void InterpolateEval::cubic_func(const T* input_data, T* out) { size_t input_rank = m_input_data_shape.size(); size_t num_of_axes = m_axes.size(); - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform output_transform(m_out_shape); - CoordinateTransform input_transform(m_input_data_shape); + const CoordinateTransformBasic output_transform{m_out_shape}; Shape indices_shape{std::vector(num_of_axes, 4)}; for (const Coordinate& output_coord : output_transform) { @@ -551,7 +548,7 @@ void InterpolateEval::cubic_func(const T* input_data, T* out) { } float summa = 0.0f; - CoordinateTransform indices{indices_shape}; + const CoordinateTransformBasic indices{indices_shape}; for (const Coordinate& idx : indices) { auto coords_for_sum = output_coord; @@ -567,12 +564,12 @@ void InterpolateEval::cubic_func(const T* input_data, T* out) { coeffs_prod *= cubic_coeffs[axis][idx[i]]; } - summa += coeffs_prod * static_cast(input_data[input_transform.index(coords_for_sum)]); + const auto input_index = coordinate_index(coords_for_sum, m_input_data_shape); + summa += coeffs_prod * static_cast(input_data[input_index]); } - out[output_transform.index(output_coord)] = static_cast(summa); + out[coordinate_index(output_coord, m_out_shape)] = static_cast(summa); } - NGRAPH_SUPPRESS_DEPRECATED_END } template @@ -677,15 +674,14 @@ void InterpolateEval::multidim_pil_func(const T* input_data, T* out, const in template void InterpolateEval::nearest_func(const T* input_data, T* out) { - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform output_transform(m_out_shape); - CoordinateTransform input_transform(m_input_data_shape); + const CoordinateTransformBasic output_transform{m_out_shape}; for (const Coordinate& output_coord : output_transform) { auto input_coord = helper.get_input_coords_for_nearest_mode(output_coord); - out[output_transform.index(output_coord)] = input_data[input_transform.index(input_coord)]; + const auto input_index = coordinate_index(input_coord, m_input_data_shape); + const auto out_index = coordinate_index(output_coord, m_out_shape); + out[out_index] = input_data[input_index]; } - NGRAPH_SUPPRESS_DEPRECATED_END } inline void pad_input_data(const uint8_t* data_ptr, @@ -694,9 +690,7 @@ inline void pad_input_data(const uint8_t* data_ptr, const ov::Shape& input_shape, const ov::Shape& padded_input_shape, const std::vector& pads_begin) { - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform input_transform(input_shape); - CoordinateTransform padded_transform(padded_input_shape); + const CoordinateTransformBasic input_transform{input_shape}; for (const Coordinate& input_coord : input_transform) { auto padded_coord = input_coord; @@ -705,11 +699,10 @@ inline void pad_input_data(const uint8_t* data_ptr, padded_coord[i] += pad; ++i; } - uint8_t* dst_ptr = padded_data_ptr + type_size * padded_transform.index(padded_coord); - const uint8_t* src_ptr = data_ptr + type_size * input_transform.index(input_coord); + uint8_t* dst_ptr = padded_data_ptr + type_size * coordinate_index(padded_coord, padded_input_shape); + const uint8_t* src_ptr = data_ptr + type_size * coordinate_index(input_coord, input_shape); memcpy(dst_ptr, src_ptr, type_size); } - NGRAPH_SUPPRESS_DEPRECATED_END } inline PartialShape get_padded_input_shape(const PartialShape& input_shape, diff --git a/src/core/reference/include/openvino/reference/interpolate_pil.hpp b/src/core/reference/include/openvino/reference/interpolate_pil.hpp index 66a40c8f88c..d57875cc538 100644 --- a/src/core/reference/include/openvino/reference/interpolate_pil.hpp +++ b/src/core/reference/include/openvino/reference/interpolate_pil.hpp @@ -40,9 +40,9 @@ #include #include +#include -#include "ngraph/op/interpolate.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/irdft.hpp b/src/core/reference/include/openvino/reference/irdft.hpp index a32c20e9765..0ee03fda858 100644 --- a/src/core/reference/include/openvino/reference/irdft.hpp +++ b/src/core/reference/include/openvino/reference/irdft.hpp @@ -6,7 +6,7 @@ #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/less.hpp b/src/core/reference/include/openvino/reference/less.hpp index c4519726589..21d2321f566 100644 --- a/src/core/reference/include/openvino/reference/less.hpp +++ b/src/core/reference/include/openvino/reference/less.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/less_eq.hpp b/src/core/reference/include/openvino/reference/less_eq.hpp index 80aff2fad73..d4ab3c2775b 100644 --- a/src/core/reference/include/openvino/reference/less_eq.hpp +++ b/src/core/reference/include/openvino/reference/less_eq.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/log_softmax.hpp b/src/core/reference/include/openvino/reference/log_softmax.hpp index 7335bd36989..710b605e850 100644 --- a/src/core/reference/include/openvino/reference/log_softmax.hpp +++ b/src/core/reference/include/openvino/reference/log_softmax.hpp @@ -6,40 +6,39 @@ #include -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape_util.hpp" #include "openvino/reference/reduce_max.hpp" #include "openvino/reference/reduce_sum.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { namespace reference { template void log_softmax(const T* arg, T* out, const Shape& shape, const AxisSet& axes) { - NGRAPH_SUPPRESS_DEPRECATED_START - auto temp_shape = ngraph::reduce(shape, axes, true); - auto temp_elements = shape_size(temp_shape); + const auto temp_shape = util::reduce_keep_dims(shape, axes); + const auto temp_elements = shape_size(temp_shape); auto temp_max = std::vector(temp_elements, 0); auto temp_sum = std::vector(temp_elements, 0); reduce_max(arg, temp_max.data(), shape, axes); - CoordinateTransform transform(shape); - CoordinateTransform temp_transform(temp_shape); + const CoordinateTransformBasic transform{shape}; for (const Coordinate& coord : transform) { - Coordinate temp_coord = ngraph::reduce(coord, axes, true); - out[transform.index(coord)] = - static_cast(std::exp(arg[transform.index(coord)] - temp_max[temp_transform.index(temp_coord)])); + const Coordinate temp_coord = util::reduce_keep_dims(coord, axes); + const auto out_index = coordinate_index(coord, shape); + const auto temp_index = coordinate_index(temp_coord, temp_shape); + out[out_index] = static_cast(std::exp(arg[out_index] - temp_max[temp_index])); } reduce_sum(out, temp_sum.data(), shape, axes); for (const Coordinate& coord : transform) { - Coordinate temp_coord = ngraph::reduce(coord, axes, true); - out[transform.index(coord)] = - static_cast((arg[transform.index(coord)] - temp_max[temp_transform.index(temp_coord)]) - - std::log(temp_sum[temp_transform.index(temp_coord)])); + const Coordinate temp_coord = util::reduce_keep_dims(coord, axes); + const auto out_index = coordinate_index(coord, shape); + const auto temp_index = coordinate_index(temp_coord, temp_shape); + out[out_index] = static_cast((arg[out_index] - temp_max[temp_index]) - std::log(temp_sum[temp_index])); } - NGRAPH_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/logical_reduction.hpp b/src/core/reference/include/openvino/reference/logical_reduction.hpp index 97be74d1122..9005c574c4c 100644 --- a/src/core/reference/include/openvino/reference/logical_reduction.hpp +++ b/src/core/reference/include/openvino/reference/logical_reduction.hpp @@ -7,7 +7,6 @@ #include #include -#include "ngraph/shape_util.hpp" #include "openvino/core/shape_util.hpp" #include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" diff --git a/src/core/reference/include/openvino/reference/lrn.hpp b/src/core/reference/include/openvino/reference/lrn.hpp index e3e2177ef21..b4df9e363f7 100644 --- a/src/core/reference/include/openvino/reference/lrn.hpp +++ b/src/core/reference/include/openvino/reference/lrn.hpp @@ -8,7 +8,7 @@ #include #include -#include "ngraph/util.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -61,7 +61,6 @@ void lrn(const T* arg, double dbeta, double dbias, size_t size) { - NGRAPH_SUPPRESS_DEPRECATED_START T alpha = static_cast(dalpha); T beta = static_cast(dbeta); T bias = static_cast(dbias); @@ -74,7 +73,7 @@ void lrn(const T* arg, axes_map[axis_coord] = true; } - CoordinateTransform input_transform(arg_shape); + const CoordinateTransformBasic input_transform{arg_shape}; for (const Coordinate& in_coord : input_transform) { // area determined by in_coord local neighborhood for (size_t i = 0; i < axes_map.size(); i++) { @@ -89,11 +88,10 @@ void lrn(const T* arg, } T square_sum = sum_region_across_axes(arg, slice_indices(arg_shape, begin_area, area_shape)); - auto index = input_transform.index(in_coord); + const auto index = coordinate_index(in_coord, arg_shape); T x = arg[index]; out[index] = x / static_cast(std::pow(bias + scale * square_sum, beta)); } - NGRAPH_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/matmul.hpp b/src/core/reference/include/openvino/reference/matmul.hpp index f05dd224f88..b4a09e0f276 100644 --- a/src/core/reference/include/openvino/reference/matmul.hpp +++ b/src/core/reference/include/openvino/reference/matmul.hpp @@ -10,7 +10,6 @@ #include #include "ngraph/runtime/opt_kernel/reshape.hpp" -#include "ngraph/shape_util.hpp" #include "openvino/reference/broadcast.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/matrix_nms.hpp b/src/core/reference/include/openvino/reference/matrix_nms.hpp index 4a3b5b015e7..eb2040dabc5 100644 --- a/src/core/reference/include/openvino/reference/matrix_nms.hpp +++ b/src/core/reference/include/openvino/reference/matrix_nms.hpp @@ -4,20 +4,9 @@ #pragma once -#include -#include -#include -#include #include -#include -#include -#include -#include -#include "ngraph/node.hpp" -#include "ngraph/op/matrix_nms.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/op/matrix_nms.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/max_pool.hpp b/src/core/reference/include/openvino/reference/max_pool.hpp index eaf2dee8932..df56c4f2b26 100644 --- a/src/core/reference/include/openvino/reference/max_pool.hpp +++ b/src/core/reference/include/openvino/reference/max_pool.hpp @@ -11,105 +11,6 @@ namespace ov { namespace reference { -template -void max_pool(const T* arg, - T* out, - const Shape& arg_shape, - const Shape& out_shape, - const Shape& window_shape, - const Strides& window_movement_strides, - const Shape& padding_below, - const Shape& padding_above) { - NGRAPH_SUPPRESS_DEPRECATED_START - // At the outermost level we will walk over every output coordinate O. - CoordinateTransform output_transform(out_shape); - - for (const Coordinate& out_coord : output_transform) { - // Our output coordinate O will have the form: - // - // (N,chan,i_1,...,i_n) - - size_t batch_index = out_coord[0]; - size_t channel = out_coord[1]; - - // For the input data we need to iterate the coordinate: - // - // I: - // - // over the range (noninclusive on the right): - // - // (N,chan,s_1*i_1,s_2*i_2,...,s_n*i_n) -> - // - // (N+1,chan+1,s_1*i_1 + window_shape_1,...,s_n*i_n + window_shape_n) - // - // with unit stride. - // - // We iterate this over the *padded* data, so below we will need to check for - // coordinates that fall in the padding area. - - size_t n_spatial_dimensions = arg_shape.size() - 2; - - Coordinate input_batch_transform_start(2 + n_spatial_dimensions); - Coordinate input_batch_transform_end(2 + n_spatial_dimensions); - Strides input_batch_transform_source_strides(2 + n_spatial_dimensions, 1); - AxisVector input_batch_transform_source_axis_order(2 + n_spatial_dimensions); - CoordinateDiff input_batch_transform_padding_below(2 + n_spatial_dimensions); - CoordinateDiff input_batch_transform_padding_above(2 + n_spatial_dimensions); - - input_batch_transform_start[0] = batch_index; - input_batch_transform_end[0] = batch_index + 1; - input_batch_transform_start[1] = channel; - input_batch_transform_end[1] = channel + 1; - input_batch_transform_padding_below[0] = 0; - input_batch_transform_padding_below[1] = 0; - input_batch_transform_padding_above[0] = 0; - input_batch_transform_padding_above[1] = 0; - - for (size_t i = 2; i < n_spatial_dimensions + 2; i++) { - size_t window_shape_this_dim = window_shape[i - 2]; - size_t movement_stride = window_movement_strides[i - 2]; - - input_batch_transform_start[i] = movement_stride * out_coord[i]; - input_batch_transform_end[i] = input_batch_transform_start[i] + window_shape_this_dim; - // If a window (kernel) is out of arg shape bounds, trim it to fit - auto padded_upper_bound = arg_shape[i] + padding_below[i - 2] + padding_above[i - 2]; - if (input_batch_transform_end[i] > padded_upper_bound) { - input_batch_transform_end[i] = padded_upper_bound; - } - input_batch_transform_padding_below[i] = padding_below[i - 2]; - input_batch_transform_padding_above[i] = padding_above[i - 2]; - } - - for (size_t i = 0; i < arg_shape.size(); i++) { - input_batch_transform_source_axis_order[i] = i; - } - - CoordinateTransform input_batch_transform(arg_shape, - input_batch_transform_start, - input_batch_transform_end, - input_batch_transform_source_strides, - input_batch_transform_source_axis_order, - input_batch_transform_padding_below, - input_batch_transform_padding_above); - - // As we go, we compute the maximum value: - // - // output[O] = max(output[O],arg[I]) - - T result = std::numeric_limits::lowest(); - - for (const Coordinate& input_batch_coord : input_batch_transform) { - if (input_batch_transform.has_source_coordinate(input_batch_coord)) { - T x = arg[input_batch_transform.index(input_batch_coord)]; - result = x > result ? x : result; - } - } - - out[output_transform.index(out_coord)] = result; - } - NGRAPH_SUPPRESS_DEPRECATED_END -} - namespace { void validate_max_pool_kernel_params(const size_t dims, const Shape& kernel, @@ -117,20 +18,20 @@ void validate_max_pool_kernel_params(const size_t dims, const Strides& kernel_dilations, const Shape& pads_begin, const Shape& pads_end) { - NGRAPH_CHECK(kernel.size() == dims && kernel_strides.size() == dims && kernel_dilations.size() == dims && - pads_begin.size() == dims && pads_end.size() == dims, - "One of the MaxPool params does not match the ", - dims, - "D implementation.\nkernel=", - kernel, - "\nkernel_strides=", - kernel_strides, - "\nkernel_dilations=", - kernel_dilations, - "\npads_begin=", - pads_begin, - "\npads_end=", - pads_end); + OPENVINO_ASSERT(kernel.size() == dims && kernel_strides.size() == dims && kernel_dilations.size() == dims && + pads_begin.size() == dims && pads_end.size() == dims, + "One of the MaxPool params does not match the ", + dims, + "D implementation.\nkernel=", + kernel, + "\nkernel_strides=", + kernel_strides, + "\nkernel_dilations=", + kernel_dilations, + "\npads_begin=", + pads_begin, + "\npads_end=", + pads_end); } /// \brief A helper struct representing spatial coordinates of a tensor element. It can use signed numbers as the @@ -390,10 +291,9 @@ void max_pool(const Values_t* data, pads_end, indices_offset); } else { - NGRAPH_CHECK(false, - "Unsupported input shape ", - data_shape, - " passed to the MaxPool reference implementation. Supported shapes: 3D, 4D and 5D."); + OPENVINO_THROW("Unsupported input shape ", + data_shape, + " passed to the MaxPool reference implementation. Supported shapes: 3D, 4D and 5D."); } } } @@ -409,5 +309,28 @@ void max_pool(const Values_t* data, } } } + +template +void max_pool(const Value_t* data, + Value_t* values, + const Shape& data_shape, + const Shape& out_shape, + const Shape& kernel, + const Strides& strides, + const Shape& pads_begin, + const Shape& pads_end) { + std::vector indices(shape_size(out_shape)); + const Strides dilations(kernel.size(), 1); + max_pool(data, + values, + indices.data(), + data_shape, + out_shape, + kernel, + strides, + dilations, + pads_begin, + pads_end); +} } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/maximum.hpp b/src/core/reference/include/openvino/reference/maximum.hpp index e918d4281b3..12388a1026c 100644 --- a/src/core/reference/include/openvino/reference/maximum.hpp +++ b/src/core/reference/include/openvino/reference/maximum.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/minimum.hpp b/src/core/reference/include/openvino/reference/minimum.hpp index 78b8788ef80..4bfe8ff0c89 100644 --- a/src/core/reference/include/openvino/reference/minimum.hpp +++ b/src/core/reference/include/openvino/reference/minimum.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/multiclass_nms.hpp b/src/core/reference/include/openvino/reference/multiclass_nms.hpp index 1a33c91eb55..58b67f3257d 100644 --- a/src/core/reference/include/openvino/reference/multiclass_nms.hpp +++ b/src/core/reference/include/openvino/reference/multiclass_nms.hpp @@ -4,20 +4,9 @@ #pragma once -#include -#include -#include -#include #include -#include -#include -#include -#include -#include "ngraph/node.hpp" -#include "ngraph/op/util/multiclass_nms_base.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/op/util/multiclass_nms_base.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/multiply.hpp b/src/core/reference/include/openvino/reference/multiply.hpp index bfc1dd01e67..91d279cc693 100644 --- a/src/core/reference/include/openvino/reference/multiply.hpp +++ b/src/core/reference/include/openvino/reference/multiply.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/mvn.hpp b/src/core/reference/include/openvino/reference/mvn.hpp index 8bcce83295e..3083b6e0d58 100644 --- a/src/core/reference/include/openvino/reference/mvn.hpp +++ b/src/core/reference/include/openvino/reference/mvn.hpp @@ -5,9 +5,9 @@ #pragma once #include -#include -#include +#include "openvino/core/shape.hpp" +#include "openvino/op/mvn.hpp" #include "openvino/reference/add.hpp" #include "openvino/reference/divide.hpp" #include "openvino/reference/multiply.hpp" @@ -18,7 +18,6 @@ namespace ov { namespace reference { -OPENVINO_SUPPRESS_DEPRECATED_START template void mvn(const T* arg, T* out, @@ -26,7 +25,7 @@ void mvn(const T* arg, const bool normalize_variance, const AxisSet& reduction_axes, const double eps) { - auto reduced_shape = ngraph::reduce(in_shape, reduction_axes, true); + auto reduced_shape = util::reduce_keep_dims(in_shape, reduction_axes); std::vector tmp_buffer(shape_size(in_shape)); reduce_mean(arg, tmp_buffer.data(), in_shape, reduction_axes); subtract(arg, tmp_buffer.data(), out, in_shape, reduced_shape, op::AutoBroadcastType::NUMPY); @@ -56,7 +55,7 @@ void mvn_6(const T* arg, bool normalize_variance, double eps, op::MVNEpsMode eps_mode) { - auto reduced_shape = ngraph::reduce(in_shape, reduction_axes, true); + auto reduced_shape = util::reduce_keep_dims(in_shape, reduction_axes); std::vector tmp_buffer(shape_size(in_shape)); reduce_mean(arg, tmp_buffer.data(), in_shape, reduction_axes); subtract(arg, tmp_buffer.data(), out, in_shape, reduced_shape, op::AutoBroadcastType::NUMPY); @@ -87,7 +86,6 @@ void mvn_6(const T* arg, divide(out, tmp_buffer.data(), out, in_shape, reduced_shape, op::AutoBroadcastType::NUMPY, true); } } -OPENVINO_SUPPRESS_DEPRECATED_END template AxisSet mvn_6_reduction_axes(const ov::Tensor& axes_input, size_t rank) { diff --git a/src/core/reference/include/openvino/reference/non_max_suppression.hpp b/src/core/reference/include/openvino/reference/non_max_suppression.hpp index b9e37e28c6a..9787e38ed47 100644 --- a/src/core/reference/include/openvino/reference/non_max_suppression.hpp +++ b/src/core/reference/include/openvino/reference/non_max_suppression.hpp @@ -4,20 +4,11 @@ #pragma once -#include -#include -#include -#include #include -#include -#include -#include #include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/runtime/tensor.hpp" namespace ov { namespace reference { @@ -37,11 +28,11 @@ void non_max_suppression5(const float* boxes_data, const bool sort_result_descending); void nms5_postprocessing(ov::TensorVector& outputs, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& selected_indices, const std::vector& selected_scores, int64_t valid_outputs, - const ngraph::element::Type selected_scores_type); + const element::Type selected_scores_type); void non_max_suppression(const float* boxes_data, const Shape& boxes_data_shape, @@ -59,10 +50,10 @@ void non_max_suppression(const float* boxes_data, const bool sort_result_descending); void nms_postprocessing(ov::TensorVector& outputs, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& selected_indices, const std::vector& selected_scores, int64_t valid_outputs, - const ngraph::element::Type selected_scores_type); + const element::Type selected_scores_type); } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/non_zero.hpp b/src/core/reference/include/openvino/reference/non_zero.hpp index 3ac5ee3c8ba..69276e37594 100644 --- a/src/core/reference/include/openvino/reference/non_zero.hpp +++ b/src/core/reference/include/openvino/reference/non_zero.hpp @@ -6,7 +6,7 @@ #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/normalize_l2.hpp b/src/core/reference/include/openvino/reference/normalize_l2.hpp index 69c0cff34fd..b1af06ea763 100644 --- a/src/core/reference/include/openvino/reference/normalize_l2.hpp +++ b/src/core/reference/include/openvino/reference/normalize_l2.hpp @@ -4,8 +4,6 @@ #pragma once -#include - #include "openvino/reference/autobroadcast_binop.hpp" #include "openvino/reference/reduce_sum.hpp" diff --git a/src/core/reference/include/openvino/reference/not_equal.hpp b/src/core/reference/include/openvino/reference/not_equal.hpp index f033bcdde70..b6b5c1a3484 100644 --- a/src/core/reference/include/openvino/reference/not_equal.hpp +++ b/src/core/reference/include/openvino/reference/not_equal.hpp @@ -11,8 +11,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/one_hot.hpp b/src/core/reference/include/openvino/reference/one_hot.hpp index 8527c1281bc..abf2bd7142d 100644 --- a/src/core/reference/include/openvino/reference/one_hot.hpp +++ b/src/core/reference/include/openvino/reference/one_hot.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/or.hpp b/src/core/reference/include/openvino/reference/or.hpp index 9d34d7b0978..7e821de63e3 100644 --- a/src/core/reference/include/openvino/reference/or.hpp +++ b/src/core/reference/include/openvino/reference/or.hpp @@ -6,8 +6,8 @@ #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/pad.hpp b/src/core/reference/include/openvino/reference/pad.hpp index ce66f64b05a..27ef1a471fb 100644 --- a/src/core/reference/include/openvino/reference/pad.hpp +++ b/src/core/reference/include/openvino/reference/pad.hpp @@ -4,9 +4,9 @@ #pragma once -#include "ngraph/coordinate_diff.hpp" -#include "ngraph/op/util/attr_types.hpp" // for op::PadMode -#include "ngraph/shape.hpp" +#include "openvino/core/coordinate_diff.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" // for op::PadMode namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/power.hpp b/src/core/reference/include/openvino/reference/power.hpp index 2aeb3042fcf..23b404fdb84 100644 --- a/src/core/reference/include/openvino/reference/power.hpp +++ b/src/core/reference/include/openvino/reference/power.hpp @@ -7,8 +7,8 @@ #include #include -#include "ngraph/op/util/attr_types.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/prelu.hpp b/src/core/reference/include/openvino/reference/prelu.hpp index 3de4744bf7e..7c3005e7e57 100644 --- a/src/core/reference/include/openvino/reference/prelu.hpp +++ b/src/core/reference/include/openvino/reference/prelu.hpp @@ -6,9 +6,9 @@ #include #include -#include -#include +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { @@ -22,15 +22,9 @@ void prelu(const T* arg, const T* slope, T* out, const Shape& arg_shape, const S channel_slope_shape[channel_dim_idx] = slope_shape[0]; std::swap(slope_shape_tmp, channel_slope_shape); } - autobroadcast_binop(arg, - slope, - out, - arg_shape, - slope_shape_tmp, - ngraph::op::AutoBroadcastType::NUMPY, - [](T x, T y) -> T { - return x < T(0) ? T(x * y) : x; - }); + autobroadcast_binop(arg, slope, out, arg_shape, slope_shape_tmp, op::AutoBroadcastType::NUMPY, [](T x, T y) -> T { + return x < T(0) ? T(x * y) : x; + }); } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/prior_box.hpp b/src/core/reference/include/openvino/reference/prior_box.hpp index e4ca13ae310..57b5373e498 100644 --- a/src/core/reference/include/openvino/reference/prior_box.hpp +++ b/src/core/reference/include/openvino/reference/prior_box.hpp @@ -6,9 +6,8 @@ #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/check.hpp" -#include "ngraph/op/prior_box.hpp" +#include "openvino/core/except.hpp" +#include "openvino/op/prior_box.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -50,7 +49,7 @@ void prior_box(const T* data, } std::vector variance = attrs.variance; - NGRAPH_CHECK(variance.size() == 1 || variance.size() == 4 || variance.empty()); + OPENVINO_ASSERT(variance.size() == 1 || variance.size() == 4 || variance.empty()); if (variance.empty()) variance.push_back(0.1f); diff --git a/src/core/reference/include/openvino/reference/prior_box_clustered.hpp b/src/core/reference/include/openvino/reference/prior_box_clustered.hpp index d15d69a6757..d4b2b2f64bd 100644 --- a/src/core/reference/include/openvino/reference/prior_box_clustered.hpp +++ b/src/core/reference/include/openvino/reference/prior_box_clustered.hpp @@ -6,9 +6,8 @@ #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/check.hpp" -#include "ngraph/op/prior_box_clustered.hpp" +#include "openvino/core/except.hpp" +#include "openvino/op/prior_box_clustered.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -18,11 +17,11 @@ void prior_box_clustered(const T* data, const T* img, float* dst_data, const Shape& out_shape, - const ngraph::op::PriorBoxClusteredAttrs& attrs) { + const op::v0::PriorBoxClustered::Attributes& attrs) { size_t num_priors_ = attrs.widths.size(); auto variances = attrs.variances; - NGRAPH_CHECK(variances.size() == 1 || variances.size() == 4 || variances.empty()); + OPENVINO_ASSERT(variances.size() == 1 || variances.size() == 4 || variances.empty()); if (variances.empty()) variances.push_back(0.1f); diff --git a/src/core/reference/include/openvino/reference/proposal.hpp b/src/core/reference/include/openvino/reference/proposal.hpp index febec805131..2f25027ba36 100644 --- a/src/core/reference/include/openvino/reference/proposal.hpp +++ b/src/core/reference/include/openvino/reference/proposal.hpp @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/op/proposal.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/proposal.hpp" namespace ov { namespace reference { namespace details { diff --git a/src/core/reference/include/openvino/reference/psroi_pooling.hpp b/src/core/reference/include/openvino/reference/psroi_pooling.hpp index 482f48ea28f..e7182325648 100644 --- a/src/core/reference/include/openvino/reference/psroi_pooling.hpp +++ b/src/core/reference/include/openvino/reference/psroi_pooling.hpp @@ -7,7 +7,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { @@ -29,7 +29,7 @@ void psroi_pooling(const T* input, } else if (mode_str == "bilinear") { mode = BILINEAR; } else { - NGRAPH_CHECK(false, "Invalid PS ROI pooling mode: " + mode_str); + OPENVINO_ASSERT(false, "Invalid PS ROI pooling mode: " + mode_str); } size_t channels_in = input_shape[1]; size_t height = input_shape[2]; diff --git a/src/core/reference/include/openvino/reference/quantize.hpp b/src/core/reference/include/openvino/reference/quantize.hpp deleted file mode 100644 index e5333aadca5..00000000000 --- a/src/core/reference/include/openvino/reference/quantize.hpp +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "ngraph/op/quantize.hpp" -#include "ngraph/shape_util.hpp" -#include "openvino/reference/utils/coordinate_transform.hpp" - -namespace ov { -namespace reference { -template -void quantize(const REAL* input, - const REAL* scale, - const QUANT* zero_point, - QUANT* output, - const Shape& input_shape, - const Shape& scale_zero_point_shape, - const AxisSet& axes, - op::Quantize::RoundMode round_mode) { - CoordinateTransform input_transform(input_shape); - CoordinateTransform scale_zero_point_transform(scale_zero_point_shape); - - for (const Coordinate& input_coord : input_transform) { - Coordinate scale_zero_point_coord = project(input_coord, axes); - - // apply scale - REAL qvalue = - input[input_transform.index(input_coord)] / scale[scale_zero_point_transform.index(scale_zero_point_coord)]; - - // round - if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY) { - REAL abs_qvalue = std::fabs(qvalue); - REAL abs_qvalue_toward_inf = std::floor(abs_qvalue + static_cast(0.5)); - qvalue = (qvalue < static_cast(0.0)) ? -abs_qvalue_toward_inf : abs_qvalue_toward_inf; - } else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_ZERO) { - auto abs_qvalue = std::fabs(qvalue); - auto abs_qvalue_toward_zero = std::ceil(abs_qvalue - static_cast(0.5)); - qvalue = (qvalue < static_cast(0.0)) ? -abs_qvalue_toward_zero : abs_qvalue_toward_zero; - } else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_UPWARD) { - qvalue = std::floor(qvalue + static_cast(0.5)); - } else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_DOWNWARD) { - qvalue = std::ceil(qvalue - static_cast(0.5)); - } else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN) { - auto up_qvalue = std::floor(qvalue + static_cast(0.5)); - auto dn_qvalue = std::ceil(qvalue - static_cast(0.5)); - auto rem = std::fmod(up_qvalue, 2.0); - qvalue = (rem == 0.0) ? up_qvalue : dn_qvalue; - } else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_INFINITY) { - auto abs_qvalue = std::fabs(qvalue); - auto abs_qvalue_toward_inf = std::ceil(abs_qvalue); - qvalue = (qvalue < static_cast(0.0)) ? -abs_qvalue_toward_inf : abs_qvalue_toward_inf; - } else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_ZERO) { - auto abs_qvalue = std::fabs(qvalue); - auto abs_qvalue_toward_zero = std::floor(abs_qvalue); - qvalue = (qvalue < static_cast(0.0)) ? -abs_qvalue_toward_zero : abs_qvalue_toward_zero; - } else if (round_mode == op::Quantize::RoundMode::ROUND_UP) { - qvalue = std::ceil(qvalue); - } else if (round_mode == op::Quantize::RoundMode::ROUND_DOWN) { - qvalue = std::floor(qvalue); - } - - // apply zero_point - qvalue += zero_point[scale_zero_point_transform.index(scale_zero_point_coord)]; - - // clamp - qvalue = std::max(qvalue, static_cast(std::numeric_limits::min())); - qvalue = std::min(qvalue, static_cast(std::numeric_limits::max())); - - // cast - output[input_transform.index(input_coord)] = static_cast(qvalue); - } -} -} // namespace reference -} // namespace ov diff --git a/src/core/reference/include/openvino/reference/random_uniform.hpp b/src/core/reference/include/openvino/reference/random_uniform.hpp index 6f942b97dc6..35257bba4a0 100644 --- a/src/core/reference/include/openvino/reference/random_uniform.hpp +++ b/src/core/reference/include/openvino/reference/random_uniform.hpp @@ -5,9 +5,9 @@ #pragma once #include -#include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/type/element_type.hpp" namespace ov { namespace reference { @@ -28,7 +28,7 @@ std::pair random_uniform(const uint64_t* out_shape, const char* max_val, char* out, const Shape& out_shape_shape, - const ngraph::element::Type& elem_type, + const element::Type& elem_type, uint64_t seed, uint64_t seed2, std::pair prev_state); diff --git a/src/core/reference/include/openvino/reference/range.hpp b/src/core/reference/include/openvino/reference/range.hpp index 99a6cb39d87..cc9cb2f643a 100644 --- a/src/core/reference/include/openvino/reference/range.hpp +++ b/src/core/reference/include/openvino/reference/range.hpp @@ -7,10 +7,8 @@ #include #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/check.hpp" -#include "ngraph/type/bfloat16.hpp" -#include "ngraph/type/float16.hpp" +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/type/float16.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/rdft.hpp b/src/core/reference/include/openvino/reference/rdft.hpp index ecfdae9585f..5abc2c7bcfb 100644 --- a/src/core/reference/include/openvino/reference/rdft.hpp +++ b/src/core/reference/include/openvino/reference/rdft.hpp @@ -17,13 +17,9 @@ #pragma once #include -#include #include -#include "ngraph/node.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/ops.hpp" -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/region_yolo.hpp b/src/core/reference/include/openvino/reference/region_yolo.hpp index f510c683db0..58a110c4429 100644 --- a/src/core/reference/include/openvino/reference/region_yolo.hpp +++ b/src/core/reference/include/openvino/reference/region_yolo.hpp @@ -7,7 +7,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { @@ -57,7 +57,7 @@ void region_yolo(const T* input, const int regions, const bool do_softmax, const std::vector& mask) { - NGRAPH_CHECK(input_shape.size() == 4); + OPENVINO_ASSERT(input_shape.size() == 4); const int batches = static_cast(input_shape[0]); const int height = static_cast(input_shape[2]); diff --git a/src/core/reference/include/openvino/reference/reorg_yolo.hpp b/src/core/reference/include/openvino/reference/reorg_yolo.hpp index 2678a4e82e3..64e5f2180a8 100644 --- a/src/core/reference/include/openvino/reference/reorg_yolo.hpp +++ b/src/core/reference/include/openvino/reference/reorg_yolo.hpp @@ -7,7 +7,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/reshape.hpp b/src/core/reference/include/openvino/reference/reshape.hpp index 3d769475352..b3cdd12df47 100644 --- a/src/core/reference/include/openvino/reference/reshape.hpp +++ b/src/core/reference/include/openvino/reference/reshape.hpp @@ -4,8 +4,8 @@ #pragma once -#include "ngraph/axis_vector.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/axis_vector.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/result.hpp b/src/core/reference/include/openvino/reference/result.hpp index ca71bab4dcb..9b59302fb0c 100644 --- a/src/core/reference/include/openvino/reference/result.hpp +++ b/src/core/reference/include/openvino/reference/result.hpp @@ -9,8 +9,6 @@ #include #include -#include "ngraph/shape.hpp" - namespace ov { namespace reference { template diff --git a/src/core/reference/include/openvino/reference/reverse_sequence.hpp b/src/core/reference/include/openvino/reference/reverse_sequence.hpp index 6a01bc9303e..07f7f6f68af 100644 --- a/src/core/reference/include/openvino/reference/reverse_sequence.hpp +++ b/src/core/reference/include/openvino/reference/reverse_sequence.hpp @@ -7,7 +7,6 @@ #include #include -#include "ngraph/util.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -25,8 +24,8 @@ void reverse_sequence(const T* arg, size_t batch_index = in_coord[batch_axis]; auto orig_seq_index = static_cast(sequence_lengths[batch_index]); - NGRAPH_CHECK(orig_seq_index <= arg_shape.at(sequence_axis), - "One of the elements of sequence lengths is greater than sequence axis dimension"); + OPENVINO_ASSERT(orig_seq_index <= arg_shape.at(sequence_axis), + "One of the elements of sequence lengths is greater than sequence axis dimension"); if (orig_seq_index == 0) { orig_seq_index = 1; diff --git a/src/core/reference/include/openvino/reference/roi_align.hpp b/src/core/reference/include/openvino/reference/roi_align.hpp index f7f9d1bc791..31eca09ebe4 100644 --- a/src/core/reference/include/openvino/reference/roi_align.hpp +++ b/src/core/reference/include/openvino/reference/roi_align.hpp @@ -6,9 +6,11 @@ #include -#include "ngraph/op/roi_align.hpp" // for ROIAlign:PoolingMode -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/roi_align.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" + namespace ov { namespace reference { using ROIPoolingMode = op::v3::ROIAlign::PoolingMode; @@ -33,11 +35,6 @@ void roi_align(const T* feature_maps, auto feature_map_width = feature_maps_shape[3]; auto num_rois = rois_shape[0]; - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform feature_maps_transform(feature_maps_shape); - CoordinateTransform rois_transform(rois_shape); - CoordinateTransform out_transform(out_shape); - bool aligned = false; T offset_src = static_cast(0); T offset_dst = static_cast(0); @@ -64,10 +61,10 @@ void roi_align(const T* feature_maps, for (unsigned int roi_index = 0; roi_index < num_rois; roi_index++) { // Get ROI`s corners - T x1 = (rois[rois_transform.index({roi_index, 0})] + offset_src) * spatial_scale + offset_dst; - T y1 = (rois[rois_transform.index({roi_index, 1})] + offset_src) * spatial_scale + offset_dst; - T x2 = (rois[rois_transform.index({roi_index, 2})] + offset_src) * spatial_scale + offset_dst; - T y2 = (rois[rois_transform.index({roi_index, 3})] + offset_src) * spatial_scale + offset_dst; + T x1 = (rois[coordinate_index({roi_index, 0}, rois_shape)] + offset_src) * spatial_scale + offset_dst; + T y1 = (rois[coordinate_index({roi_index, 1}, rois_shape)] + offset_src) * spatial_scale + offset_dst; + T x2 = (rois[coordinate_index({roi_index, 2}, rois_shape)] + offset_src) * spatial_scale + offset_dst; + T y2 = (rois[coordinate_index({roi_index, 3}, rois_shape)] + offset_src) * spatial_scale + offset_dst; T roi_width = x2 - x1; T roi_height = y2 - y1; @@ -83,7 +80,7 @@ void roi_align(const T* feature_maps, auto sampling_ratio_x = sampling_ratio == 0 ? static_cast(ceil(bin_width)) : sampling_ratio; auto sampling_ratio_y = sampling_ratio == 0 ? static_cast(ceil(bin_height)) : sampling_ratio; - NGRAPH_CHECK(sampling_ratio_x >= 0 && sampling_ratio_y >= 0); + OPENVINO_ASSERT(sampling_ratio_x >= 0 && sampling_ratio_y >= 0); uint64_t num_samples_in_bin = static_cast(sampling_ratio_x) * static_cast(sampling_ratio_y); @@ -169,26 +166,27 @@ void roi_align(const T* feature_maps, // the four parts are values of the four closest surrounding // neighbours of considered sample, then basing on all sampled // values in bin we calculate pooled value - auto sample_part_1 = feature_maps[feature_maps_transform.index( - {static_cast(batch_indices[roi_index]), - channel_index, - pooling_points[sample_index].first, - pooling_points[sample_index].second})]; - auto sample_part_2 = feature_maps[feature_maps_transform.index( - {static_cast(batch_indices[roi_index]), - channel_index, - pooling_points[sample_index + 1].first, - pooling_points[sample_index + 1].second})]; - auto sample_part_3 = feature_maps[feature_maps_transform.index( - {static_cast(batch_indices[roi_index]), - channel_index, - pooling_points[sample_index + 2].first, - pooling_points[sample_index + 2].second})]; - auto sample_part_4 = feature_maps[feature_maps_transform.index( - {static_cast(batch_indices[roi_index]), - channel_index, - pooling_points[sample_index + 3].first, - pooling_points[sample_index + 3].second})]; + const auto batch_index = static_cast(batch_indices[roi_index]); + auto sample_part_1 = feature_maps[coordinate_index({batch_index, + channel_index, + pooling_points[sample_index].first, + pooling_points[sample_index].second}, + feature_maps_shape)]; + auto sample_part_2 = feature_maps[coordinate_index({batch_index, + channel_index, + pooling_points[sample_index + 1].first, + pooling_points[sample_index + 1].second}, + feature_maps_shape)]; + auto sample_part_3 = feature_maps[coordinate_index({batch_index, + channel_index, + pooling_points[sample_index + 2].first, + pooling_points[sample_index + 2].second}, + feature_maps_shape)]; + auto sample_part_4 = feature_maps[coordinate_index({batch_index, + channel_index, + pooling_points[sample_index + 3].first, + pooling_points[sample_index + 3].second}, + feature_maps_shape)]; T sample_value = pooling_weights[sample_index] * sample_part_1 + pooling_weights[sample_index + 1] * sample_part_2 + @@ -210,17 +208,12 @@ void roi_align(const T* feature_maps, } } // save the calculations for all bins across this channel - auto output_channel_offset = out_transform.index({static_cast(roi_index), - static_cast(channel_index), - static_cast(0), - static_cast(0)}); + auto output_channel_offset = coordinate_index({roi_index, channel_index, 0ul, 0ul}, out_shape); std::copy(tmp_out.begin(), tmp_out.end(), out + output_channel_offset); tmp_out.clear(); } } - NGRAPH_SUPPRESS_DEPRECATED_END - return; } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/roi_pooling.hpp b/src/core/reference/include/openvino/reference/roi_pooling.hpp index 5dbe13d1de5..02247ee2fa5 100644 --- a/src/core/reference/include/openvino/reference/roi_pooling.hpp +++ b/src/core/reference/include/openvino/reference/roi_pooling.hpp @@ -7,7 +7,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { @@ -42,7 +42,7 @@ void roi_pooling(const T* feature_maps, int roi_batch_id = static_cast(rois[roi_idx + 0]); // ROI batch id must be in the range of [0, N-1] - NGRAPH_CHECK(0 <= roi_batch_id && roi_batch_id < batches, "ROI batch id must be in the range of [0, N-1]"); + OPENVINO_ASSERT(0 <= roi_batch_id && roi_batch_id < batches, "ROI batch id must be in the range of [0, N-1]"); if (pooling_method == "max") { // ROI coordinates scaled to input feature maps diff --git a/src/core/reference/include/openvino/reference/roll.hpp b/src/core/reference/include/openvino/reference/roll.hpp index 16b50bc32f6..dfe33c00ffa 100644 --- a/src/core/reference/include/openvino/reference/roll.hpp +++ b/src/core/reference/include/openvino/reference/roll.hpp @@ -16,7 +16,7 @@ #pragma once -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/round.hpp b/src/core/reference/include/openvino/reference/round.hpp index b5c1ee87f55..f3080c4edcd 100644 --- a/src/core/reference/include/openvino/reference/round.hpp +++ b/src/core/reference/include/openvino/reference/round.hpp @@ -9,7 +9,7 @@ #include #include "openvino/op/round.hpp" -#include "openvino/reference/round_guard.hpp" +#include "openvino/reference/rounding_guard.hpp" #include "openvino/reference/utils/type_util.hpp" namespace ov { @@ -50,7 +50,7 @@ T round_half_away_zero(T value) { */ template ()>::type* = nullptr> void round(const T* arg, T* out, const size_t count, const op::v5::Round::RoundMode mode) { - const ov::RoundGuard round_g{FE_TONEAREST}; + const ov::RoundingGuard round_g{FE_TONEAREST}; const auto round_algo = (mode == op::v5::Round::RoundMode::HALF_TO_EVEN) ? round_to_nearest_even : round_half_away_zero; diff --git a/src/core/reference/include/openvino/reference/round_guard.hpp b/src/core/reference/include/openvino/reference/rounding_guard.hpp similarity index 84% rename from src/core/reference/include/openvino/reference/round_guard.hpp rename to src/core/reference/include/openvino/reference/rounding_guard.hpp index cfccdc01b7a..4c11b2637ae 100644 --- a/src/core/reference/include/openvino/reference/round_guard.hpp +++ b/src/core/reference/include/openvino/reference/rounding_guard.hpp @@ -18,10 +18,10 @@ namespace ov { * - FE_UPWARD * see std header for details. */ -class RoundGuard { +class RoundingGuard { public: - RoundGuard(int mode); - ~RoundGuard(); + RoundingGuard(int mode); + ~RoundingGuard(); private: int m_prev_round_mode; diff --git a/src/core/reference/include/openvino/reference/scatter_elements_update.hpp b/src/core/reference/include/openvino/reference/scatter_elements_update.hpp index 0262db0a1ce..3fd38f06600 100644 --- a/src/core/reference/include/openvino/reference/scatter_elements_update.hpp +++ b/src/core/reference/include/openvino/reference/scatter_elements_update.hpp @@ -8,8 +8,8 @@ #include #include -#include "ngraph/check.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/except.hpp" +#include "openvino/core/shape.hpp" #include "openvino/op/scatter_elements_update.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" diff --git a/src/core/reference/include/openvino/reference/scatter_nd_update.hpp b/src/core/reference/include/openvino/reference/scatter_nd_update.hpp index f4c5821dac0..ff63313823b 100644 --- a/src/core/reference/include/openvino/reference/scatter_nd_update.hpp +++ b/src/core/reference/include/openvino/reference/scatter_nd_update.hpp @@ -7,8 +7,7 @@ #include #include -#include "ngraph/coordinate.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "utils/span.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/scatter_update.hpp b/src/core/reference/include/openvino/reference/scatter_update.hpp index 90a1b50b1b7..07dab6be32b 100644 --- a/src/core/reference/include/openvino/reference/scatter_update.hpp +++ b/src/core/reference/include/openvino/reference/scatter_update.hpp @@ -6,8 +6,7 @@ #include -#include "ngraph/check.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" #include "openvino/util/common_util.hpp" @@ -22,12 +21,10 @@ static const CoordinateTransformBasic get_target_shape(const Shape& data_shape, AxisVector axis_order(m_n_axes); std::iota(axis_order.begin(), axis_order.end(), 0); const Strides strides(m_n_axes, 1); - OPENVINO_SUPPRESS_DEPRECATED_START for (size_t axis = 0; axis < m_n_axes; axis++) { target_shape.push_back( util::ceil_div(end_corner[axis_order[axis]] - start_corner[axis_order[axis]], strides[axis_order[axis]])); } - OPENVINO_SUPPRESS_DEPRECATED_END return target_shape; } diff --git a/src/core/reference/include/openvino/reference/sequences.hpp b/src/core/reference/include/openvino/reference/sequences.hpp index c466ec5d559..cbbd91655b1 100644 --- a/src/core/reference/include/openvino/reference/sequences.hpp +++ b/src/core/reference/include/openvino/reference/sequences.hpp @@ -51,7 +51,7 @@ void cell_pass(CellType type, return new_shape; }; - size_t x_shape_size = ngraph::shape_size(shapes[0]); + size_t x_shape_size = shape_size(shapes[0]); // split X size_t num_splits = shapes[0].at(1); @@ -90,7 +90,7 @@ void cell_pass(CellType type, // split A std::vector a_seqs; if (type == CellType::AUGRU) { - const auto a_shape_size = ngraph::shape_size(shapes[6]); + const auto a_shape_size = shape_size(shapes[6]); a_seqs.resize(a_shape_size * sizeof(T)); std::vector a_pointers(num_splits); for (size_t i = 0; i < num_splits; ++i) { @@ -100,18 +100,18 @@ void cell_pass(CellType type, } Shape part_shape{batch, 1, hidden_size}; - size_t part_shape_size = ngraph::shape_size(part_shape); + size_t part_shape_size = shape_size(part_shape); std::vector> h_list(num_splits, std::vector(part_shape_size * sizeof(T), 0)); std::vector> c_list(num_splits, std::vector(part_shape_size * sizeof(T), 0)); // use outputs as a buffer for temporarily values char* H_i = outputs[1]; - std::memcpy(H_i, inputs[2], ngraph::shape_size(shapes[2]) * sizeof(T)); + std::memcpy(H_i, inputs[2], shape_size(shapes[2]) * sizeof(T)); char* C_i = nullptr; // LSTMCell only if ((type == CellType::LSTM) || (type == CellType::LSTM_v1)) { C_i = outputs[2]; - std::memcpy(C_i, inputs[3], ngraph::shape_size(shapes[3]) * sizeof(T)); + std::memcpy(C_i, inputs[3], shape_size(shapes[3]) * sizeof(T)); } for (size_t time_step = 0; time_step < num_splits; ++time_step) { @@ -310,11 +310,11 @@ void lstm_sequence(const char* X, } else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) { // Split bidirectional case to forward + reverse passes. // split inputs - std::vector> H_split(2, std::vector(sizeof(T) * ngraph::shape_size(H_shape) / 2)); - std::vector> C_split(2, std::vector(sizeof(T) * ngraph::shape_size(C_shape) / 2)); - std::vector> W_split(2, std::vector(sizeof(T) * ngraph::shape_size(W_shape) / 2)); - std::vector> R_split(2, std::vector(sizeof(T) * ngraph::shape_size(R_shape) / 2)); - std::vector> B_split(2, std::vector(sizeof(T) * ngraph::shape_size(B_shape) / 2)); + std::vector> H_split(2, std::vector(sizeof(T) * shape_size(H_shape) / 2)); + std::vector> C_split(2, std::vector(sizeof(T) * shape_size(C_shape) / 2)); + std::vector> W_split(2, std::vector(sizeof(T) * shape_size(W_shape) / 2)); + std::vector> R_split(2, std::vector(sizeof(T) * shape_size(R_shape) / 2)); + std::vector> B_split(2, std::vector(sizeof(T) * shape_size(B_shape) / 2)); char* h_pointers[2] = {H_split[0].data(), H_split[1].data()}; char* c_pointers[2] = {C_split[0].data(), C_split[1].data()}; char* w_pointers[2] = {W_split[0].data(), W_split[1].data()}; @@ -428,12 +428,12 @@ void lstm_sequence_v1(const char* X, } else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) { // Split bidirectional case to forward + reverse passes. // split inputs - std::vector> H_split(2, std::vector(sizeof(T) * ngraph::shape_size(H_shape) / 2)); - std::vector> C_split(2, std::vector(sizeof(T) * ngraph::shape_size(C_shape) / 2)); - std::vector> W_split(2, std::vector(sizeof(T) * ngraph::shape_size(W_shape) / 2)); - std::vector> R_split(2, std::vector(sizeof(T) * ngraph::shape_size(R_shape) / 2)); - std::vector> B_split(2, std::vector(sizeof(T) * ngraph::shape_size(B_shape) / 2)); - std::vector> P_split(2, std::vector(sizeof(T) * ngraph::shape_size(P_shape) / 2)); + std::vector> H_split(2, std::vector(sizeof(T) * shape_size(H_shape) / 2)); + std::vector> C_split(2, std::vector(sizeof(T) * shape_size(C_shape) / 2)); + std::vector> W_split(2, std::vector(sizeof(T) * shape_size(W_shape) / 2)); + std::vector> R_split(2, std::vector(sizeof(T) * shape_size(R_shape) / 2)); + std::vector> B_split(2, std::vector(sizeof(T) * shape_size(B_shape) / 2)); + std::vector> P_split(2, std::vector(sizeof(T) * shape_size(P_shape) / 2)); char* h_pointers[2] = {H_split[0].data(), H_split[1].data()}; char* c_pointers[2] = {C_split[0].data(), C_split[1].data()}; char* w_pointers[2] = {W_split[0].data(), W_split[1].data()}; @@ -554,10 +554,10 @@ void gru_sequence(const char* X, } else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) { // Split bidirectional case to forward + reverse passes. // split inputs - std::vector> H_split(2, std::vector(sizeof(T) * ngraph::shape_size(H_shape) / 2)); - std::vector> W_split(2, std::vector(sizeof(T) * ngraph::shape_size(W_shape) / 2)); - std::vector> R_split(2, std::vector(sizeof(T) * ngraph::shape_size(R_shape) / 2)); - std::vector> B_split(2, std::vector(sizeof(T) * ngraph::shape_size(B_shape) / 2)); + std::vector> H_split(2, std::vector(sizeof(T) * shape_size(H_shape) / 2)); + std::vector> W_split(2, std::vector(sizeof(T) * shape_size(W_shape) / 2)); + std::vector> R_split(2, std::vector(sizeof(T) * shape_size(R_shape) / 2)); + std::vector> B_split(2, std::vector(sizeof(T) * shape_size(B_shape) / 2)); char* h_pointers[2] = {H_split[0].data(), H_split[1].data()}; char* w_pointers[2] = {W_split[0].data(), W_split[1].data()}; char* r_pointers[2] = {R_split[0].data(), R_split[1].data()}; @@ -645,10 +645,10 @@ void rnn_sequence(const char* X, } else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) { // Split bidirectional case to forward + reverse passes. // split inputs - std::vector> H_split(2, std::vector(sizeof(T) * ngraph::shape_size(H_shape) / 2)); - std::vector> W_split(2, std::vector(sizeof(T) * ngraph::shape_size(W_shape) / 2)); - std::vector> R_split(2, std::vector(sizeof(T) * ngraph::shape_size(R_shape) / 2)); - std::vector> B_split(2, std::vector(sizeof(T) * ngraph::shape_size(B_shape) / 2)); + std::vector> H_split(2, std::vector(sizeof(T) * shape_size(H_shape) / 2)); + std::vector> W_split(2, std::vector(sizeof(T) * shape_size(W_shape) / 2)); + std::vector> R_split(2, std::vector(sizeof(T) * shape_size(R_shape) / 2)); + std::vector> B_split(2, std::vector(sizeof(T) * shape_size(B_shape) / 2)); char* h_pointers[2] = {H_split[0].data(), H_split[1].data()}; char* w_pointers[2] = {W_split[0].data(), W_split[1].data()}; char* r_pointers[2] = {R_split[0].data(), R_split[1].data()}; diff --git a/src/core/reference/include/openvino/reference/shape_of.hpp b/src/core/reference/include/openvino/reference/shape_of.hpp index 940c236ec1f..b2b6ab478b1 100644 --- a/src/core/reference/include/openvino/reference/shape_of.hpp +++ b/src/core/reference/include/openvino/reference/shape_of.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/shuffle_channels.hpp b/src/core/reference/include/openvino/reference/shuffle_channels.hpp index aff376b15b9..36d25657b19 100644 --- a/src/core/reference/include/openvino/reference/shuffle_channels.hpp +++ b/src/core/reference/include/openvino/reference/shuffle_channels.hpp @@ -8,7 +8,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/slice.hpp b/src/core/reference/include/openvino/reference/slice.hpp index a86b47f76fe..bb45b596c02 100644 --- a/src/core/reference/include/openvino/reference/slice.hpp +++ b/src/core/reference/include/openvino/reference/slice.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/type/element_type.hpp" +#include "openvino/core/type/element_type.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/softmax.hpp b/src/core/reference/include/openvino/reference/softmax.hpp index 69ea583fbc6..1e03f940376 100644 --- a/src/core/reference/include/openvino/reference/softmax.hpp +++ b/src/core/reference/include/openvino/reference/softmax.hpp @@ -6,39 +6,39 @@ #include -#include "ngraph/shape_util.hpp" +#include "openvino/core/shape_util.hpp" #include "openvino/reference/reduce_max.hpp" #include "openvino/reference/reduce_sum.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { namespace reference { template void softmax(const T* arg, T* out, const Shape& shape, const AxisSet& axes) { - NGRAPH_SUPPRESS_DEPRECATED_START - auto temp_shape = ngraph::reduce(shape, axes, true); - auto temp_elements = shape_size(temp_shape); - auto temp_ptr = new T[temp_elements]; + const auto temp_shape = util::reduce_keep_dims(shape, axes); + const auto temp_elements = shape_size(temp_shape); + auto temp_storage = std::vector(temp_elements); + const auto temp_ptr = temp_storage.data(); reduce_max(arg, temp_ptr, shape, axes); - CoordinateTransform transform(shape); - CoordinateTransform temp_transform(temp_shape); - for (const Coordinate& coord : transform) { - Coordinate temp_coord = ngraph::reduce(coord, axes, true); - out[transform.index(coord)] = - std::exp(arg[transform.index(coord)] - temp_ptr[temp_transform.index(temp_coord)]); + const CoordinateTransformBasic transform{shape}; + for (const auto& coord : transform) { + const Coordinate temp_coord = util::reduce_keep_dims(coord, axes); + const auto out_index = coordinate_index(coord, shape); + const auto temp_index = coordinate_index(temp_coord, temp_shape); + out[out_index] = std::exp(arg[out_index] - temp_ptr[temp_index]); } reduce_sum(out, temp_ptr, shape, axes); - for (const Coordinate& coord : transform) { - Coordinate temp_coord = ngraph::reduce(coord, axes, true); - out[transform.index(coord)] /= temp_ptr[temp_transform.index(temp_coord)]; + for (const auto& coord : transform) { + const Coordinate temp_coord = util::reduce_keep_dims(coord, axes); + const auto out_index = coordinate_index(coord, shape); + const auto temp_index = coordinate_index(temp_coord, temp_shape); + out[out_index] /= temp_ptr[temp_index]; } - - delete[] temp_ptr; - NGRAPH_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/space_to_depth.hpp b/src/core/reference/include/openvino/reference/space_to_depth.hpp index 3eeb2653463..3df0bdd4123 100644 --- a/src/core/reference/include/openvino/reference/space_to_depth.hpp +++ b/src/core/reference/include/openvino/reference/space_to_depth.hpp @@ -4,8 +4,8 @@ #pragma once -#include "ngraph/op/space_to_depth.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/space_to_depth.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/squared_difference.hpp b/src/core/reference/include/openvino/reference/squared_difference.hpp index b28586d3623..a7e4149de21 100644 --- a/src/core/reference/include/openvino/reference/squared_difference.hpp +++ b/src/core/reference/include/openvino/reference/squared_difference.hpp @@ -7,7 +7,6 @@ #include #include -#include "ngraph/shape_util.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/subtract.hpp b/src/core/reference/include/openvino/reference/subtract.hpp index 2051dd1874d..689ceb6915b 100644 --- a/src/core/reference/include/openvino/reference/subtract.hpp +++ b/src/core/reference/include/openvino/reference/subtract.hpp @@ -7,6 +7,8 @@ #include #include +#include "openvino/core/shape.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/reference/autobroadcast_binop.hpp" namespace ov { diff --git a/src/core/reference/include/openvino/reference/tile.hpp b/src/core/reference/include/openvino/reference/tile.hpp index 2ee3da6b0c0..81fcf4b0182 100644 --- a/src/core/reference/include/openvino/reference/tile.hpp +++ b/src/core/reference/include/openvino/reference/tile.hpp @@ -4,10 +4,7 @@ #pragma once -#include - -#include "ngraph/type/element_type.hpp" -#include "openvino/reference/utils/coordinate_transform.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/topk.hpp b/src/core/reference/include/openvino/reference/topk.hpp index 0b7b4d48a53..c84fb54e996 100644 --- a/src/core/reference/include/openvino/reference/topk.hpp +++ b/src/core/reference/include/openvino/reference/topk.hpp @@ -8,7 +8,8 @@ #include #include -#include "ngraph/op/topk.hpp" +#include "openvino/op/topk.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -51,30 +52,21 @@ void topk(const T* arg, size_t k, bool compute_max, op::TopKSortType sort = op::TopKSortType::NONE) { - NGRAPH_SUPPRESS_DEPRECATED_START using namespace std; - // reorder source axis visit order and make "axis" inner most - size_t ndim = static_cast(in_shape.size()); - Coordinate start_corner(ndim, 0); - Coordinate end_corner(in_shape); - end_corner[axis] = 1; - Strides strides(ndim, 1); - AxisVector axis_order(ndim); - iota(axis_order.begin(), axis_order.end(), 0); - axis_order.erase(axis_order.begin() + axis); - axis_order.push_back(axis); - // Create CoordinateTransforms that visits only the first element along "axis" - CoordinateTransform input_transform(in_shape, start_corner, end_corner, strides, axis_order); - CoordinateTransform output_transform(out_shape, start_corner, end_corner, strides, axis_order); // Create temp vector for sorting. vector> workspace(in_shape[axis]); - vector in_strides = ngraph::row_major_strides(in_shape); - vector out_strides = ngraph::row_major_strides(out_shape); + vector in_strides = row_major_strides(in_shape); + vector out_strides = row_major_strides(out_shape); auto in_axis_stride = in_strides[axis]; auto out_axis_stride = out_strides[axis]; - for (const Coordinate& coord : input_transform) { - auto arg_index = input_transform.index(coord); - auto out_index = output_transform.index(coord); + + // Iterate over elements with 0 index at "axis" dimension + auto traverse_shape = in_shape; + traverse_shape[axis] = 1; + CoordinateTransformBasic traverse_transform(traverse_shape); + for (const Coordinate& coord : traverse_transform) { + auto arg_index = coordinate_index(coord, in_shape); + auto out_index = coordinate_index(coord, out_shape); // Fill the temp vector U i = 0; for (tuple& entry : workspace) { @@ -109,7 +101,6 @@ void topk(const T* arg, out_index += out_axis_stride; } } - NGRAPH_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/include/openvino/reference/transpose.hpp b/src/core/reference/include/openvino/reference/transpose.hpp index 03af9040382..6d91676dab9 100644 --- a/src/core/reference/include/openvino/reference/transpose.hpp +++ b/src/core/reference/include/openvino/reference/transpose.hpp @@ -9,7 +9,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/include/openvino/reference/unique.hpp b/src/core/reference/include/openvino/reference/unique.hpp index f037f1e7f00..fc823432005 100644 --- a/src/core/reference/include/openvino/reference/unique.hpp +++ b/src/core/reference/include/openvino/reference/unique.hpp @@ -5,7 +5,7 @@ #pragma once #include "gather.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" @@ -105,7 +105,7 @@ UniqueElements find_unique_elements(const Data_t* data, using std::begin; using std::end; - const auto data_shape_strides = ngraph::row_major_strides(data_shape); + const auto data_shape_strides = row_major_strides(data_shape); if (axis && *axis < 0) { const auto normalized_axis = *axis + data_shape.size(); diff --git a/src/core/reference/include/openvino/reference/utils/fft_common.hpp b/src/core/reference/include/openvino/reference/utils/fft_common.hpp index d445efbba71..02ede7769ea 100644 --- a/src/core/reference/include/openvino/reference/utils/fft_common.hpp +++ b/src/core/reference/include/openvino/reference/utils/fft_common.hpp @@ -11,8 +11,7 @@ #include #include -#include "ngraph/shape.hpp" -#include "ngraph/type/element_type.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { @@ -25,7 +24,7 @@ namespace fft_common { // into [N_{r - 1}, ..., N_0]. // At this time, complex tensors are supported only for FFT-like operations, as // DFT, IDFT, RDFT -std::vector reverse_shape_of_emulated_complex_tensor(const ngraph::Shape& shape); +std::vector reverse_shape_of_emulated_complex_tensor(const Shape& shape); // Calculates strides for all axes. std::vector compute_strides(const std::vector& v); diff --git a/src/core/reference/include/openvino/reference/utils/nms_common.hpp b/src/core/reference/include/openvino/reference/utils/nms_common.hpp index 4138db64085..b8364bb0661 100644 --- a/src/core/reference/include/openvino/reference/utils/nms_common.hpp +++ b/src/core/reference/include/openvino/reference/utils/nms_common.hpp @@ -10,7 +10,7 @@ #include #include -#include "ngraph/type/element_type.hpp" +#include "openvino/core/type/element_type.hpp" namespace ov { namespace reference { @@ -60,11 +60,11 @@ struct BoxInfo { void nms_common_postprocessing(void* prois, void* pscores, void* pselected_num, - const ngraph::element::Type& output_type, + const element::Type& output_type, const std::vector& selected_outputs, const std::vector& selected_indices, const std::vector& valid_outputs, - const ngraph::element::Type& selected_outputs_type); + const element::Type& selected_outputs_type); } // namespace nms_common } // namespace reference diff --git a/src/core/reference/src/op/depth_to_space.cpp b/src/core/reference/src/op/depth_to_space.cpp index 490b566f456..0a0d33596d4 100644 --- a/src/core/reference/src/op/depth_to_space.cpp +++ b/src/core/reference/src/op/depth_to_space.cpp @@ -7,8 +7,8 @@ #include #include -#include "ngraph/check.hpp" #include "ngraph/runtime/opt_kernel/reshape.hpp" +#include "openvino/core/except.hpp" namespace ov { namespace reference { @@ -35,16 +35,16 @@ void depth_to_space(const char* const in, const size_t spatial_dims = in_shape.size() - spatial_dim_index; const size_t c_dim_divider = static_cast(std::pow(block_size, spatial_dims)); - NGRAPH_CHECK(block_size > 0 && c_dim % c_dim_divider == 0, - "DepthToSpace: The input data's 'channels' axis size: ", - c_dim, - " must be evenly divided by 'block_size'^'spatial_dims': (", - c_dim_divider, - ", ", - block_size, - "^", - spatial_dims, - ")"); + OPENVINO_ASSERT(block_size > 0 && c_dim % c_dim_divider == 0, + "DepthToSpace: The input data's 'channels' axis size: ", + c_dim, + " must be evenly divided by 'block_size'^'spatial_dims': (", + c_dim_divider, + ", ", + block_size, + "^", + spatial_dims, + ")"); const size_t c_flat = c_dim / c_dim_divider; diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index a48dc998495..abe8f8c14ba 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -249,6 +249,14 @@ Shape compute_matmul_output_shape(const Shape& common_sub_shape, return matmul_output_shape; } +/// @brief Prepares default order axis vector +/// +AxisVector get_default_order(size_t rank) { + AxisVector default_order(rank); + std::iota(begin(default_order), end(default_order), 0); + return default_order; +} + /// \brief Update a vector of inputs and subscripts by removing items for /// inputs with indices input_ind1 and input_ind2 and inserted new input and /// the corresponsing subscript in the tail @@ -278,8 +286,8 @@ ov::Tensor unsqueeze_input(const ov::Tensor& input, std::vector& unsque return input; } - Shape input_shape = input.get_shape(); - Shape output_shape = input_shape; + const auto& input_shape = input.get_shape(); + auto output_shape = input_shape; std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end()); for (auto unsqueeze_axis : unsqueeze_axes) { OPENVINO_ASSERT(unsqueeze_axis >= 0); @@ -288,9 +296,7 @@ ov::Tensor unsqueeze_input(const ov::Tensor& input, std::vector& unsque } auto output = ov::Tensor(input.get_element_type(), output_shape); - OPENVINO_SUPPRESS_DEPRECATED_START - const AxisVector order = ngraph::get_default_order(input.get_shape()); - OPENVINO_SUPPRESS_DEPRECATED_END + const auto order = get_default_order(input_shape.size()); const auto element_type = input.get_element_type(); reference::reshape(reinterpret_cast(input.data()), @@ -645,11 +651,9 @@ ov::Tensor reshape_input_for_matmul(const ov::Tensor& input, } const auto element_type = input.get_element_type(); - const auto input_shape = input.get_shape(); + const auto& input_shape = input.get_shape(); auto output = ov::Tensor(element_type, new_shape); - OPENVINO_SUPPRESS_DEPRECATED_START - const AxisVector order = ngraph::get_default_order(input_shape); - OPENVINO_SUPPRESS_DEPRECATED_END + const auto order = get_default_order(input_shape.size()); reference::reshape(reinterpret_cast(input.data()), reinterpret_cast(output.data()), @@ -871,7 +875,7 @@ void contract_two_inputs(ov::TensorVector& inputs, // broadcast both inputs to have common sub-shape broadcasted that is needed // in case of ellipsis among the common labels // reference::broadcast() - PartialShape::broadcast_merge_into(common_sub_shape1, common_sub_shape2, ngraph::op::AutoBroadcastType::NUMPY); + PartialShape::broadcast_merge_into(common_sub_shape1, common_sub_shape2, op::AutoBroadcastType::NUMPY); Shape common_sub_shape = common_sub_shape1.get_shape(); broadcast_input(inputs, input_ind1, @@ -926,9 +930,7 @@ void contract_two_inputs(ov::TensorVector& inputs, back_shape.insert(back_shape.end(), separate2_sub_shape.begin(), separate2_sub_shape.end()); auto contract_output = ov::Tensor(matmul_output.get_element_type(), back_shape); - OPENVINO_SUPPRESS_DEPRECATED_START - const AxisVector order = ngraph::get_default_order(matmul_output.get_shape()); - OPENVINO_SUPPRESS_DEPRECATED_END + const auto order = get_default_order(matmul_output.get_shape().size()); reference::reshape(reinterpret_cast(matmul_output.data()), reinterpret_cast(contract_output.data()), matmul_output.get_shape(), diff --git a/src/core/reference/src/op/experimental_detectron_detection_output.cpp b/src/core/reference/src/op/experimental_detectron_detection_output.cpp index da60c8a7ce8..bf297fef97d 100644 --- a/src/core/reference/src/op/experimental_detectron_detection_output.cpp +++ b/src/core/reference/src/op/experimental_detectron_detection_output.cpp @@ -14,13 +14,13 @@ // limitations under the License. //***************************************************************************** -#include "ngraph/op/experimental_detectron_detection_output.hpp" +#include "openvino/op/experimental_detectron_detection_output.hpp" #include #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/experimental_detectron_detection_output.hpp" namespace { @@ -318,7 +318,7 @@ void experimental_detectron_detection_output(const float* boxes, void experimental_detectron_detection_output_postprocessing(void* pboxes, void* pclasses, void* pscores, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& output_boxes, const std::vector& output_classes, const std::vector& output_scores, diff --git a/src/core/reference/src/op/experimental_detectron_proposal_single_image.cpp b/src/core/reference/src/op/experimental_detectron_proposal_single_image.cpp index 319c4c75f8d..ef9ad0002d4 100644 --- a/src/core/reference/src/op/experimental_detectron_proposal_single_image.cpp +++ b/src/core/reference/src/op/experimental_detectron_proposal_single_image.cpp @@ -10,8 +10,8 @@ #include #include -#include "ngraph/op/experimental_detectron_generate_proposals.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/experimental_detectron_generate_proposals.hpp" #include "openvino/reference/proposal.hpp" namespace { @@ -295,7 +295,7 @@ void experimental_detectron_proposals_single_image( void experimental_detectron_proposals_single_image_postprocessing(void* prois, void* pscores, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& output_rois, const std::vector& output_scores, const Shape& output_rois_shape, diff --git a/src/core/reference/src/op/experimental_detectron_roi_feature_extractor.cpp b/src/core/reference/src/op/experimental_detectron_roi_feature_extractor.cpp index ad4c35f3482..423fb3d4a7d 100644 --- a/src/core/reference/src/op/experimental_detectron_roi_feature_extractor.cpp +++ b/src/core/reference/src/op/experimental_detectron_roi_feature_extractor.cpp @@ -10,8 +10,8 @@ #include #include -#include "ngraph/op/experimental_detectron_roi_feature.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/experimental_detectron_roi_feature.hpp" #if defined(__GNUC__) && !defined(__clang__) # if defined(__linux__) && defined(OPENVINO_ARCH_X86) && \ @@ -344,7 +344,7 @@ void experimental_detectron_roi_feature_extractor( void experimental_detectron_roi_feature_extractor_postprocessing(void* prois_features, void* prois, - const ngraph::element::Type output_type, + const element::Type output_type, const std::vector& output_rois_features, const std::vector& output_rois, const Shape& output_rois_features_shape, diff --git a/src/core/reference/src/op/function.cpp b/src/core/reference/src/op/function.cpp index ebf706e3f03..c70bf4020b1 100644 --- a/src/core/reference/src/op/function.cpp +++ b/src/core/reference/src/op/function.cpp @@ -6,12 +6,8 @@ #include -#include "ngraph/opsets/opset5.hpp" -#include "ngraph/runtime/host_tensor.hpp" -#include "ngraph/runtime/tensor.hpp" #include "openvino/core/deprecated.hpp" #include "openvino/core/shape_util.hpp" -#include "openvino/reference/concat.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/src/op/gather_tree.cpp b/src/core/reference/src/op/gather_tree.cpp index 6e9ef4bb04d..4a5dd31092e 100644 --- a/src/core/reference/src/op/gather_tree.cpp +++ b/src/core/reference/src/op/gather_tree.cpp @@ -9,7 +9,6 @@ #include #include -#include "ngraph/check.hpp" #include "openvino/core/except.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" @@ -19,8 +18,8 @@ static size_t _asIndex(const char* source, const element::Type& element_type) { // According to the GatherTree op specification only I32 and FP32 precisions are supported. switch (element_type) { case element::Type_t::f16: { - ngraph::float16 tmpBuff = 0.f; - memcpy(&tmpBuff, source, sizeof(ngraph::float16)); + ov::float16 tmpBuff = 0.f; + memcpy(&tmpBuff, source, sizeof(ov::float16)); return static_cast(tmpBuff); } case element::Type_t::f32: { diff --git a/src/core/reference/src/op/generate_proposal.cpp b/src/core/reference/src/op/generate_proposal.cpp index 3e4f9b8707b..10fa35d0dcb 100644 --- a/src/core/reference/src/op/generate_proposal.cpp +++ b/src/core/reference/src/op/generate_proposal.cpp @@ -10,8 +10,8 @@ #include #include -#include "ngraph/op/generate_proposals.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/op/generate_proposals.hpp" struct sProposalBox { float x0; @@ -358,8 +358,8 @@ void generate_proposals(const std::vector& im_info, void generate_proposals_postprocessing(void* prois, void* pscores, void* proi_num, - const ngraph::element::Type& output_type, - const ngraph::element::Type& roi_num_type, + const element::Type& output_type, + const element::Type& roi_num_type, const std::vector& output_rois, const std::vector& output_scores, const std::vector& num_rois, diff --git a/src/core/reference/src/op/group_convolution.cpp b/src/core/reference/src/op/group_convolution.cpp index f3a85a8b37a..cb613b74ed7 100644 --- a/src/core/reference/src/op/group_convolution.cpp +++ b/src/core/reference/src/op/group_convolution.cpp @@ -15,19 +15,19 @@ void validate_group_convolution_parameters(const Shape& in_shape, const CoordinateDiff& pads_begin, const CoordinateDiff& pads_end) { // this implementation supports 1D, 2D and 3D convolutions - NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); + OPENVINO_ASSERT(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); - NGRAPH_CHECK(in_shape.size() + 1 == f_shape.size(), "Unsupported filter rank: ", f_shape.size()); + OPENVINO_ASSERT(in_shape.size() + 1 == f_shape.size(), "Unsupported filter rank: ", f_shape.size()); - NGRAPH_CHECK(in_shape.size() == out_shape.size(), - "Incompatible input and output ranks: ", - in_shape.size(), - " and ", - out_shape.size()); + OPENVINO_ASSERT(in_shape.size() == out_shape.size(), + "Incompatible input and output ranks: ", + in_shape.size(), + " and ", + out_shape.size()); const size_t groups = f_shape[filter_group_axis]; const size_t in_channels = in_shape[in_channel_axis]; - NGRAPH_CHECK(in_channels % groups == 0, "Input channels of data batch input must be multiple of groups"); + OPENVINO_ASSERT(in_channels % groups == 0, "Input channels of data batch input must be multiple of groups"); const Shape in_group_shape = [&]() { Shape new_shape{in_shape}; new_shape[in_channel_axis] /= groups; @@ -35,7 +35,7 @@ void validate_group_convolution_parameters(const Shape& in_shape, }(); const size_t out_channels = out_shape[out_channel_axis]; - NGRAPH_CHECK(out_channels % groups == 0, "Output channels of output must be multiple of groups"); + OPENVINO_ASSERT(out_channels % groups == 0, "Output channels of output must be multiple of groups"); const Shape out_group_shape = [&]() { Shape new_shape{out_shape}; new_shape[out_channel_axis] /= groups; diff --git a/src/core/reference/src/op/group_convolution_backprop_data.cpp b/src/core/reference/src/op/group_convolution_backprop_data.cpp index a45356d109d..dd8cacb46d7 100644 --- a/src/core/reference/src/op/group_convolution_backprop_data.cpp +++ b/src/core/reference/src/op/group_convolution_backprop_data.cpp @@ -32,27 +32,27 @@ void validate_convolution_backprop_data_parameters(const Shape& in_shape, const CoordinateDiff& pads_begin, const CoordinateDiff& pads_end) { // this implementation supports 1D, 2D and 3D convolutions - NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); - NGRAPH_CHECK(in_shape.size() == f_shape.size(), - "Incompatible input ranks: ", - in_shape.size(), - " and ", - f_shape.size()); - NGRAPH_CHECK(in_shape[in_channel_axis] == f_shape[filter_in_ch_axis], - "Incompatible input channels in data batch and filters shapes: ", - in_shape[in_channel_axis], - " and ", - f_shape[filter_in_ch_axis]); - NGRAPH_CHECK(in_shape.size() == out_shape.size(), - "Incompatible input and output ranks: ", - in_shape.size(), - " and ", - out_shape.size()); + OPENVINO_ASSERT(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); + OPENVINO_ASSERT(in_shape.size() == f_shape.size(), + "Incompatible input ranks: ", + in_shape.size(), + " and ", + f_shape.size()); + OPENVINO_ASSERT(in_shape[in_channel_axis] == f_shape[filter_in_ch_axis], + "Incompatible input channels in data batch and filters shapes: ", + in_shape[in_channel_axis], + " and ", + f_shape[filter_in_ch_axis]); + OPENVINO_ASSERT(in_shape.size() == out_shape.size(), + "Incompatible input and output ranks: ", + in_shape.size(), + " and ", + out_shape.size()); const auto spatial_dims = in_shape.size() - 2; - NGRAPH_CHECK(strides.size() == spatial_dims, "Strides not definied for all and only spatial dimensions"); - NGRAPH_CHECK(dilations.size() == spatial_dims, "Dilations not defined for all and only spatial dimensions"); - NGRAPH_CHECK((pads_begin.size() == pads_end.size()) && (pads_begin.size() == spatial_dims), - "Pads not defined for all and only spatial dimensions"); + OPENVINO_ASSERT(strides.size() == spatial_dims, "Strides not definied for all and only spatial dimensions"); + OPENVINO_ASSERT(dilations.size() == spatial_dims, "Dilations not defined for all and only spatial dimensions"); + OPENVINO_ASSERT((pads_begin.size() == pads_end.size()) && (pads_begin.size() == spatial_dims), + "Pads not defined for all and only spatial dimensions"); Shape out_spatial_shape{std::next(out_shape.begin(), 2), std::end(out_shape)}; Shape infered_out_spatial_shape{}; @@ -63,7 +63,7 @@ void validate_convolution_backprop_data_parameters(const Shape& in_shape, dilations, pads_begin, pads_end); - NGRAPH_CHECK(out_spatial_shape == infered_out_spatial_shape, "Incorrect output shape provided"); + OPENVINO_ASSERT(out_spatial_shape == infered_out_spatial_shape, "Incorrect output shape provided"); } void validate_group_convolution_backprop_data_parameters(const Shape& in_shape, @@ -74,19 +74,19 @@ void validate_group_convolution_backprop_data_parameters(const Shape& in_shape, const CoordinateDiff& pads_begin, const CoordinateDiff& pads_end) { // this implementation supports 1D, 2D and 3D convolutions - NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); + OPENVINO_ASSERT(in_shape.size() >= 3 && in_shape.size() <= 5, "Unsupported input rank: ", in_shape); - NGRAPH_CHECK(in_shape.size() + 1 == f_shape.size(), "Unsupported filter rank: ", f_shape.size()); + OPENVINO_ASSERT(in_shape.size() + 1 == f_shape.size(), "Unsupported filter rank: ", f_shape.size()); - NGRAPH_CHECK(in_shape.size() == out_shape.size(), - "Incompatible input and output ranks: ", - in_shape.size(), - " and ", - out_shape.size()); + OPENVINO_ASSERT(in_shape.size() == out_shape.size(), + "Incompatible input and output ranks: ", + in_shape.size(), + " and ", + out_shape.size()); const size_t groups = f_shape[filter_group_axis]; const size_t in_channels = in_shape[in_channel_axis]; - NGRAPH_CHECK(in_channels % groups == 0, "Input channels of data batch input must be multiple of groups"); + OPENVINO_ASSERT(in_channels % groups == 0, "Input channels of data batch input must be multiple of groups"); const Shape in_group_shape = [&]() { Shape new_shape{in_shape}; new_shape[in_channel_axis] /= groups; @@ -94,7 +94,7 @@ void validate_group_convolution_backprop_data_parameters(const Shape& in_shape, }(); const size_t out_channels = out_shape[out_channel_axis]; - NGRAPH_CHECK(out_channels % groups == 0, "Output channels of output must be multiple of groups"); + OPENVINO_ASSERT(out_channels % groups == 0, "Output channels of output must be multiple of groups"); const Shape out_group_shape = [&]() { Shape new_shape{out_shape}; new_shape[out_channel_axis] /= groups; diff --git a/src/core/reference/src/op/if.cpp b/src/core/reference/src/op/if.cpp index 4bffb99470e..3b74fa78108 100644 --- a/src/core/reference/src/op/if.cpp +++ b/src/core/reference/src/op/if.cpp @@ -4,7 +4,7 @@ #include "openvino/reference/if.hpp" -#include "ngraph/op/if.hpp" +#include "openvino/op/if.hpp" #include "openvino/reference/function.hpp" namespace ov { @@ -14,7 +14,7 @@ void if_reference(const std::vector>& bodies, const std::vector& input_descs, ov::TensorVector& out, const ov::TensorVector& args) { - NGRAPH_CHECK(args.size() > 0, "If operation must have input condition value"); + OPENVINO_ASSERT(args.size() > 0, "If operation must have input condition value"); auto condition_value = args[0].data()[0]; auto branch_index = (condition_value) ? op::v8::If::THEN_BODY_INDEX : op::v8::If::ELSE_BODY_INDEX; @@ -24,16 +24,16 @@ void if_reference(const std::vector>& bodies, auto inputs_size = args.size(); auto output_size = out.size(); for (const auto& input_desc : input_descs[branch_index]) { - NGRAPH_CHECK(inputs_size > input_desc->m_input_index, - "Incorrect associating! If has not input with id ", - input_desc->m_input_index); + OPENVINO_ASSERT(inputs_size > input_desc->m_input_index, + "Incorrect associating! If has not input with id ", + input_desc->m_input_index); inputs_to_body[input_desc->m_body_parameter_index] = args[input_desc->m_input_index]; } reference::function(bodies[branch_index], inputs_to_body, outs_from_body); for (const auto& out_descr : out_descs[branch_index]) { - NGRAPH_CHECK(output_size > out_descr->m_output_index, - "Incorrect associating! If has not output with id ", - out_descr->m_output_index); + OPENVINO_ASSERT(output_size > out_descr->m_output_index, + "Incorrect associating! If has not output with id ", + out_descr->m_output_index); auto res = outs_from_body[out_descr->m_body_value_index]; res.copy_to(out[out_descr->m_output_index]); } diff --git a/src/core/reference/src/op/interpolate.cpp b/src/core/reference/src/op/interpolate.cpp index 64bf9f8b469..e7b4deb9e84 100644 --- a/src/core/reference/src/op/interpolate.cpp +++ b/src/core/reference/src/op/interpolate.cpp @@ -6,9 +6,8 @@ #include -using namespace ov::reference; - -using Coordinate = ngraph::Coordinate; +namespace ov { +namespace reference { float InterpolateEvalHelper::triangle_coeff(float dz) { return std::max(0.0f, 1.0f - std::fabs(dz)); @@ -122,19 +121,15 @@ InterpolateEvalHelper::InfoForLinearMode InterpolateEvalHelper::get_info_for_lin std::vector a(num_of_axes); std::vector r(num_of_axes); - NGRAPH_SUPPRESS_DEPRECATED_START - CoordinateTransform output_transform(m_out_shape); - CoordinateTransform input_transform(m_input_data_shape); - - std::vector vector_for_indeces(num_of_axes); + std::vector vector_for_indices(num_of_axes); float prod_a = 1; for (std::size_t i = 0; i < num_of_axes; ++i) { a[i] = antialias ? m_scales[i] : 1.0f; prod_a *= a[i]; r[i] = (m_scales[i] > 1.0) ? static_cast(2) : static_cast(std::ceil(2.0f / a[i])); - vector_for_indeces[i] = 2 * r[i] + 1; + vector_for_indices[i] = 2 * r[i] + 1; } - Shape shape_for_indeces{vector_for_indeces}; + Shape shape_for_indices{vector_for_indices}; InfoForLinearMode result; @@ -142,8 +137,7 @@ InterpolateEvalHelper::InfoForLinearMode InterpolateEvalHelper::get_info_for_lin result.a = a; result.r = r; result.prod_a = prod_a; - result.shape_for_indeces = shape_for_indeces; - NGRAPH_SUPPRESS_DEPRECATED_END + result.shape_for_indices = shape_for_indices; return result; } @@ -228,3 +222,5 @@ InterpolateEvalHelper::LinearModeInnerIterationResult InterpolateEvalHelper::inn return result; } +} // namespace reference +} // namespace ov diff --git a/src/core/reference/src/op/irdft.cpp b/src/core/reference/src/op/irdft.cpp index d66c7a8e556..4046f8460e6 100644 --- a/src/core/reference/src/op/irdft.cpp +++ b/src/core/reference/src/op/irdft.cpp @@ -10,7 +10,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/fft.hpp" #include "openvino/reference/utils/fft_common.hpp" diff --git a/src/core/reference/src/op/matmul.cpp b/src/core/reference/src/op/matmul.cpp index 446e941e984..f84287c0256 100644 --- a/src/core/reference/src/op/matmul.cpp +++ b/src/core/reference/src/op/matmul.cpp @@ -9,14 +9,12 @@ #include #include -#include "ngraph/shape_util.hpp" - namespace ov { namespace reference { namespace details { std::vector get_transpose_order(const Shape& input_shape) { size_t rank = input_shape.size(); - NGRAPH_CHECK(rank > 1, "Invalid input for transpose"); + OPENVINO_ASSERT(rank > 1, "Invalid input for transpose"); std::vector axes_order(rank); std::iota(axes_order.begin(), axes_order.end(), 0); std::swap(axes_order[rank - 1], axes_order[rank - 2]); diff --git a/src/core/reference/src/op/matrix_nms.cpp b/src/core/reference/src/op/matrix_nms.cpp index bc6b3847854..2dfe451afe6 100644 --- a/src/core/reference/src/op/matrix_nms.cpp +++ b/src/core/reference/src/op/matrix_nms.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/op/matrix_nms.hpp" +#include "openvino/op/matrix_nms.hpp" #include #include @@ -10,7 +10,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/matrix_nms.hpp" #include "openvino/reference/utils/nms_common.hpp" diff --git a/src/core/reference/src/op/multiclass_nms.cpp b/src/core/reference/src/op/multiclass_nms.cpp index cee59347ac1..b38091c7dd7 100644 --- a/src/core/reference/src/op/multiclass_nms.cpp +++ b/src/core/reference/src/op/multiclass_nms.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/op/multiclass_nms.hpp" +#include "openvino/op/multiclass_nms.hpp" #include #include @@ -10,14 +10,13 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/multiclass_nms.hpp" #include "openvino/reference/utils/nms_common.hpp" namespace ov { namespace reference { namespace multiclass_nms_impl { -OPENVINO_SUPPRESS_DEPRECATED_START using Rectangle = reference::nms_common::Rectangle; using BoxInfo = reference::nms_common::BoxInfo; diff --git a/src/core/reference/src/op/pad.cpp b/src/core/reference/src/op/pad.cpp index ad8ab0f9f0b..e5894721357 100644 --- a/src/core/reference/src/op/pad.cpp +++ b/src/core/reference/src/op/pad.cpp @@ -6,8 +6,7 @@ #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/check.hpp" +#include "openvino/core/except.hpp" #include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" @@ -153,9 +152,9 @@ struct SymmetricAndReflectPad : PadBase { void check_inputs() const override { for (size_t i = 0; i != padding_begin.size(); ++i) { const auto axis_size = static_cast(data_shape[i]); - NGRAPH_CHECK(padding_begin.at(i) - axis_correction < axis_size, - "padding below should be less than data shape"); - NGRAPH_CHECK(padding_end.at(i) - axis_correction < axis_size, "padding should be less than data shape"); + OPENVINO_ASSERT(padding_begin.at(i) - axis_correction < axis_size, + "padding below should be less than data shape"); + OPENVINO_ASSERT(padding_end.at(i) - axis_correction < axis_size, "padding should be less than data shape"); } } diff --git a/src/core/reference/src/op/random_uniform.cpp b/src/core/reference/src/op/random_uniform.cpp index 99bd70aca40..01215b095d2 100644 --- a/src/core/reference/src/op/random_uniform.cpp +++ b/src/core/reference/src/op/random_uniform.cpp @@ -6,8 +6,8 @@ #include -#include "ngraph/shape.hpp" #include "openvino/core/except.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { @@ -140,7 +140,7 @@ void run_philox(uint64_t key, uint64_t counter, uint64_t n, size_t n_rounds, std template void convert_to_output_type(const std::vector& res, size_t step, - const ngraph::element::Type& elem_type, + const element::Type& elem_type, const char* min_val, const char* max_val, char* out, @@ -185,7 +185,7 @@ std::pair random_uniform(const uint64_t* out_shape, const char* max_val, char* out, const Shape& out_shape_shape, - const ngraph::element::Type& elem_type, + const element::Type& elem_type, uint64_t seed, uint64_t seed2, std::pair prev_state) { @@ -229,11 +229,11 @@ std::pair random_uniform(const uint64_t* out_shape, // convert values to corresponding output_type switch (elem_type) { - case ngraph::element::Type_t::f32: { + case element::Type_t::f32: { convert_to_output_type(res, step, elem_type, min_val, max_val, out, k, elem_count, uint32_to_float); break; } - case ngraph::element::Type_t::f16: { + case element::Type_t::f16: { convert_to_output_type(res, step, elem_type, @@ -245,7 +245,7 @@ std::pair random_uniform(const uint64_t* out_shape, uint32_to_float16); break; } - case ngraph::element::Type_t::bf16: { + case element::Type_t::bf16: { convert_to_output_type(res, step, elem_type, @@ -257,7 +257,7 @@ std::pair random_uniform(const uint64_t* out_shape, uint32_to_bfloat16); break; } - case ngraph::element::Type_t::f64: { + case element::Type_t::f64: { convert_to_output_type(res, step, elem_type, @@ -272,7 +272,7 @@ std::pair random_uniform(const uint64_t* out_shape, }); break; } - case ngraph::element::Type_t::i32: { + case element::Type_t::i32: { convert_to_output_type(res, step, elem_type, @@ -288,7 +288,7 @@ std::pair random_uniform(const uint64_t* out_shape, }); break; } - case ngraph::element::Type_t::i64: { + case element::Type_t::i64: { convert_to_output_type(res, step, elem_type, diff --git a/src/core/reference/src/op/rdft.cpp b/src/core/reference/src/op/rdft.cpp index 771b05f73a1..f24eadfa559 100644 --- a/src/core/reference/src/op/rdft.cpp +++ b/src/core/reference/src/op/rdft.cpp @@ -19,7 +19,7 @@ #include #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" #include "openvino/reference/fft.hpp" #include "openvino/reference/utils/fft_common.hpp" diff --git a/src/core/reference/src/op/reorg_yolo.cpp b/src/core/reference/src/op/reorg_yolo.cpp index 1bf40680e16..68d1936e893 100644 --- a/src/core/reference/src/op/reorg_yolo.cpp +++ b/src/core/reference/src/op/reorg_yolo.cpp @@ -8,7 +8,7 @@ #include -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/src/op/reshape.cpp b/src/core/reference/src/op/reshape.cpp index 2da555a542a..dec23afda86 100644 --- a/src/core/reference/src/op/reshape.cpp +++ b/src/core/reference/src/op/reshape.cpp @@ -7,7 +7,7 @@ #include #include -#include "ngraph/check.hpp" +#include "openvino/core/except.hpp" #include "openvino/reference/utils/coordinate_range.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" @@ -17,7 +17,7 @@ namespace { std::vector reorder(const std::vector& origin, const AxisVector& order) { std::vector reordered = origin; auto out = begin(reordered); - NGRAPH_CHECK(origin.size() <= order.size()); + OPENVINO_ASSERT(origin.size() <= order.size()); for (size_t i = 0; i < origin.size(); ++i) { *out = origin.at(order[i]); ++out; diff --git a/src/core/reference/src/op/reverse.cpp b/src/core/reference/src/op/reverse.cpp index 0794a81ce95..e0555cde41e 100644 --- a/src/core/reference/src/op/reverse.cpp +++ b/src/core/reference/src/op/reverse.cpp @@ -8,11 +8,9 @@ #include #include -#include "ngraph/check.hpp" +#include "openvino/core/except.hpp" #include "openvino/reference/utils/coordinate_range.hpp" -using namespace ngraph; - namespace ov { namespace reference { void reverse(const char* arg, @@ -21,7 +19,7 @@ void reverse(const char* arg, const Shape& out_shape, const AxisSet& reversed_axes, size_t elem_size) { - NGRAPH_CHECK(shape_size(arg_shape) == shape_size(out_shape)); + OPENVINO_ASSERT(shape_size(arg_shape) == shape_size(out_shape)); const bool nothing_to_revers = reversed_axes.empty(); if (nothing_to_revers) { diff --git a/src/core/reference/src/op/slice.cpp b/src/core/reference/src/op/slice.cpp index 4ea9bc6be23..4f01cbce8a8 100644 --- a/src/core/reference/src/op/slice.cpp +++ b/src/core/reference/src/op/slice.cpp @@ -6,9 +6,9 @@ #include -#include "ngraph/check.hpp" #include "openvino/core/except.hpp" #include "openvino/reference/utils/coordinate_range.hpp" +#include "openvino/util/common_util.hpp" namespace ov { namespace reference { @@ -64,12 +64,21 @@ void slice(const char* arg, const Strides& strides, const Shape& out_shape, size_t elem_size) { - NGRAPH_SUPPRESS_DEPRECATED_START - const CoordinateTransform input_transform(arg_shape, lower_bounds, upper_bounds, strides); + const auto rank = arg_shape.size(); + OPENVINO_ASSERT( + lower_bounds.size() == rank && upper_bounds.size() == rank && strides.size() == rank && + out_shape.size() == rank, + "arg_shape, lower_bounds, upper_bounds, strides and out_shape are expected to have the same rank equal ", + rank); - const CoordinateTransform output_transform(out_shape); - - NGRAPH_CHECK(shape_size(input_transform.get_target_shape()) == shape_size(output_transform.get_target_shape())); + auto expected_out_shape = Shape(arg_shape); + for (size_t i = 0; i < rank; ++i) + expected_out_shape[i] = util::ceil_div(upper_bounds[i] - lower_bounds[i], strides[i]); + OPENVINO_ASSERT(out_shape == expected_out_shape, + "Expected output shape is ", + expected_out_shape, + ". Got ", + out_shape); auto dst_mem = out; @@ -81,7 +90,6 @@ void slice(const char* arg, std::advance(dst_mem, elem_size); } } - NGRAPH_SUPPRESS_DEPRECATED_END } } // namespace reference } // namespace ov diff --git a/src/core/reference/src/op/space_to_depth.cpp b/src/core/reference/src/op/space_to_depth.cpp index 646550bb2a6..247efe39412 100644 --- a/src/core/reference/src/op/space_to_depth.cpp +++ b/src/core/reference/src/op/space_to_depth.cpp @@ -6,8 +6,8 @@ #include -#include "ngraph/check.hpp" #include "ngraph/runtime/opt_kernel/reshape.hpp" +#include "openvino/core/except.hpp" namespace ov { namespace reference { @@ -34,13 +34,13 @@ void space_to_depth(const char* const in, const size_t spatial_dims = in_shape.size() - spatial_dim_index; for (size_t i = spatial_dim_index; i < in_shape.size(); ++i) { - NGRAPH_CHECK(block_size > 0 && in_shape.at(i) % block_size == 0, - "SpaceToDepth: The dimension on position: ", - i, - " equal to: ", - in_shape.at(i), - " must be a multiple of blocksize: ", - block_size); + OPENVINO_ASSERT(block_size > 0 && in_shape.at(i) % block_size == 0, + "SpaceToDepth: The dimension on position: ", + i, + " equal to: ", + in_shape.at(i), + " must be a multiple of blocksize: ", + block_size); } Shape dispersed_shape{n_dim, c_dim}; diff --git a/src/core/reference/src/op/split.cpp b/src/core/reference/src/op/split.cpp index 41a2dba235a..6186bdd5af9 100644 --- a/src/core/reference/src/op/split.cpp +++ b/src/core/reference/src/op/split.cpp @@ -8,8 +8,6 @@ #include -#include "ngraph/check.hpp" - using namespace ov; void reference::split(const char* data, diff --git a/src/core/reference/src/op/strided_slice.cpp b/src/core/reference/src/op/strided_slice.cpp index 457a65dec5d..2ff07ba8500 100644 --- a/src/core/reference/src/op/strided_slice.cpp +++ b/src/core/reference/src/op/strided_slice.cpp @@ -8,7 +8,6 @@ #include -#include "ngraph/check.hpp" #include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/opt_kernel/reshape.hpp" diff --git a/src/core/reference/src/op/transpose.cpp b/src/core/reference/src/op/transpose.cpp index 1706775f572..5b893ccc569 100644 --- a/src/core/reference/src/op/transpose.cpp +++ b/src/core/reference/src/op/transpose.cpp @@ -10,7 +10,7 @@ #include #include "ngraph/runtime/opt_kernel/reshape.hpp" -#include "ngraph/shape.hpp" +#include "openvino/core/shape.hpp" namespace ov { namespace reference { diff --git a/src/core/reference/src/op/utils/fft_common.cpp b/src/core/reference/src/op/utils/fft_common.cpp index ec9c91a1211..6bf248190ed 100644 --- a/src/core/reference/src/op/utils/fft_common.cpp +++ b/src/core/reference/src/op/utils/fft_common.cpp @@ -10,12 +10,10 @@ #include #include -#include "ngraph/check.hpp" - namespace ov { namespace reference { namespace fft_common { -std::vector reverse_shape_of_emulated_complex_tensor(const ngraph::Shape& shape) { +std::vector reverse_shape_of_emulated_complex_tensor(const Shape& shape) { assert(shape.size() >= 2); std::vector reversed_shape(shape.begin(), shape.end() - 1); std::reverse(reversed_shape.begin(), reversed_shape.end()); diff --git a/src/core/reference/src/op/utils/nms_common.cpp b/src/core/reference/src/op/utils/nms_common.cpp index a8dfe16687b..427ef985fe8 100644 --- a/src/core/reference/src/op/utils/nms_common.cpp +++ b/src/core/reference/src/op/utils/nms_common.cpp @@ -8,7 +8,7 @@ #include #include -#include "ngraph/check.hpp" +#include "openvino/core/except.hpp" namespace ov { namespace reference { @@ -16,11 +16,11 @@ namespace nms_common { void nms_common_postprocessing(void* prois, void* pscores, void* pselected_num, - const ngraph::element::Type& output_type, + const element::Type& output_type, const std::vector& selected_outputs, const std::vector& selected_indices, const std::vector& valid_outputs, - const ngraph::element::Type& selected_outputs_type) { + const element::Type& selected_outputs_type) { int64_t total_num = std::accumulate(valid_outputs.begin(), valid_outputs.end(), int64_t(0)); switch (selected_outputs_type) { @@ -41,11 +41,11 @@ void nms_common_postprocessing(void* prois, memcpy(ptr, selected_outputs.data(), total_num * sizeof(float) * 6); } break; default: - NGRAPH_UNREACHABLE("unsupported element type, should be [bf16, f16, f32]"); + OPENVINO_THROW("unsupported element type, should be [bf16, f16, f32]"); } if (pscores) { - if (output_type == ngraph::element::i64) { + if (output_type == element::i64) { int64_t* indices_ptr = static_cast(pscores); memcpy(indices_ptr, selected_indices.data(), total_num * sizeof(int64_t)); } else { @@ -57,7 +57,7 @@ void nms_common_postprocessing(void* prois, } if (pselected_num) { - if (output_type == ngraph::element::i64) { + if (output_type == element::i64) { int64_t* valid_outputs_ptr = static_cast(pselected_num); std::copy(valid_outputs.begin(), valid_outputs.end(), valid_outputs_ptr); } else { diff --git a/src/core/reference/src/op/utils/round_guard.cpp b/src/core/reference/src/op/utils/round_guard.cpp deleted file mode 100644 index 565fb2db598..00000000000 --- a/src/core/reference/src/op/utils/round_guard.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "openvino/reference/round_guard.hpp" - -namespace ov { -RoundGuard::RoundGuard(int mode) : m_prev_round_mode{std::fegetround()} { - std::fesetround(mode); -} - -RoundGuard::~RoundGuard() { - std::fesetround(m_prev_round_mode); -} -} // namespace ov diff --git a/src/core/reference/src/op/utils/rounding_guard.cpp b/src/core/reference/src/op/utils/rounding_guard.cpp new file mode 100644 index 00000000000..70b0ce5897e --- /dev/null +++ b/src/core/reference/src/op/utils/rounding_guard.cpp @@ -0,0 +1,11 @@ +#include "openvino/reference/rounding_guard.hpp" + +namespace ov { +RoundingGuard::RoundingGuard(int mode) : m_prev_round_mode{std::fegetround()} { + std::fesetround(mode); +} + +RoundingGuard::~RoundingGuard() { + std::fesetround(m_prev_round_mode); +} +} // namespace ov diff --git a/src/core/reference/src/runtime/opt_kernel/reshape.cpp b/src/core/reference/src/runtime/opt_kernel/reshape.cpp index f1bfd265182..e0ca720845c 100644 --- a/src/core/reference/src/runtime/opt_kernel/reshape.cpp +++ b/src/core/reference/src/runtime/opt_kernel/reshape.cpp @@ -7,7 +7,6 @@ #include #include -#include "ngraph/check.hpp" #include "openvino/core/parallel.hpp" #include "openvino/reference/reshape.hpp" diff --git a/src/core/reference/src/utils/coordinate_transform.cpp b/src/core/reference/src/utils/coordinate_transform.cpp index e62f6154652..cd97834e6d0 100644 --- a/src/core/reference/src/utils/coordinate_transform.cpp +++ b/src/core/reference/src/utils/coordinate_transform.cpp @@ -11,17 +11,16 @@ #include #include -#include "ngraph/axis_vector.hpp" -#include "ngraph/coordinate_diff.hpp" -#include "ngraph/except.hpp" -#include "ngraph/shape.hpp" -#include "ngraph/strides.hpp" #include "ngraph/util.hpp" +#include "openvino/core/axis_vector.hpp" +#include "openvino/core/coordinate_diff.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/strides.hpp" #include "openvino/reference/utils/coordinate_index.hpp" using namespace ov; -NGRAPH_SUPPRESS_DEPRECATED_START +OPENVINO_SUPPRESS_DEPRECATED_START namespace { Strides default_strides(size_t n_axes) { return Strides(n_axes, 1); @@ -42,6 +41,7 @@ Coordinate default_source_end_corner(const Shape& source_shape) { return source_shape; } } // namespace +OPENVINO_SUPPRESS_DEPRECATED_END CoordinateTransformBasic::CoordinateTransformBasic(const Shape& source_shape) : m_source_shape(source_shape) {} @@ -58,6 +58,7 @@ const CoordinateIterator& CoordinateTransformBasic::end() const noexcept { return CoordinateIterator::end(); } +OPENVINO_SUPPRESS_DEPRECATED_START CoordinateTransform::CoordinateTransform(const Shape& source_shape, const Coordinate& source_start_corner, const Coordinate& source_end_corner, @@ -123,11 +124,9 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, std::vector padded_upper_bounds; for (size_t i = 0; i < m_n_axes; i++) { - NGRAPH_SUPPRESS_DEPRECATED_START std::ptrdiff_t padded_upper_bound = ngraph::subtract_or_zero(source_shape[i], size_t(1)) * target_dilation_strides[i] + 1 + target_padding_below[i] + target_padding_above[i]; - NGRAPH_SUPPRESS_DEPRECATED_END if (padded_upper_bound < 0) { std::stringstream ss; @@ -343,6 +342,7 @@ CoordinateIterator CoordinateTransform::begin() const noexcept { const CoordinateIterator& CoordinateTransform::end() const noexcept { return CoordinateIterator::end(); } +OPENVINO_SUPPRESS_DEPRECATED_END // The "is_end" parameter is true if we want the "end()" iterator. CoordinateIterator::CoordinateIterator(const Shape& target_shape, bool is_end) diff --git a/src/core/src/op/interpolate.cpp b/src/core/src/op/interpolate.cpp index d541cc9ed37..47cfe4e169f 100644 --- a/src/core/src/op/interpolate.cpp +++ b/src/core/src/op/interpolate.cpp @@ -174,30 +174,6 @@ std::vector get_scales_vector(const ov::TensorVector& args, } } // namespace -static void pad_input_data(const uint8_t* data_ptr, - uint8_t* padded_data_ptr, - size_t type_size, - const ov::Shape& input_shape, - const ov::Shape& padded_input_shape, - const std::vector& pads_begin) { - NGRAPH_SUPPRESS_DEPRECATED_START - ov::CoordinateTransform input_transform(input_shape); - ov::CoordinateTransform padded_transform(padded_input_shape); - - for (const ngraph::Coordinate& input_coord : input_transform) { - auto padded_coord = input_coord; - size_t i = 0; - for (size_t pad : pads_begin) { - padded_coord[i] += pad; - ++i; - } - uint8_t* dst_ptr = padded_data_ptr + type_size * padded_transform.index(padded_coord); - const uint8_t* src_ptr = data_ptr + type_size * input_transform.index(input_coord); - memcpy(dst_ptr, src_ptr, type_size); - } - NGRAPH_SUPPRESS_DEPRECATED_END -} - bool ov::op::v4::Interpolate::evaluate_interpolate(TensorVector& outputs, const TensorVector& inputs) const { auto input_shapes = std::vector(); const auto inputs_num = inputs.size(); @@ -229,7 +205,12 @@ bool ov::op::v4::Interpolate::evaluate_interpolate(TensorVector& outputs, const auto* data_ptr = static_cast(inputs[data_port].data()); auto* padded_data_ptr = padded_input_data.data(); - pad_input_data(data_ptr, padded_data_ptr, type_size, inputs[data_port].get_shape(), padded_input_shape, pads_begin); + reference::pad_input_data(data_ptr, + padded_data_ptr, + type_size, + inputs[data_port].get_shape(), + padded_input_shape, + pads_begin); switch (input_et) { case element::Type_t::f32: diff --git a/src/core/src/op/topk.cpp b/src/core/src/op/topk.cpp index 36ba74a7977..485dc4e91fe 100644 --- a/src/core/src/op/topk.cpp +++ b/src/core/src/op/topk.cpp @@ -8,26 +8,25 @@ #include #include "itt.hpp" -#include "ngraph/attribute_visitor.hpp" -#include "ngraph/axis_vector.hpp" -#include "ngraph/op/constant.hpp" -#include "ngraph/op/util/op_types.hpp" -#include "ngraph/runtime/host_tensor.hpp" -#include "ngraph/shape.hpp" -#include "ngraph/validation_util.hpp" +#include "openvino/core/attribute_visitor.hpp" +#include "openvino/core/axis_vector.hpp" #include "openvino/core/dimension_tracker.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/validation_util.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/util/op_types.hpp" #include "openvino/reference/topk.hpp" using namespace std; -using namespace ngraph; +namespace ov { OPENVINO_SUPPRESS_DEPRECATED_START namespace topk { namespace { template -inline bool evaluate_execute(const HostTensorPtr& arg0, - const HostTensorPtr& out_indices, - const HostTensorPtr& out_values, +inline bool evaluate_execute(const ngraph::HostTensorPtr& arg0, + const ngraph::HostTensorPtr& out_indices, + const ngraph::HostTensorPtr& out_values, const ov::Shape out_shape, const size_t axis, const size_t k, @@ -61,9 +60,9 @@ inline bool evaluate_execute(const HostTensorPtr& arg0, } break template -bool evaluate(const HostTensorPtr& arg, - const HostTensorPtr& out_indices, - const HostTensorPtr& out_values, +bool evaluate(const ngraph::HostTensorPtr& arg, + const ngraph::HostTensorPtr& out_indices, + const ngraph::HostTensorPtr& out_values, const ov::Shape out_shape, const size_t axis, const size_t k, @@ -81,9 +80,9 @@ bool evaluate(const HostTensorPtr& arg, return rc; } -bool evaluate_topk(const HostTensorPtr& arg, - const HostTensorPtr& out_indices, - const HostTensorPtr& out_values, +bool evaluate_topk(const ngraph::HostTensorPtr& arg, + const ngraph::HostTensorPtr& out_indices, + const ngraph::HostTensorPtr& out_values, const ov::Shape out_shape, const size_t axis, const size_t k, @@ -185,12 +184,12 @@ bool op::v1::TopK::has_evaluate() const { OV_OP_SCOPE(v1_TopK_has_evaluate); switch (get_input_element_type(0)) { - case ngraph::element::i32: - case ngraph::element::i64: - case ngraph::element::u32: - case ngraph::element::u64: - case ngraph::element::f16: - case ngraph::element::f32: + case element::i32: + case element::i64: + case element::u32: + case element::u64: + case element::f16: + case element::f32: break; default: return false; @@ -198,23 +197,23 @@ bool op::v1::TopK::has_evaluate() const { if (op::util::is_constant(input_value(1).get_node())) { switch (get_input_element_type(1)) { - case ngraph::element::i8: - case ngraph::element::i32: - case ngraph::element::i64: + case element::i8: + case element::i32: + case element::i64: break; default: return false; } } else { switch (get_input_element_type(1)) { - case ngraph::element::i8: - case ngraph::element::i16: - case ngraph::element::i32: - case ngraph::element::i64: - case ngraph::element::u8: - case ngraph::element::u16: - case ngraph::element::u32: - case ngraph::element::u64: + case element::i8: + case element::i16: + case element::i32: + case element::i64: + case element::u8: + case element::u16: + case element::u32: + case element::u64: break; default: return false; @@ -258,12 +257,12 @@ bool op::v3::TopK::has_evaluate() const { OV_OP_SCOPE(v3_TopK_has_evaluate); switch (get_input_element_type(0)) { - case ngraph::element::i32: - case ngraph::element::i64: - case ngraph::element::u32: - case ngraph::element::u64: - case ngraph::element::f16: - case ngraph::element::f32: + case element::i32: + case element::i64: + case element::u32: + case element::u64: + case element::f16: + case element::f32: break; default: return false; @@ -271,23 +270,23 @@ bool op::v3::TopK::has_evaluate() const { if (op::util::is_constant(input_value(1).get_node())) { switch (get_input_element_type(1)) { - case ngraph::element::i8: - case ngraph::element::i32: - case ngraph::element::i64: + case element::i8: + case element::i32: + case element::i64: break; default: return false; } } else { switch (get_input_element_type(1)) { - case ngraph::element::i8: - case ngraph::element::i16: - case ngraph::element::i32: - case ngraph::element::i64: - case ngraph::element::u8: - case ngraph::element::u16: - case ngraph::element::u32: - case ngraph::element::u64: + case element::i8: + case element::i16: + case element::i32: + case element::i64: + case element::u8: + case element::u16: + case element::u32: + case element::u64: break; default: return false; @@ -360,15 +359,16 @@ bool ov::op::v11::TopK::has_evaluate() const { OV_OP_SCOPE(v11_TopK_has_evaluate); switch (get_input_element_type(0)) { - case ngraph::element::i32: - case ngraph::element::i64: - case ngraph::element::u32: - case ngraph::element::u64: - case ngraph::element::f16: - case ngraph::element::f32: + case element::i32: + case element::i64: + case element::u32: + case element::u64: + case element::f16: + case element::f32: break; default: return false; } return true; } +} // namespace ov diff --git a/src/plugins/template/backend/ops/interpolate.cpp b/src/plugins/template/backend/ops/interpolate.cpp index 1d1cc4caced..180488da0ff 100644 --- a/src/plugins/template/backend/ops/interpolate.cpp +++ b/src/plugins/template/backend/ops/interpolate.cpp @@ -170,30 +170,6 @@ std::vector get_scales_vector(const ov::TensorVector& args, return scales; } -static void pad_input_data(const uint8_t* data_ptr, - uint8_t* padded_data_ptr, - size_t type_size, - const ov::Shape& input_shape, - const ov::Shape& padded_input_shape, - const std::vector& pads_begin) { - OPENVINO_SUPPRESS_DEPRECATED_START - ov::CoordinateTransform input_transform(input_shape); - ov::CoordinateTransform padded_transform(padded_input_shape); - - for (const ov::Coordinate& input_coord : input_transform) { - auto padded_coord = input_coord; - size_t i = 0; - for (size_t pad : pads_begin) { - padded_coord[i] += pad; - ++i; - } - uint8_t* dst_ptr = padded_data_ptr + type_size * padded_transform.index(padded_coord); - const uint8_t* src_ptr = data_ptr + type_size * input_transform.index(input_coord); - memcpy(dst_ptr, src_ptr, type_size); - } - OPENVINO_SUPPRESS_DEPRECATED_END -} - namespace v11 { bool evaluate_interpolate(const std::shared_ptr& op, ov::TensorVector& outputs, @@ -236,12 +212,12 @@ bool evaluate_interpolate(const std::shared_ptr& op, const uint8_t* data_ptr = static_cast(inputs[0].data()); uint8_t* padded_data_ptr = padded_input_data.data(); - pad_input_data(data_ptr, - padded_data_ptr, - type_size, - input_shape.to_shape(), - padded_input_shape, - m_attrs.pads_begin); + reference::pad_input_data(data_ptr, + padded_data_ptr, + type_size, + input_shape.to_shape(), + padded_input_shape, + m_attrs.pads_begin); switch (input_et) { case element::f32: diff --git a/src/plugins/template/tests/functional/op_reference/avg_pool.cpp b/src/plugins/template/tests/functional/op_reference/avg_pool.cpp index 7bf582184f3..fdd5205f993 100644 --- a/src/plugins/template/tests/functional/op_reference/avg_pool.cpp +++ b/src/plugins/template/tests/functional/op_reference/avg_pool.cpp @@ -60,7 +60,7 @@ struct AvgPoolParams { class ReferenceAvgPoolLayerTest : public testing::TestWithParam, public CommonReferenceTest { public: void SetUp() override { - auto params = GetParam(); + const auto& params = GetParam(); function = CreateFunction(params.m_input_shape, params.m_input_type, params.m_strides, @@ -75,11 +75,11 @@ public: } static std::string getTestCaseName(const testing::TestParamInfo& obj) { - auto params = obj.param; + const auto& params = obj.param; std::ostringstream result; result << "iShape=" << params.m_input_shape << "_"; result << "iType=" << params.m_input_type << "_"; - result << "iShape=" << params.m_output_shape << "_"; + result << "oShape=" << params.m_output_shape << "_"; result << "oType=" << params.m_output_type << "_"; result << "excludePad=" << params.m_exclude_pad << "_"; result << "roundingType=" << params.m_rounding_type << "_"; @@ -126,6 +126,32 @@ std::vector generateParamsForAvgPool() { using T = typename element_type_traits::value_type; std::vector params{ + AvgPoolParams(ov::Shape{1, 1, 5}, + ov::Shape{1, 1, 5}, + IN_ET, + IN_ET, + std::vector{1, 2, 3, 4, 5}, + std::vector{1.5, 2.5, 3.5, 4.5, 5}, + Strides{1}, + Shape{0}, + Shape{1}, + Shape{2}, + true, + op::RoundingType::FLOOR, + op::PadType::EXPLICIT), + AvgPoolParams(ov::Shape{1, 1, 8}, + ov::Shape{1, 1, 4}, + IN_ET, + IN_ET, + std::vector{1, 2, 3, 4, 5, 6, 7, 8}, + std::vector{2, 4, 6, 7.5}, + Strides{2}, + Shape{0}, + Shape{0}, + Shape{3}, + false, + op::RoundingType::CEIL, + op::PadType::EXPLICIT), AvgPoolParams(ov::Shape{1, 1, 3, 3}, ov::Shape{1, 1, 2, 2}, IN_ET,