[Interpolate-11] Reference implementation for Interpolate-11 (#16342)

* Reference impl for interpolate-11 init

* ND support init

* Tests clean up

* Add evaluate method for Interpolate-11

* New version tests init

* Type parametrized tests

* Tests duplication clean up and reusage of v4 test cases

* Add clipping to the type bounds

* Style fix

* Add float type tests

* Fix default ports values

* Commented code clean up

* Add passing cube_coeff param

* Tests clean up

* Add separate namespace

* Adjust variable names

* Adjust function name

* Use vectors instead of raw ptrs

* update func to static inline

* Adjust types

* Add Interpolate-11 to template plugin evaluates map

* Revert interpolate-11 core evaluate support

* Use const ref to filter

* Use static cast

* Update link
This commit is contained in:
Katarzyna Mitrus 2023-03-29 07:11:56 +02:00 committed by GitHub
parent daf562832f
commit f7891aa034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1196 additions and 2 deletions

View File

@ -1605,3 +1605,38 @@ Some of the benchmark data in testdata/ is licensed differently:
domain; the latter does not have expired copyright, but is still in the
public domain according to the license information
(http://www.gutenberg.org/ebooks/53).
-------------------------------------------------------------
29. Pillow (https://github.com/python-pillow/Pillow)
The Python Imaging Library (PIL) is
Copyright © 1997-2011 by Secret Labs AB
Copyright © 1995-2011 by Fredrik Lundh
Pillow is the friendly PIL fork. It is
Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors.
Like PIL, Pillow is licensed under the open source HPND License:
By obtaining, using, and/or copying this software and/or its associated
documentation, you agree that you have read, understood, and will comply
with the following terms and conditions:
Permission to use, copy, modify and distribute this software and its
documentation for any purpose and without fee is hereby granted,
provided that the above copyright notice appears in all copies, and that
both that copyright notice and this permission notice appear in supporting
documentation, and that the name of Secret Labs AB or the author not be
used in advertising or publicity pertaining to distribution of the software
without specific, written prior permission.
SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.
IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,
INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.

View File

@ -12,9 +12,11 @@
#include <functional>
#include <map>
#include "interpolate_pil.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/op/interpolate.hpp"
#include "ngraph/shape_util.hpp"
#include "transpose.hpp"
namespace ngraph {
namespace runtime {
@ -302,6 +304,12 @@ public:
case InterpolateMode::CUBIC:
cubic_func(input_data, out);
break;
case InterpolateMode::BILINEAR_PILLOW:
bilinear_pil_func(input_data, out);
break;
case InterpolateMode::BICUBIC_PILLOW:
bicubic_pil_func(input_data, out);
break;
default:
OPENVINO_THROW("Unsupported interpolation mode");
break;
@ -345,6 +353,10 @@ private:
/// \param input_data pointer to input data
/// \param out pointer to memory block for output data
void nearest_func(const T* input_data, T* out);
void bilinear_pil_func(const T* input_data, T* out);
void bicubic_pil_func(const T* input_data, T* out);
void multidim_pil_func(const T* input_data, T* out, const interpolate_pil::filter& filterp);
};
template <typename T>
@ -564,6 +576,106 @@ void InterpolateEval<T>::cubic_func(const T* input_data, T* out) {
NGRAPH_SUPPRESS_DEPRECATED_END
}
template <typename T>
void InterpolateEval<T>::bilinear_pil_func(const T* input_data, T* out) {
struct interpolate_pil::filter bilinear = {interpolate_pil::bilinear_filter, 1.0, m_cube_coeff};
multidim_pil_func(input_data, out, bilinear);
}
template <typename T>
void InterpolateEval<T>::bicubic_pil_func(const T* input_data, T* out) {
struct interpolate_pil::filter bicubic = {interpolate_pil::bicubic_filter, 2.0, m_cube_coeff};
multidim_pil_func(input_data, out, bicubic);
}
template <typename T>
void InterpolateEval<T>::multidim_pil_func(const T* input_data, T* out, const interpolate_pil::filter& filterp) {
OPENVINO_ASSERT(m_axes.size() == 2, "For Pillow based modes exactly two (HW) axes need to be provided.");
auto h_dim_idx = m_axes[0];
auto w_dim_idx = m_axes[1];
auto h_dim_in = m_input_data_shape[h_dim_idx];
auto w_dim_in = m_input_data_shape[w_dim_idx];
auto h_dim_out = m_out_shape[h_dim_idx];
auto w_dim_out = m_out_shape[w_dim_idx];
auto in_matrix_elem_size = h_dim_in * w_dim_in;
auto out_matrix_elem_size = h_dim_out * w_dim_out;
auto box = std::vector<float>{0.f, 0.f, static_cast<float>(w_dim_in), static_cast<float>(h_dim_in)};
if (shape_size(m_input_data_shape) == in_matrix_elem_size) {
// Input data is 2D or ND with other dimensions equal 1
interpolate_pil::imaging_resample_inner(input_data,
w_dim_in,
h_dim_in,
w_dim_out,
h_dim_out,
filterp,
box.data(),
out);
} else {
// Flatten other dimensions and interpolate over 2D matrices
std::vector<int64_t> in_transp_axes_order;
for (size_t i = 0; i < m_input_data_shape.size(); ++i) {
if (std::find(m_axes.begin(), m_axes.end(), i) == m_axes.end()) {
in_transp_axes_order.push_back(i);
}
}
in_transp_axes_order.insert(in_transp_axes_order.end(), m_axes.begin(), m_axes.end());
Shape transp_input_shape;
Shape transp_output_shape;
for (auto&& axis : in_transp_axes_order) {
transp_input_shape.push_back(m_input_data_shape[axis]);
transp_output_shape.push_back(m_out_shape[axis]);
}
size_t flat_batch_size =
transp_input_shape.size() > 2
? shape_size(transp_input_shape.begin(), transp_input_shape.begin() + transp_input_shape.size() - 2)
: 1;
// Transpose HW dimensions to the end of the tensor shape
std::vector<T> transposed_in(input_data, input_data + shape_size(m_input_data_shape));
transpose(reinterpret_cast<const char*>(input_data),
reinterpret_cast<char*>(transposed_in.data()),
m_input_data_shape,
sizeof(T),
in_transp_axes_order.data(),
transp_input_shape);
std::vector<T> transposed_out(shape_size(m_out_shape));
T* in_matrix_ptr = transposed_in.data();
T* out_matrix_ptr = transposed_out.data();
// Resample each 2D matrix
for (size_t i = 0; i < flat_batch_size; ++i) {
interpolate_pil::imaging_resample_inner(in_matrix_ptr,
w_dim_in,
h_dim_in,
w_dim_out,
h_dim_out,
filterp,
box.data(),
out_matrix_ptr);
in_matrix_ptr += in_matrix_elem_size;
out_matrix_ptr += out_matrix_elem_size;
}
std::vector<int64_t> out_transp_axes_order(m_out_shape.size() - 2);
std::iota(out_transp_axes_order.begin(), out_transp_axes_order.end(), 0);
out_transp_axes_order.insert(out_transp_axes_order.begin() + h_dim_idx, transp_input_shape.size() - 2);
out_transp_axes_order.insert(out_transp_axes_order.begin() + w_dim_idx, transp_input_shape.size() - 1);
// Transpose back to the original data dimensions order
transpose(reinterpret_cast<const char*>(transposed_out.data()),
reinterpret_cast<char*>(out),
transp_output_shape,
sizeof(T),
out_transp_axes_order.data(),
m_out_shape);
}
}
template <typename T>
void InterpolateEval<T>::nearest_func(const T* input_data, T* out) {
NGRAPH_SUPPRESS_DEPRECATED_START

View File

@ -0,0 +1,320 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
// The implementation for BILINEAR_PILLOW and BICUBIC_PILLOW is based on the
// Pillow library code from:
// https://github.com/python-pillow/Pillow/blob/9.4.0/src/libImaging/Resample.c
// The Python Imaging Library (PIL) is
// Copyright © 1997-2011 by Secret Labs AB
// Copyright © 1995-2011 by Fredrik Lundh
// Pillow is the friendly PIL fork. It is
// Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors.
// Like PIL, Pillow is licensed under the open source HPND License:
// By obtaining, using, and/or copying this software and/or its associated
// documentation, you agree that you have read, understood, and will comply
// with the following terms and conditions:
// Permission to use, copy, modify and distribute this software and its
// documentation for any purpose and without fee is hereby granted,
// provided that the above copyright notice appears in all copies, and that
// both that copyright notice and this permission notice appear in supporting
// documentation, and that the name of Secret Labs AB or the author not be
// used in advertising or publicity pertaining to distribution of the software
// without specific, written prior permission.
// SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
// SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.
// IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,
// INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
// PERFORMANCE OF THIS SOFTWARE.
#pragma once
#include <algorithm>
#include <cmath>
#include "ngraph/op/interpolate.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph {
namespace runtime {
namespace reference {
namespace interpolate_pil {
struct filter {
double (*filter)(double x, double coeff_a);
double support;
double coeff_a;
};
template <typename T_out, typename T_in>
T_out round_up(T_in x) {
return (T_out)(x >= 0.0 ? x + 0.5F : x - 0.5F);
}
template <typename T_out, typename T_in>
T_out clip(const T_in& x,
const T_out& min = std::numeric_limits<T_out>::min(),
const T_out& max = std::numeric_limits<T_out>::max()) {
return T_out(std::max(T_in(min), std::min(x, T_in(max))));
}
static inline double bilinear_filter(double x, double) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return 1.0 - x;
}
return 0.0;
}
static inline double bicubic_filter(double x, double a) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1;
}
if (x < 2.0) {
return (((x - 5) * x + 8) * x - 4) * a;
}
return 0.0;
}
static inline int precompute_coeffs(int in_size,
float in0,
float in1,
int out_size,
const filter& filterp,
std::vector<int>& bounds,
std::vector<double>& kk) {
double support, scale, filterscale;
double center, ww, ss;
int xx, x, ksize, xmin, xmax;
/* prepare for horizontal stretch */
filterscale = scale = (double)(in1 - in0) / out_size;
if (filterscale < 1.0) {
filterscale = 1.0;
}
/* determine support size (length of resampling filter) */
support = filterp.support * filterscale;
/* maximum number of coeffs */
ksize = (int)ceil(support) * 2 + 1;
/* coefficient buffer */
kk.resize(out_size * ksize);
bounds.resize(out_size * 2);
for (xx = 0; xx < out_size; xx++) {
center = in0 + (xx + 0.5) * scale;
ww = 0.0;
ss = 1.0 / filterscale;
// Round the value
xmin = (int)(center - support + 0.5);
if (xmin < 0) {
xmin = 0;
}
// Round the value
xmax = (int)(center + support + 0.5);
if (xmax > in_size) {
xmax = in_size;
}
xmax -= xmin;
double* k = &kk[xx * ksize];
for (x = 0; x < xmax; x++) {
double w = filterp.filter((x + xmin - center + 0.5) * ss, filterp.coeff_a);
k[x] = w;
ww += w;
}
for (x = 0; x < xmax; x++) {
if (ww != 0.0) {
k[x] /= ww;
}
}
// Remaining values should stay empty if they are used despite of xmax.
for (; x < ksize; x++) {
k[x] = 0;
}
bounds[xx * 2 + 0] = xmin;
bounds[xx * 2 + 1] = xmax;
}
return ksize;
}
template <typename T>
void imaging_resample_horizontal(T* im_out,
Shape im_out_shape,
const T* im_in,
Shape im_in_shape,
int offset,
int ksize,
std::vector<int>& bounds,
std::vector<double>& kk) {
double ss;
int x, xmin, xmax;
double* k;
for (size_t yy = 0; yy < im_out_shape[0]; yy++) {
for (size_t xx = 0; xx < im_out_shape[1]; xx++) {
xmin = bounds[xx * 2 + 0];
xmax = bounds[xx * 2 + 1];
k = &kk[xx * ksize];
ss = 0.0;
for (x = 0; x < xmax; x++) {
size_t in_idx = (yy + offset) * im_in_shape[1] + (x + xmin);
ss += im_in[in_idx] * k[x];
}
size_t out_idx = yy * im_out_shape[1] + xx;
if (std::is_integral<T>()) {
im_out[out_idx] = T(clip<T, int64_t>(round_up<int64_t, double>(ss)));
} else {
im_out[out_idx] = T(ss);
}
}
}
}
template <typename T>
void imaging_resample_vertical(T* im_out,
Shape im_out_shape,
const T* im_in,
Shape im_in_shape,
int offset,
int ksize,
std::vector<int>& bounds,
std::vector<double>& kk) {
double ss;
int y, ymin, ymax;
double* k;
for (size_t yy = 0; yy < im_out_shape[0]; yy++) {
ymin = bounds[yy * 2 + 0];
ymax = bounds[yy * 2 + 1];
k = &kk[yy * ksize];
for (size_t xx = 0; xx < im_out_shape[1]; xx++) {
ss = 0.0;
for (y = 0; y < ymax; y++) {
size_t in_idx = (y + ymin) * im_in_shape[1] + xx;
ss += im_in[in_idx] * k[y];
}
size_t out_idx = yy * im_out_shape[1] + xx;
if (std::is_integral<T>()) {
im_out[out_idx] = T(clip<T, int64_t>(round_up<int64_t, double>(ss)));
} else {
im_out[out_idx] = T(ss);
}
}
}
}
template <typename T>
void imaging_resample_inner(const T* im_in,
size_t im_in_xsize,
size_t im_in_ysize,
size_t xsize,
size_t ysize,
const filter& filterp,
float* box,
T* im_out) {
int ybox_first, ybox_last;
int ksize_horiz, ksize_vert;
std::vector<int> bounds_horiz;
std::vector<int> bounds_vert;
std::vector<double> kk_horiz;
std::vector<double> kk_vert;
auto need_horizontal = xsize != im_in_xsize || bool(box[0]) || box[2] != xsize;
auto need_vertical = ysize != im_in_ysize || bool(box[1]) || box[3] != ysize;
ksize_horiz = precompute_coeffs(static_cast<int>(im_in_xsize),
box[0],
box[2],
static_cast<int>(xsize),
filterp,
bounds_horiz,
kk_horiz);
ksize_vert = precompute_coeffs(static_cast<int>(im_in_ysize),
box[1],
box[3],
static_cast<int>(ysize),
filterp,
bounds_vert,
kk_vert);
// First used row in the source image
ybox_first = bounds_vert[0];
// Last used row in the source image
ybox_last = bounds_vert[ysize * 2 - 2] + bounds_vert[ysize * 2 - 1];
size_t im_temp_ysize = (ybox_last - ybox_first);
auto im_temp_elem_count = im_temp_ysize * xsize;
auto im_temp = std::vector<T>(im_temp_elem_count, 0);
/* two-pass resize, horizontal pass */
if (need_horizontal) {
// Shift bounds for vertical pass
for (size_t i = 0; i < ysize; i++) {
bounds_vert[i * 2] -= ybox_first;
}
if (im_temp.size() > 0) {
imaging_resample_horizontal(im_temp.data(),
Shape{im_temp_ysize, xsize},
im_in,
Shape{im_in_ysize, im_in_xsize},
ybox_first,
ksize_horiz,
bounds_horiz,
kk_horiz);
}
}
/* vertical pass */
if (need_vertical) {
/* im_in can be the original image or horizontally resampled one */
if (need_horizontal) {
imaging_resample_vertical(im_out,
Shape{ysize, xsize},
im_temp.data(),
Shape{im_temp_ysize, xsize},
0,
ksize_vert,
bounds_vert,
kk_vert);
} else {
imaging_resample_vertical(im_out,
Shape{ysize, xsize},
im_in,
Shape{im_in_ysize, im_in_xsize},
0,
ksize_vert,
bounds_vert,
kk_vert);
}
}
/* none of the previous steps are performed, copying */
if (!need_horizontal && !need_vertical) {
std::copy(im_in, im_in + (im_in_xsize * im_in_ysize), im_out);
} else if (need_horizontal && !need_vertical) {
std::copy(im_temp.begin(), im_temp.end(), im_out);
}
return;
}
} // namespace interpolate_pil
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -32,7 +32,7 @@ target_compile_definitions(interpreter_backend
SHARED_LIB_PREFIX="${CMAKE_SHARED_LIBRARY_PREFIX}"
SHARED_LIB_SUFFIX="${IE_BUILD_POSTFIX}${CMAKE_SHARED_LIBRARY_SUFFIX}"
)
target_link_libraries(interpreter_backend PRIVATE ngraph::builder ngraph::reference openvino::util openvino::runtime::dev)
target_link_libraries(interpreter_backend PRIVATE ngraph::builder ngraph::reference openvino::util openvino::runtime::dev ov_shape_inference)
target_include_directories(interpreter_backend PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>)

View File

@ -91,6 +91,7 @@
#include <ngraph/runtime/reference/utils/nms_common.hpp>
#include "backend.hpp"
#include "interpolate_shape_inference.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/runtime/reference/convert_color_nv12.hpp"
#include "ov_ops/augru_cell.hpp"
@ -4147,6 +4148,195 @@ bool evaluate(const shared_ptr<op::v0::Interpolate>& op,
return true;
}
namespace eval {
namespace interpolate {
// The helpers below are similar to the internal utils used in evaluate method of v4::Intepolate core op
// Those functions can be unified and moved to a common place
std::vector<int64_t> get_axes_vector(const ngraph::HostTensorVector& args,
size_t default_size,
size_t axes_port,
size_t max_num_of_ports) {
size_t num_of_inputs = args.size();
std::vector<int64_t> axes;
if (num_of_inputs == max_num_of_ports) {
auto axes_arg = args[axes_port];
size_t num_of_axes = args[axes_port]->get_shape()[0];
axes.reserve(num_of_axes);
if (axes_arg->get_element_type() == ov::element::i64) {
int64_t* axes_ptr = axes_arg->get_data_ptr<int64_t>();
axes.insert(axes.end(), axes_ptr, axes_ptr + num_of_axes);
} else if (axes_arg->get_element_type() == ov::element::i32) {
int32_t* axes_ptr = axes_arg->get_data_ptr<int32_t>();
for (size_t i = 0; i < num_of_axes; ++i)
axes.push_back(axes_ptr[i]);
} else {
OPENVINO_ASSERT(false, "Failed to process ", axes_arg->get_element_type());
}
} else {
for (size_t i = 0; i < default_size; ++i) {
axes.push_back(i);
}
}
return axes;
}
std::vector<int64_t> get_target_shape_vector(const ngraph::HostTensorVector& args,
size_t num_of_axes,
size_t target_shape_port = 1) {
std::vector<int64_t> target_shape;
target_shape.reserve(num_of_axes);
auto target_shape_arg = args[target_shape_port];
if (target_shape_arg->get_element_type() == ov::element::i64) {
int64_t* target_shape_ptr = target_shape_arg->get_data_ptr<int64_t>();
target_shape.insert(target_shape.end(), target_shape_ptr, target_shape_ptr + num_of_axes);
} else if (target_shape_arg->get_element_type() == ov::element::i32) {
int32_t* target_shape_ptr = target_shape_arg->get_data_ptr<int32_t>();
for (size_t i = 0; i < num_of_axes; ++i)
target_shape.push_back(target_shape_ptr[i]);
} else {
OPENVINO_ASSERT(false, "Failed to process ", target_shape_arg->get_element_type());
}
return target_shape;
}
std::vector<float> get_scales_vector(const ngraph::HostTensorVector& args,
const ov::Shape& input_shape,
const ov::op::util::InterpolateBase::InterpolateAttrs& attrs,
std::vector<int64_t> axes,
size_t scales_port) {
std::vector<float> scales;
size_t num_of_axes = axes.size();
if (attrs.shape_calculation_mode == ov::op::util::InterpolateBase::ShapeCalcMode::SCALES) {
float* scales_ptr = args[scales_port]->get_data_ptr<float>();
scales.insert(scales.end(), scales_ptr, scales_ptr + num_of_axes);
} else {
auto target_shape = get_target_shape_vector(args, num_of_axes);
for (size_t i = 0; i < num_of_axes; ++i) {
size_t axis = axes[i];
float scale = static_cast<float>(target_shape[i]) / static_cast<float>(input_shape[axis]);
scales.push_back(scale);
}
}
return scales;
}
static void pad_input_data(const uint8_t* data_ptr,
uint8_t* padded_data_ptr,
size_t type_size,
const ov::Shape& input_shape,
const ov::Shape& padded_input_shape,
const std::vector<size_t>& pads_begin) {
NGRAPH_SUPPRESS_DEPRECATED_START
ngraph::CoordinateTransform input_transform(input_shape);
ngraph::CoordinateTransform padded_transform(padded_input_shape);
for (const ngraph::Coordinate& input_coord : input_transform) {
auto padded_coord = input_coord;
size_t i = 0;
for (size_t pad : pads_begin) {
padded_coord[i] += pad;
++i;
}
uint8_t* dst_ptr = padded_data_ptr + type_size * padded_transform.index(padded_coord);
const uint8_t* src_ptr = data_ptr + type_size * input_transform.index(input_coord);
memcpy(dst_ptr, src_ptr, type_size);
}
NGRAPH_SUPPRESS_DEPRECATED_END
}
namespace v11 {
bool evaluate_interpolate(const shared_ptr<op::v11::Interpolate>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs) {
using namespace ov::op;
constexpr size_t data_port = 0;
constexpr size_t scales_sizes_port = 1;
constexpr size_t axes_port = 2;
constexpr size_t max_num_of_ports = 3;
element::Type input_et = inputs[0]->get_element_type();
size_t type_size = input_et.size();
ov::PartialShape input_shape{inputs[data_port]->get_shape()};
auto m_attrs = op->get_attrs();
util::correct_pads_attr(op.get(), m_attrs.pads_begin, m_attrs.pads_end, std::vector<PartialShape>{input_shape});
ov::Shape padded_input_shape;
for (size_t i = 0; i < input_shape.size(); ++i) {
padded_input_shape.emplace_back(m_attrs.pads_begin[i] + m_attrs.pads_end[i] + input_shape[i].get_length());
}
auto axes = get_axes_vector(inputs, inputs[1]->get_shape().size(), axes_port, max_num_of_ports);
auto scales = get_scales_vector(inputs, padded_input_shape, m_attrs, axes, scales_sizes_port);
ov::PartialShape output_shape{padded_input_shape};
if (m_attrs.shape_calculation_mode == util::InterpolateBase::ShapeCalcMode::SCALES) {
util::infer_using_scales(output_shape, axes, scales);
} else {
auto sizes = get_target_shape_vector(inputs, axes.size(), scales_sizes_port);
for (size_t i = 0; i < sizes.size(); ++i) {
output_shape[axes[i]] = Dimension(sizes[i]);
}
}
ov::Shape out_shape = output_shape.to_shape();
outputs[0]->set_shape(out_shape);
outputs[0]->set_element_type(input_et);
size_t bytes_in_padded_input = shape_size(padded_input_shape) * type_size;
std::vector<uint8_t> padded_input_data(bytes_in_padded_input, 0);
const uint8_t* data_ptr = inputs[0]->get_data_ptr<uint8_t>();
uint8_t* padded_data_ptr = padded_input_data.data();
pad_input_data(data_ptr,
padded_data_ptr,
type_size,
input_shape.to_shape(),
padded_input_shape,
m_attrs.pads_begin);
switch (input_et) {
case element::Type_t::f32:
ngraph::runtime::reference::interpolate<float>(reinterpret_cast<float*>(padded_data_ptr),
padded_input_shape,
scales,
axes,
outputs[0]->get_data_ptr<float>(),
out_shape,
m_attrs);
break;
case element::Type_t::u8:
ngraph::runtime::reference::interpolate<uint8_t>(reinterpret_cast<uint8_t*>(padded_data_ptr),
padded_input_shape,
scales,
axes,
outputs[0]->get_data_ptr<uint8_t>(),
out_shape,
m_attrs);
break;
default:;
}
return true;
}
} // namespace v11
} // namespace interpolate
} // namespace eval
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v11::Interpolate>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs) {
return eval::interpolate::v11::evaluate_interpolate(op, outputs, inputs);
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v9::SoftSign>& op, const HostTensorVector& outputs, const HostTensorVector& inputs) {
element::Type input_et = op->get_input_element_type(0);

View File

@ -146,5 +146,7 @@ _OPENVINO_OP_REG(IsInf, op::v10)
_OPENVINO_OP_REG(IsNaN, op::v10)
_OPENVINO_OP_REG(Unique, op::v10)
_OPENVINO_OP_REG(Interpolate, op::v11)
_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)

View File

@ -777,7 +777,8 @@ private:
auto scales = op::v0::Constant::create<float>(element::f32, Shape{scales_data.size()}, scales_data);
const auto& axes_data = param.axes_data;
auto axes = op::v0::Constant::create<int64_t>(element::i64, Shape{axes_data.size()}, axes_data);
auto interpolate = std::make_shared<op::v4::Interpolate>(image, target_spatial_shape, scales, axes, param.attrs);
auto interpolate =
std::make_shared<op::v4::Interpolate>(image, target_spatial_shape, scales, axes, param.attrs);
return std::make_shared<Model>(NodeVector{interpolate}, ParameterVector{image});
}
};
@ -792,4 +793,538 @@ INSTANTIATE_TEST_SUITE_P(smoke,
ReferenceInterpolate_v4::getTestCaseName);
} // namespace attribute_tests
namespace interpolate_v11_tests {
using InterpolateAttrs = op::v11::Interpolate::InterpolateAttrs;
using InterpolateMode = op::v11::Interpolate::InterpolateMode;
using ShapeCalcMode = op::v11::Interpolate::ShapeCalcMode;
using CoordinateTransformMode = op::v11::Interpolate::CoordinateTransformMode;
using TransformMode = op::v11::Interpolate::CoordinateTransformMode;
using NearestMode = op::v11::Interpolate::NearestMode;
class InterpolateV11TestParams {
public:
template <class Data_t = float>
InterpolateV11TestParams(std::string test_name,
Shape input_data_shape,
std::vector<int64_t> spatial_shape_data,
Shape output_shape,
std::vector<float> scales_data,
std::vector<int64_t> axes_data,
InterpolateAttrs interp_attrs,
std::vector<Data_t> input_data,
std::vector<Data_t> expected_results,
double cube_coeff_a = -0.75,
element::Type inType = element::from<Data_t>())
: test_name(test_name),
input_data_shape(input_data_shape),
spatial_shape_data(spatial_shape_data),
output_shape(output_shape),
scales_data(scales_data),
axes_data(axes_data),
attrs(interp_attrs),
m_input_data(CreateTensor(inType, input_data)),
m_expected_result(CreateTensor(inType, expected_results)),
inType(inType) {
attrs.cube_coeff = cube_coeff_a;
};
template <class Data_t = float>
InterpolateV11TestParams(const attribute_tests::InterpolateV4TestParams& v4_params)
: test_name(v4_params.test_name),
input_data_shape(v4_params.input_data_shape),
spatial_shape_data(v4_params.spatial_shape_data),
output_shape(v4_params.output_shape),
scales_data(v4_params.scales_data),
axes_data(v4_params.axes_data),
attrs(v4_params.attrs),
m_input_data(CreateTensor(element::from<Data_t>(), v4_params.input_data)),
m_expected_result(CreateTensor(element::from<Data_t>(), v4_params.expected_results)),
inType(element::from<Data_t>()){};
std::string test_name;
Shape input_data_shape;
std::vector<int64_t> spatial_shape_data;
Shape output_shape;
std::vector<float> scales_data;
std::vector<int64_t> axes_data;
InterpolateAttrs attrs;
ov::Tensor m_input_data;
ov::Tensor m_expected_result;
element::Type inType;
};
template <typename Data_t = uint8_t>
std::vector<InterpolateV11TestParams> generateParamsForInterpolate_bilinear_pil_int() {
const std::vector<size_t> zero_pads{0, 0, 0, 0};
return {
{
"bilinear.downsample_sizes_linear_range_h_pixel_hw_2D",
Shape{8, 8},
{4, 4},
Shape{4, 4},
{},
{0, 1},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63},
std::vector<Data_t>{7, 9, 11, 12, 21, 23, 25, 26, 37, 39, 41, 42, 51, 53, 55, 56},
},
{
"bilinear.downsample_scales_linear_range_h_pixel_hw_2D_scales",
Shape{8, 8},
{},
Shape{4, 4},
{0.5f, 0.5f},
{0, 1},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SCALES, {0, 0}, {0, 0}},
std::vector<Data_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63},
std::vector<Data_t>{7, 9, 11, 12, 21, 23, 25, 26, 37, 39, 41, 42, 51, 53, 55, 56},
},
{
"bilinear.downsample_scales_linear_rand_h_pixel_nhwc",
Shape{1, 4, 4, 3},
{},
Shape{1, 2, 2, 3},
{0.5f, 0.5f},
{1, 2},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SCALES, zero_pads, zero_pads},
std::vector<Data_t>{172, 10, 127, 140, 47, 170, 196, 151, 117, 166, 22, 183, 192, 204, 33, 216,
67, 179, 78, 154, 251, 82, 162, 219, 195, 118, 125, 139, 103, 125, 229, 216,
9, 164, 116, 108, 211, 222, 161, 159, 21, 81, 89, 165, 242, 214, 102, 98},
std::vector<Data_t>{174, 97, 132, 144, 119, 173, 175, 129, 124, 160, 138, 129},
},
{
"bilinear.downsample_scales_linear_range_h_pixel_nhwc",
Shape{1, 4, 4, 3},
{},
Shape{1, 2, 2, 3},
{0.5f, 0.5f},
{1, 2},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SCALES, zero_pads, zero_pads},
std::vector<Data_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47},
std::vector<Data_t>{11, 12, 13, 16, 17, 18, 29, 30, 31, 34, 35, 36},
},
{
"bilinear.downsample_scales_linear_rand_h_pixel_nhwc_batch_2",
Shape{2, 4, 4, 3},
{},
Shape{2, 2, 2, 3},
{0.5f, 0.5f},
{1, 2},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SCALES, zero_pads, zero_pads},
std::vector<Data_t>{172, 10, 127, 140, 47, 170, 196, 151, 117, 166, 22, 183, 192, 204, 33, 216,
67, 179, 78, 154, 251, 82, 162, 219, 195, 118, 125, 139, 103, 125, 229, 216,
9, 164, 116, 108, 211, 222, 161, 159, 21, 81, 89, 165, 242, 214, 102, 98,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47},
std::vector<Data_t>{174, 97, 132, 144, 119, 173, 175, 129, 124, 160, 138, 129,
11, 12, 13, 16, 17, 18, 29, 30, 31, 34, 35, 36},
},
{
"bilinear.downsample_sizes_nhwc_1x5x6x3_to_1x2x4x3",
Shape{1, 5, 6, 3},
{2, 4},
Shape{1, 2, 4, 3},
{},
{1, 2},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SIZES, zero_pads, zero_pads},
std::vector<Data_t>{37, 244, 193, 106, 235, 128, 71, 255, 140, 47, 103, 184, 72, 20, 188,
238, 255, 126, 7, 0, 137, 195, 204, 32, 203, 170, 101, 77, 133, 30,
193, 255, 79, 203, 145, 37, 192, 83, 112, 60, 144, 128, 163, 23, 129,
80, 134, 101, 204, 191, 174, 47, 71, 30, 78, 99, 237, 170, 118, 88,
252, 121, 116, 171, 134, 141, 146, 101, 25, 125, 127, 239, 178, 228, 239,
137, 20, 213, 167, 216, 254, 84, 80, 107, 101, 177, 50, 80, 146, 139},
std::vector<Data_t>{89 /* 90 */, 137, 129, 138, 169, 107, 109, 140, 113, 168, 161, 95,
134, 119, 178, 171, 118, 148, 138, 130, 106, 116, 133, 120},
},
{
"bilinear.upsample_sizes_nhwc_1x2x4x3_to_1x5x6x3",
Shape{1, 2, 4, 3},
{5, 6},
Shape{1, 5, 6, 3},
{},
{1, 2},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SIZES, zero_pads, zero_pads},
std::vector<Data_t>{37, 244, 193, 106, 235, 128, 71, 255, 140, 47, 103, 184,
72, 20, 188, 238, 255, 126, 7, 0, 137, 195, 204, 32},
std::vector<Data_t>{37, 244, 193, 72, 240, 161, 100, 238, 130, 77, 252, 138, 59, 179, 162, 47,
103, 184, 41 /* 40 */, 222, 193, 80, 230, 161, 110, 235, // Rounding?
130, 74, 231, 138, 63, 171, 154, 62, 113, 169, 55, 132, 191, 114, 189, 159,
150, 225, 129, 62, 148, 137, 80, 141, 124, 121, 154, 108, 69, 42, 188, 147,
148, 157, 189, 215, 128, 49, 64, 135, 97, 110, 93, 180, 194, 47, 72, 20,
188, 155, 138, 157, 199, 212, 128, 46, 43, 135, 101, 102, 85, 195, 204, 32},
},
{
"bilinear.downsample_sizes_nchw_1x3x5x6_to_1x3x2x4",
Shape{1, 3, 5, 6},
{2, 4},
Shape{1, 3, 2, 4},
{},
{2, 3},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SIZES, zero_pads, zero_pads},
std::vector<Data_t>{37, 106, 71, 47, 72, 238, 7, 195, 203, 77, 193, 203, 192, 60, 163,
80, 204, 47, 78, 170, 252, 171, 146, 125, 178, 137, 167, 84, 101, 80,
244, 235, 255, 103, 20, 255, 0, 204, 170, 133, 255, 145, 83, 144, 23,
134, 191, 71, 99, 118, 121, 134, 101, 127, 228, 20, 216, 80, 177, 146,
193, 128, 140, 184, 188, 126, 137, 32, 101, 30, 79, 37, 112, 128, 129,
101, 174, 30, 237, 88, 116, 141, 25, 239, 239, 213, 254, 107, 50, 139},
std::vector<Data_t>{89 /* 90 */, 138, 109, 168, 134, 171, 138, 116, 137, 169, 140, 161,
119, 118, 130, 133, 129, 107, 113, 95, 178, 148, 106, 120},
},
{
"bilinear.downsample_scales_range_h_pixel_nchw",
Shape{1, 3, 4, 4},
{},
Shape{1, 3, 2, 2},
{0.5f, 0.5f},
{2, 3},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SCALES, zero_pads, zero_pads},
std::vector<Data_t>{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45,
1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46,
2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47},
std::vector<Data_t>{11, 16, 29, 34, 12, 17, 30, 35, 13, 18, 31, 36},
}};
}
template <typename Data_t = uint8_t>
std::vector<InterpolateV11TestParams> generateParamsForInterpolate_bicubic_pil_int() {
return {
{
"bicubic.downsample_scales_2D",
Shape{8, 8},
{},
Shape{4, 4},
{0.5f, 0.5f},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SCALES, {0, 0}, {0, 0}},
std::vector<Data_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63},
std::vector<Data_t>{5, 6, 9, 10, 21, 22, 25, 26, 37, 38, 41, 42, 53, 54, 57, 58},
-0.5, // cube_coeff
},
{
"bicubic.downsample_sizes_2D",
Shape{8, 8},
{4, 4},
Shape{4, 4},
{},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63},
std::vector<Data_t>{5, 6, 9, 10, 21, 22, 25, 26, 37, 38, 41, 42, 53, 54, 57, 58},
-0.5, // cube_coeff
},
{
"bicubic.downsample_sizes_2D",
Shape{5, 6},
{2, 4},
Shape{2, 4},
{},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{168, 92, 157, 111, 15, 138, 97, 47, 237, 25, 163, 6, 72, 118, 121,
238, 22, 174, 182, 140, 43, 121, 158, 242, 210, 73, 113, 111, 75, 132},
std::vector<Data_t>{99, 143, 105, 88, 146, 98, 123, 154},
-0.5, // cube_coeff
},
{
"bicubic.downsample_sizes_2D_ov_default_cube_coeff",
Shape{5, 6},
{2, 4},
Shape{2, 4},
{},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{168, 92, 157, 111, 15, 138, 97, 47, 237, 25, 163, 6, 72, 118, 121,
238, 22, 174, 182, 140, 43, 121, 158, 242, 210, 73, 113, 111, 75, 132},
std::vector<Data_t>{97, 144, 106, 88, 145, 98, 121, 153},
// default cube_coeff -0.75
},
{
"bicubic.downsample_sizes_1x1x8x8_nchw",
Shape{1, 1, 8, 8},
{4, 4},
Shape{4, 4},
{},
{2, 3},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63},
std::vector<Data_t>{5, 6, 9, 10, 21, 22, 25, 26, 37, 38, 41, 42, 53, 54, 57, 58},
-0.5, // cube_coeff
},
{
"bicubic.downsample_sizes_1x8x8x1_nhwc",
Shape{1, 8, 8, 1},
{4, 4},
Shape{4, 4},
{},
{1, 2},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63},
std::vector<Data_t>{5, 6, 9, 10, 21, 22, 25, 26, 37, 38, 41, 42, 53, 54, 57, 58},
-0.5, // cube_coeff
},
{
"bicubic.upsample_sizes_1x2x4x3_to_1x5x6x3_nhwc",
Shape{1, 2, 4, 3},
{5, 6},
Shape{1, 5, 6, 3},
{},
{1, 2},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{168, 92, 157, 111, 15, 138, 97, 47, 237, 25, 163, 6,
72, 118, 121, 238, 22, 174, 182, 140, 43, 121, 158, 242},
std::vector<Data_t>{183, 94, 162, 141, 53, 141, 94, 11, 150, 93, 27, 255, 49, 105, 119, 10, 172, 0,
165, 99, 155, 143, 55, 143, 116, 14, 152, 108, 42, 226, 64, 113, 122, 26, 170, 17,
117, 111, 138, 148, 60, 148, 175, 22, 155, 148, 80, 143, 102, 133, 131, 69, 165, 128,
68, 122, 121, 152, 65, 153, 233, 29, 158, 188, 118, 60, 140, 153, 140, 111, 160, 238,
50, 127, 114, 154, 67, 155, 255, 32, 160, 203, 133, 29, 155, 161, 143, 127, 158, 255},
-0.5, // cube_coeff
},
{
"bicubic.downsample_sizes_1x5x6x3_to_1x2x4x3_nhwc",
Shape{1, 5, 6, 3},
{2, 4},
Shape{1, 2, 4, 3},
{},
{1, 2},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{168, 92, 157, 111, 15, 138, 97, 47, 237, 25, 163, 6, 72, 118, 121,
238, 22, 174, 182, 140, 43, 121, 158, 242, 210, 73, 113, 111, 75, 132,
24, 124, 104, 57, 157, 107, 7, 173, 14, 82, 162, 210, 144, 84, 177,
129, 136, 39, 95, 218, 99, 52, 75, 170, 232, 178, 213, 138, 136, 158,
47, 20, 181, 30, 63, 43, 182, 76, 31, 125, 52, 124, 218, 202, 78,
68, 148, 25, 251, 161, 124, 160, 2, 159, 116, 78, 119, 209, 37, 219},
std::vector<Data_t>{126, 125, 124, 133, 79, 181, 77, 127, 79, 95, 111, 131,
147, 178, 119, 124, 102, 144, 117, 75, 84, 135, 78, 134},
-0.5, // cube_coeff
}};
}
template <typename Data_t = float>
std::vector<InterpolateV11TestParams> generateParamsForInterpolate_bilinear_pil_float() {
return {
{
"bilinear.downsample_2D_sizes",
Shape{5, 6},
{2, 4},
Shape{2, 4},
{},
{0, 1},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{121.14, 131.03, 193.32, 243.32, 8.92, 36.9, 210.67, 242.85, 63.8, 79.83,
222.47, 108.37, 69.93, 211.89, 65.79, 104.75, 164.82, 140.7, 21.95, 7.06,
221.59, 192.9, 214.5, 137.76, 209.29, 84.41, 115.89, 201.84, 31.72, 77.62},
std::vector<Data_t>{159.58046, 141.59782, 138.78581, 111.842384, 96.50358, 129.36433, 159.38596, 128.2533},
},
{
"bilinear.downsample_to_2x4_2D_scales",
Shape{5, 6},
{},
Shape{2, 4},
{0.4f, 0.7f},
{0, 1},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SCALES, {0, 0}, {0, 0}},
std::vector<Data_t>{121.14, 131.03, 193.32, 243.32, 8.92, 36.9, 210.67, 242.85, 63.8, 79.83,
222.47, 108.37, 69.93, 211.89, 65.79, 104.75, 164.82, 140.7, 21.95, 7.06,
221.59, 192.9, 214.5, 137.76, 209.29, 84.41, 115.89, 201.84, 31.72, 77.62},
std::vector<Data_t>{159.58046, 141.59782, 138.78581, 111.842384, 96.50358, 129.36433, 159.38596, 128.2533},
},
{
"bilinear.downsample_to_2x3_2D_scales",
Shape{5, 6},
{},
Shape{2, 4},
{0.4f, 0.6666f},
{0, 1},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SCALES, {0, 0}, {0, 0}},
std::vector<Data_t>{121.14, 131.03, 193.32, 243.32, 8.92, 36.9, 210.67, 242.85, 63.8, 79.83,
222.47, 108.37, 69.93, 211.89, 65.79, 104.75, 164.82, 140.7, 21.95, 7.06,
221.59, 192.9, 214.5, 137.76, 209.29, 84.41, 115.89, 201.84, 31.72, 77.62},
std::vector<Data_t>{158.00597, 137.05489, 121.252205, 102.18909, 147.77483, 137.24052},
},
{
"bilinear.upsample_2D_sizes",
Shape{2, 4},
{5, 6},
Shape{5, 6},
{},
{0, 1},
{InterpolateMode::BILINEAR_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{214.42, 66.97, 27.98, 76.41, 105.94, 208.44, 115.53, 23.53},
std::vector<Data_t>{214.42, 140.695, 60.47167, 34.478333, 52.195, 76.41, 203.57199, 142.34451,
73.72, 44.132, 53.9285, 71.122, 160.18, 148.9425, 126.71333, 82.746666,
60.8625, 49.97, 116.788, 155.5405, 179.70667, 121.361336, 67.7965, 28.818,
105.94, 157.19, 192.955, 131.015, 69.53, 23.53},
}};
}
template <typename Data_t = float>
std::vector<InterpolateV11TestParams> generateParamsForInterpolate_bicubic_pil_float() {
return {
{
"bicubic.downsample_2D_sizes",
Shape{5, 6},
{2, 4},
Shape{2, 4},
{},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{121.14, 131.03, 193.32, 243.32, 8.92, 36.9, 210.67, 242.85, 63.8, 79.83,
222.47, 108.37, 69.93, 211.89, 65.79, 104.75, 164.82, 140.7, 21.95, 7.06,
221.59, 192.9, 214.5, 137.76, 209.29, 84.41, 115.89, 201.84, 31.72, 77.62},
std::vector<Data_t>{162.90814, 143.26627, 138.46507, 109.5325, 92.69513, 126.17204, 164.13477, 127.86513},
-0.5, // cube_coeff
},
{
"bicubic.downsample_to_2x4_2D_scales",
Shape{5, 6},
{},
Shape{2, 4},
{0.4f, 0.7f},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SCALES, {0, 0}, {0, 0}},
std::vector<Data_t>{121.14, 131.03, 193.32, 243.32, 8.92, 36.9, 210.67, 242.85, 63.8, 79.83,
222.47, 108.37, 69.93, 211.89, 65.79, 104.75, 164.82, 140.7, 21.95, 7.06,
221.59, 192.9, 214.5, 137.76, 209.29, 84.41, 115.89, 201.84, 31.72, 77.62},
std::vector<Data_t>{162.90814, 143.26627, 138.46507, 109.5325, 92.69513, 126.17204, 164.13477, 127.86513},
-0.5, // cube_coeff
},
{
"bicubic.downsample_2D_sizes_cube_coeff_ov_default",
Shape{5, 6},
{2, 4},
Shape{2, 4},
{},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{121.14, 131.03, 193.32, 243.32, 8.92, 36.9, 210.67, 242.85, 63.8, 79.83,
222.47, 108.37, 69.93, 211.89, 65.79, 104.75, 164.82, 140.7, 21.95, 7.06,
221.59, 192.9, 214.5, 137.76, 209.29, 84.41, 115.89, 201.84, 31.72, 77.62},
std::vector<
Data_t>{162.548325, 144.773224, 138.243408, 110.827049, 92.899925, 125.124802, 164.711548, 129.240463},
// default cube_coeff -0.75
},
{
"bicubic.downsample_to_2x3_2D_scales",
Shape{5, 6},
{},
Shape{2, 3},
{0.4f, 0.6666f},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SCALES, {0, 0}, {0, 0}},
std::vector<Data_t>{121.14, 131.03, 193.32, 243.32, 8.92, 36.9, 210.67, 242.85, 63.8, 79.83,
222.47, 108.37, 69.93, 211.89, 65.79, 104.75, 164.82, 140.7, 21.95, 7.06,
221.59, 192.9, 214.5, 137.76, 209.29, 84.41, 115.89, 201.84, 31.72, 77.62},
std::vector<Data_t>{162.16028, 136.76193, 118.96405, 95.98418, 151.06361, 137.54117},
-0.5, // cube_coeff
},
{
"bicubic.upsample_2D_sizes",
Shape{2, 4},
{5, 6},
Shape{5, 6},
{},
{0, 1},
{InterpolateMode::BICUBIC_PILLOW, ShapeCalcMode::SIZES, {0, 0}, {0, 0}},
std::vector<Data_t>{214.42, 66.97, 27.98, 76.41, 105.94, 208.44, 115.53, 23.53},
std::vector<Data_t>{236.49521, 146.10538, 38.218796, 17.75709, 50.332058, 85.74947, 215.93185, 148.13255,
63.085896, 35.050694, 51.983547, 75.524284, 161.65862, 153.48294, 128.71808, 80.69401,
56.342354, 48.53678, 107.38538, 158.83333, 194.35027, 126.33732, 60.70116, 21.549273,
86.82202, 160.8605, 219.21736, 143.63092, 62.35265, 11.32409},
-0.5, // cube_coeff
}};
}
std::vector<InterpolateV11TestParams> generateCombinedParamsForInterpolate_v11() {
const std::vector<std::vector<InterpolateV11TestParams>> allTypeParamsV11{
generateParamsForInterpolate_bilinear_pil_float<float>(),
generateParamsForInterpolate_bicubic_pil_float<float>(),
generateParamsForInterpolate_bilinear_pil_int<uint8_t>(),
generateParamsForInterpolate_bicubic_pil_int<uint8_t>()};
const std::vector<std::vector<attribute_tests::InterpolateV4TestParams>> allTypeParamsV4{
attribute_tests::generateParamsForInterpolate_v4_cubic(),
attribute_tests::generateParamsForInterpolate_v4_nearest(),
attribute_tests::generateParamsForInterpolate_v4_linear_onnx(),
attribute_tests::generateParamsForInterpolate_v4_linear_onnx5d()};
std::vector<InterpolateV11TestParams> combinedParams;
for (auto& params : allTypeParamsV11) {
std::move(params.begin(), params.end(), std::back_inserter(combinedParams));
}
for (auto& params : allTypeParamsV4) {
for (auto& param : params) {
combinedParams.emplace_back(param);
}
}
return combinedParams;
}
class ReferenceInterpolate_v11 : public testing::TestWithParam<InterpolateV11TestParams>, public CommonReferenceTest {
public:
void SetUp() override {
const auto& params = GetParam();
function = CreateFunction(params);
inputData = {params.m_input_data};
refOutData = {params.m_expected_result};
}
static std::string getTestCaseName(const testing::TestParamInfo<InterpolateV11TestParams>& obj) {
const auto& param = obj.param;
std::ostringstream result;
result << "data_type=" << param.inType << "; ";
result << "data_shape=" << param.input_data_shape << "; ";
if (param.attrs.mode == InterpolateMode::BICUBIC_PILLOW || param.attrs.mode == InterpolateMode::CUBIC) {
result << "cubic_coeff=" << param.attrs.cube_coeff << "; ";
}
if (!param.test_name.empty()) {
result << "tested_case=" << param.test_name << "; ";
}
return result.str();
}
private:
static std::shared_ptr<Model> CreateFunction(const InterpolateV11TestParams& param) {
auto image = std::make_shared<op::v0::Parameter>(param.inType, param.input_data_shape);
ov::Output<ov::Node> sizes_or_scales;
if (param.attrs.shape_calculation_mode == ShapeCalcMode::SCALES) {
const auto& scales_data = param.scales_data;
sizes_or_scales = op::v0::Constant::create<float>(element::f32, Shape{scales_data.size()}, scales_data);
} else {
const auto& spatial_shape_data = param.spatial_shape_data;
sizes_or_scales =
op::v0::Constant::create<int64_t>(element::i64, Shape{spatial_shape_data.size()}, spatial_shape_data);
}
const auto& axes_data = param.axes_data;
auto axes = op::v0::Constant::create<int64_t>(element::i64, Shape{axes_data.size()}, axes_data);
auto interpolate = std::make_shared<op::v11::Interpolate>(image, sizes_or_scales, axes, param.attrs);
return std::make_shared<Model>(NodeVector{interpolate}, ParameterVector{image});
}
};
TEST_P(ReferenceInterpolate_v11, LayerTest) {
Exec();
}
INSTANTIATE_TEST_SUITE_P(smoke,
ReferenceInterpolate_v11,
::testing::ValuesIn(generateCombinedParamsForInterpolate_v11()),
ReferenceInterpolate_v11::getTestCaseName);
} // namespace interpolate_v11_tests
} // namespace