Fixed interpolate to work with int32 axes, target_shapes (#8379)
This commit is contained in:
parent
7db0973ba7
commit
279d905011
@ -294,9 +294,20 @@ std::vector<int64_t> get_axes_vector(const HostTensorVector& args) {
|
||||
std::vector<int64_t> axes;
|
||||
|
||||
if (num_of_inputs == max_num_of_ports) {
|
||||
int64_t* axes_data_ptr = args[axes_port]->get_data_ptr<int64_t>();
|
||||
auto axes_arg = args[axes_port];
|
||||
size_t num_of_axes = args[axes_port]->get_shape()[0];
|
||||
axes.insert(axes.end(), axes_data_ptr, axes_data_ptr + num_of_axes);
|
||||
axes.reserve(num_of_axes);
|
||||
|
||||
if (axes_arg->get_element_type() == ngraph::element::i64) {
|
||||
int64_t* axes_ptr = axes_arg->get_data_ptr<int64_t>();
|
||||
axes.insert(axes.end(), axes_ptr, axes_ptr + num_of_axes);
|
||||
} else if (axes_arg->get_element_type() == ngraph::element::i32) {
|
||||
int32_t* axes_ptr = axes_arg->get_data_ptr<int32_t>();
|
||||
for (size_t i = 0; i < num_of_axes; ++i)
|
||||
axes.push_back(axes_ptr[i]);
|
||||
} else {
|
||||
OPENVINO_ASSERT(false, "Failed to process ", axes_arg->get_element_type());
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < input_rank; ++i) {
|
||||
axes.push_back(i);
|
||||
@ -308,9 +319,19 @@ std::vector<int64_t> get_axes_vector(const HostTensorVector& args) {
|
||||
|
||||
std::vector<int64_t> get_target_shape_vector(const HostTensorVector& args, size_t num_of_axes) {
|
||||
std::vector<int64_t> target_shape;
|
||||
target_shape.reserve(num_of_axes);
|
||||
|
||||
int64_t* target_shape_ptr = args[target_shape_port]->get_data_ptr<int64_t>();
|
||||
target_shape.insert(target_shape.end(), target_shape_ptr, target_shape_ptr + num_of_axes);
|
||||
auto target_shape_arg = args[target_shape_port];
|
||||
if (target_shape_arg->get_element_type() == ngraph::element::i64) {
|
||||
int64_t* target_shape_ptr = target_shape_arg->get_data_ptr<int64_t>();
|
||||
target_shape.insert(target_shape.end(), target_shape_ptr, target_shape_ptr + num_of_axes);
|
||||
} else if (target_shape_arg->get_element_type() == ngraph::element::i32) {
|
||||
int32_t* target_shape_ptr = target_shape_arg->get_data_ptr<int32_t>();
|
||||
for (size_t i = 0; i < num_of_axes; ++i)
|
||||
target_shape.push_back(target_shape_ptr[i]);
|
||||
} else {
|
||||
OPENVINO_ASSERT(false, "Failed to process ", target_shape_arg->get_element_type());
|
||||
}
|
||||
|
||||
return target_shape;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user