Review interpolate for shape inference aspects (#17667)
* Review interpolate shapes and label propagation * Review shape_infer template implementation * Update shape infer of interpolate in GPU plugin - Add new tensor accessor for ov::Tensor map * Correct casting in dim::scale function * Remove validation of size of input 1 in v0 * Relax inputs check for interpolate v4 * Correct GPU shape inference * Use ov::Tensors in interpolate's evaluate - Remove some duplicated code - Apply comments from review * Set shape in interpolate's eval for output tensor
This commit is contained in:
@@ -34,7 +34,7 @@ struct MemoryAccessor : public ov::ITensorAccessor {
|
||||
* @param stream CLDNN stream used for memory locks.
|
||||
* @param clbk Function object for custom callback when accessing data and not found in CLDNN memories.
|
||||
*/
|
||||
MemoryAccessor(const container_type* ptrs, const stream& stream, std::function<const ov::Tensor(size_t)> clbk)
|
||||
MemoryAccessor(const container_type* ptrs, const stream& stream, std::function<ov::Tensor(size_t)> clbk)
|
||||
: m_ptrs{ptrs},
|
||||
m_stream{stream},
|
||||
m_clbk{std::move(clbk)},
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "resample_inst.h"
|
||||
#include "primitive_type_base.h"
|
||||
#include <string>
|
||||
#include "json_object.h"
|
||||
|
||||
#include "interpolate_shape_inference.hpp"
|
||||
#include "json_object.h"
|
||||
#include "memory_accessor.hpp"
|
||||
#include "primitive_type_base.h"
|
||||
#include "resample_inst.h"
|
||||
|
||||
namespace cldnn {
|
||||
GPU_DEFINE_PRIMITIVE_TYPE_ID(resample)
|
||||
@@ -41,67 +42,34 @@ std::vector<layout> resample_inst::calc_output_layouts(resample_node const& /*no
|
||||
|
||||
ShapeType sizes_shape = desc->sizes.empty() ? ov::Shape{ input_rank }
|
||||
: ov::Shape{ desc->sizes.size() };
|
||||
ShapeType scales_shape = desc->scales.empty() ? ov::Shape{ input_rank }
|
||||
: ov::Shape{ desc->scales.size() };
|
||||
std::vector<ShapeType> output_shapes = {ShapeType()};
|
||||
std::vector<ShapeType> input_shapes = {
|
||||
input_shape,
|
||||
sizes_shape,
|
||||
scales_shape,
|
||||
ov::Shape{ desc->axes.size() }
|
||||
};
|
||||
ShapeType scales_shape = desc->scales.empty() ? ov::Shape{input_rank} : ov::Shape{desc->scales.size()};
|
||||
std::vector<ShapeType> input_shapes = {input_shape, sizes_shape, scales_shape};
|
||||
|
||||
std::unordered_map<size_t, ov::Tensor> tensors;
|
||||
|
||||
auto sizes = desc->sizes;
|
||||
if (!sizes.empty()) {
|
||||
tensors.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{sizes.size()}, sizes.data()));
|
||||
}
|
||||
|
||||
auto scales = desc->scales;
|
||||
if (!scales.empty()) {
|
||||
tensors.emplace(2, ov::Tensor(ov::element::f32, ov::Shape{scales.size()}, scales.data()));
|
||||
}
|
||||
|
||||
auto axes = desc->axes;
|
||||
if (!axes.empty()) {
|
||||
auto axes_shape = ov::Shape{axes.size()};
|
||||
input_shapes.push_back(axes_shape);
|
||||
tensors.emplace(3, ov::Tensor(ov::element::i64, axes_shape, axes.data()));
|
||||
}
|
||||
|
||||
auto& memory_deps = impl_param.memory_deps;
|
||||
std::map<size_t, ngraph::HostTensorPtr> const_data;
|
||||
|
||||
auto sizes_data = desc->sizes;
|
||||
auto scales_data = desc->scales;
|
||||
|
||||
bool sizes_calc_mod = desc->get_attrs().shape_calculation_mode == ov::op::v4::Interpolate::ShapeCalcMode::SIZES;
|
||||
|
||||
if (((sizes_data.empty() && !memory_deps.count(1)) || !sizes_calc_mod) &&
|
||||
((scales_data.empty() && !memory_deps.count(2)) || sizes_calc_mod)) {
|
||||
return { layout{ShapeType::dynamic(input_rank), input_layout.data_type, input_layout.format} };
|
||||
}
|
||||
|
||||
auto axes_data = desc->axes;
|
||||
if (axes_data.empty()) {
|
||||
axes_data.resize(input_layout.get_rank());
|
||||
std::iota(axes_data.begin(), axes_data.end(), 0);
|
||||
}
|
||||
auto axes_tensor = make_host_tensor({ ov::PartialShape{ ov::Shape{axes_data.size()} }, data_types::i64, format::bfyx },
|
||||
static_cast<void*>(axes_data.data()));
|
||||
const_data.emplace(3, axes_tensor);
|
||||
const auto ta = MemoryAccessor(&memory_deps, impl_param.get_stream(), ov::make_tensor_accessor(tensors));
|
||||
|
||||
auto pads_begin = desc->pads_begin;
|
||||
auto pads_end = desc->pads_end;
|
||||
ov::op::util::correct_pads_attr(&op, pads_begin, pads_end, input_shapes);
|
||||
|
||||
if (sizes_calc_mod) {
|
||||
if (!sizes_data.empty()) {
|
||||
auto sizes_tensor = make_host_tensor({ sizes_shape, data_types::i64, format::bfyx }, static_cast<void*>(sizes_data.data()));
|
||||
const_data.emplace(1, sizes_tensor);
|
||||
ov::op::v4::shape_infer(&op, pads_begin, pads_end, input_shapes, output_shapes, {const_data});
|
||||
} else {
|
||||
auto sizes_mem = memory_deps.at(1);
|
||||
cldnn::mem_lock<uint8_t, mem_lock_type::read> lock(sizes_mem, impl_param.get_stream());
|
||||
auto sizes_tensor = make_host_tensor(sizes_mem->get_layout(), lock.data());
|
||||
const_data.emplace(1, sizes_tensor);
|
||||
ov::op::v4::shape_infer(&op, pads_begin, pads_end, input_shapes, output_shapes, {const_data});
|
||||
}
|
||||
} else {
|
||||
if (!scales_data.empty()) {
|
||||
auto scales_tensor = make_host_tensor({ scales_shape, data_types::f32, format::bfyx }, static_cast<void*>(scales_data.data()));
|
||||
const_data.emplace(2, scales_tensor);
|
||||
ov::op::v4::shape_infer(&op, pads_begin, pads_end, input_shapes, output_shapes, {const_data});
|
||||
} else {
|
||||
auto scales_mem = memory_deps.at(2);
|
||||
cldnn::mem_lock<uint8_t, mem_lock_type::read> lock(scales_mem, impl_param.get_stream());
|
||||
auto scales_tensor = make_host_tensor(scales_mem->get_layout(), lock.data());
|
||||
const_data.emplace(2, scales_tensor);
|
||||
ov::op::v4::shape_infer(&op, pads_begin, pads_end, input_shapes, output_shapes, {const_data});
|
||||
}
|
||||
}
|
||||
const auto output_shapes = ov::op::v4::shape_infer(&op, input_shapes, pads_begin, pads_end, ta);
|
||||
|
||||
return { layout{output_shapes[0], input_layout.data_type, format::adjust_to_rank(input_layout.format, output_shapes[0].size())} };
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user