Fixed interpolate to work with int32 axes, target_shapes (#8379)

This commit is contained in:
Ilya Lavrenov 2021-11-08 13:40:54 +03:00 committed by GitHub
parent 7db0973ba7
commit 279d905011
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -294,9 +294,20 @@ std::vector<int64_t> get_axes_vector(const HostTensorVector& args) {
std::vector<int64_t> axes;
if (num_of_inputs == max_num_of_ports) {
int64_t* axes_data_ptr = args[axes_port]->get_data_ptr<int64_t>();
auto axes_arg = args[axes_port];
size_t num_of_axes = args[axes_port]->get_shape()[0];
axes.insert(axes.end(), axes_data_ptr, axes_data_ptr + num_of_axes);
axes.reserve(num_of_axes);
if (axes_arg->get_element_type() == ngraph::element::i64) {
int64_t* axes_ptr = axes_arg->get_data_ptr<int64_t>();
axes.insert(axes.end(), axes_ptr, axes_ptr + num_of_axes);
} else if (axes_arg->get_element_type() == ngraph::element::i32) {
int32_t* axes_ptr = axes_arg->get_data_ptr<int32_t>();
for (size_t i = 0; i < num_of_axes; ++i)
axes.push_back(axes_ptr[i]);
} else {
OPENVINO_ASSERT(false, "Failed to process ", axes_arg->get_element_type());
}
} else {
for (size_t i = 0; i < input_rank; ++i) {
axes.push_back(i);
@ -308,9 +319,19 @@ std::vector<int64_t> get_axes_vector(const HostTensorVector& args) {
std::vector<int64_t> get_target_shape_vector(const HostTensorVector& args, size_t num_of_axes) {
std::vector<int64_t> target_shape;
target_shape.reserve(num_of_axes);
int64_t* target_shape_ptr = args[target_shape_port]->get_data_ptr<int64_t>();
target_shape.insert(target_shape.end(), target_shape_ptr, target_shape_ptr + num_of_axes);
auto target_shape_arg = args[target_shape_port];
if (target_shape_arg->get_element_type() == ngraph::element::i64) {
int64_t* target_shape_ptr = target_shape_arg->get_data_ptr<int64_t>();
target_shape.insert(target_shape.end(), target_shape_ptr, target_shape_ptr + num_of_axes);
} else if (target_shape_arg->get_element_type() == ngraph::element::i32) {
int32_t* target_shape_ptr = target_shape_arg->get_data_ptr<int32_t>();
for (size_t i = 0; i < num_of_axes; ++i)
target_shape.push_back(target_shape_ptr[i]);
} else {
OPENVINO_ASSERT(false, "Failed to process ", target_shape_arg->get_element_type());
}
return target_shape;
}