Crop(Split and VariadicSplit) new shape infer (#13216)
* [GPU] Split and VariadicSplit new shape infer (#13216) GPU update test script * [GPU] Added W/A for crop offset (#13216) Co-authored-by: Ahn, Paul Y <paul.y.ahn@intel.com> * [GPU] Fix crop gpu test failures (#13216) - Fixed by review comments * [GPU] Move input offsets calculation to crop_inst::calc_output_layouts (#13216) Co-authored-by: Taylor Yeonbok Lee <taylor.lee@intel.com>
This commit is contained in:
parent
35b943d21a
commit
cd5928da53
@ -18,7 +18,6 @@ void shape_infer(const Split* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2));
|
||||
|
||||
output_shapes.clear();
|
||||
|
@ -18,7 +18,6 @@ void shape_infer(const VariadicSplit* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
constexpr bool is_dynamic_shape = std::is_base_of<ov::PartialShape, T>::value;
|
||||
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3));
|
||||
|
@ -14,6 +14,15 @@ namespace cldnn {
|
||||
/// @addtogroup cpp_primitives Primitives
|
||||
/// @{
|
||||
|
||||
/// @brief Select original ngraph op mode for the @ref crop layer.
|
||||
enum class crop_ngraph_op_mode : int32_t {
|
||||
none,
|
||||
/// @brief ngraph split op.
|
||||
split,
|
||||
/// @brief ngraph variadic split op.
|
||||
variadic_split
|
||||
};
|
||||
|
||||
/// @brief Marker type indicating that instead of reference input size left, top,
|
||||
/// right and bottom borders (to cut out) should be specified.
|
||||
///
|
||||
@ -52,7 +61,8 @@ struct crop : public primitive_base<crop> {
|
||||
const tensor& reference_input,
|
||||
const tensor& offsets,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input}, output_padding), reference_input(reference_input), offsets(offsets) {}
|
||||
: primitive_base(id, {input}, output_padding), reference_input(reference_input),
|
||||
offsets(offsets), op_mode(crop_ngraph_op_mode::none) {}
|
||||
|
||||
/// @brief Constructs crop primitive (borders variant).
|
||||
///
|
||||
@ -72,7 +82,8 @@ struct crop : public primitive_base<crop> {
|
||||
const tensor& rb_borders,
|
||||
const crop_borders_t,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input}, output_padding), reference_input(rb_borders.negate()), offsets(lt_borders) {}
|
||||
: primitive_base(id, {input}, output_padding), reference_input(rb_borders.negate()),
|
||||
offsets(lt_borders), op_mode(crop_ngraph_op_mode::none) {}
|
||||
|
||||
/// @brief Constructs crop primitive (symmetric borders variant).
|
||||
///
|
||||
@ -89,12 +100,37 @@ struct crop : public primitive_base<crop> {
|
||||
const tensor& xy_borders,
|
||||
const crop_borders_t,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input}, output_padding), reference_input(xy_borders.negate()), offsets(xy_borders) {}
|
||||
: primitive_base(id, {input}, output_padding), reference_input(xy_borders.negate()),
|
||||
offsets(xy_borders), op_mode(crop_ngraph_op_mode::none) {}
|
||||
|
||||
/// @brief Constructs crop primitive.
|
||||
/// @param id This primitive id.
|
||||
/// @param inputs Input primitive id vector.
|
||||
/// @param reference_input Reference input tensor with the required dimensions.
|
||||
/// @param offsets Input offsets.
|
||||
/// @param output_idx Output data index of splited output.
|
||||
/// @param num_splits The number of pieces that the data tensor should be split into.
|
||||
crop(const primitive_id& id,
|
||||
const std::vector<primitive_id>& inputs,
|
||||
const tensor& reference_input,
|
||||
const tensor& offsets,
|
||||
const crop_ngraph_op_mode op_mode,
|
||||
const int output_idx,
|
||||
const size_t num_splits = 1,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, inputs, output_padding), reference_input(reference_input),
|
||||
offsets(offsets), output_idx(output_idx), num_splits(num_splits), op_mode(op_mode) {}
|
||||
|
||||
/// @brief Reference input tensor with the required dimensions.
|
||||
tensor reference_input;
|
||||
/// @brief Input offsets.
|
||||
tensor offsets;
|
||||
/// @brief data index of splited output.
|
||||
int output_idx = 0;
|
||||
/// @brief num_splits which Split has number of split as property
|
||||
size_t num_splits = 1;
|
||||
/// @brief original ngraph operation type
|
||||
crop_ngraph_op_mode op_mode;
|
||||
};
|
||||
/// @}
|
||||
/// @}
|
||||
|
@ -6,9 +6,13 @@
|
||||
#include "primitive_type_base.h"
|
||||
#include "intel_gpu/runtime/memory.hpp"
|
||||
#include "intel_gpu/runtime/error_handler.hpp"
|
||||
#include "intel_gpu/plugin/common_utils.hpp"
|
||||
#include "json_object.h"
|
||||
#include <string>
|
||||
|
||||
#include "variadic_split_shape_inference.hpp"
|
||||
#include "split_shape_inference.hpp"
|
||||
|
||||
namespace cldnn {
|
||||
primitive_type_id crop::type_id() {
|
||||
static primitive_type_base<crop> instance;
|
||||
@ -38,6 +42,105 @@ layout crop_inst::calc_output_layout(crop_node const& node, kernel_impl_params c
|
||||
return layout({in_layout.data_type, in_layout.format, ref_in_sizes});
|
||||
}
|
||||
|
||||
template<typename ShapeType>
|
||||
std::vector<layout> crop_inst::calc_output_layouts(const crop_node& /*node*/, const kernel_impl_params& impl_param) {
|
||||
OPENVINO_ASSERT(static_cast<bool>(impl_param.desc->output_data_type) == false,
|
||||
"Output data type forcing is not supported for crop_node!");
|
||||
|
||||
auto desc = impl_param.typed_desc<crop>();
|
||||
const auto in_layout = impl_param.get_input_layout();
|
||||
std::vector<ShapeType> output_shapes = {ShapeType()};
|
||||
std::vector<ShapeType> input_shapes = {
|
||||
impl_param.input_layouts[0].get<ShapeType>(),
|
||||
};
|
||||
for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
|
||||
input_shapes.push_back(impl_param.input_layouts[i].get<ShapeType>());
|
||||
}
|
||||
|
||||
// TODO: calling shape_infer for all cropped outpus is redundant... Need to optimize.
|
||||
if (desc->op_mode == cldnn::crop_ngraph_op_mode::variadic_split) {
|
||||
std::map<size_t, ngraph::HostTensorPtr> const_data;
|
||||
|
||||
OPENVINO_ASSERT(impl_param.memory_deps.count(1) > 0, "[GPU] Can't find Crop(ngraph VariadicSplit op mode) axis values memory dependency");
|
||||
auto axis_values_mem = impl_param.memory_deps.at(1);
|
||||
cldnn::mem_lock<uint8_t, mem_lock_type::read> axis_values_mem_lock(axis_values_mem, impl_param.prog.get_stream());
|
||||
const_data.emplace(1, make_host_tensor(axis_values_mem->get_layout(), axis_values_mem_lock.data()));
|
||||
|
||||
OPENVINO_ASSERT(impl_param.memory_deps.count(2) > 0, "[GPU] Can't find Crop(ngraph VariadicSplit op mode) split length values memory dependency");
|
||||
auto split_length_mem = impl_param.memory_deps.at(2);
|
||||
cldnn::mem_lock<uint8_t, mem_lock_type::read> split_length_mem_lock(split_length_mem, impl_param.prog.get_stream());
|
||||
const_data.emplace(2, make_host_tensor(split_length_mem->get_layout(), split_length_mem_lock.data()));
|
||||
|
||||
ov::op::v1::VariadicSplit op;
|
||||
shape_infer(&op, input_shapes, output_shapes, const_data);
|
||||
} else if (desc->op_mode == cldnn::crop_ngraph_op_mode::split) {
|
||||
std::map<size_t, ngraph::HostTensorPtr> const_data;
|
||||
|
||||
OPENVINO_ASSERT(impl_param.memory_deps.count(1) > 0, "[GPU] Can't find Crop(ngraph Split op mode) axis values memory dependency");
|
||||
auto axis_values_mem = impl_param.memory_deps.at(1);
|
||||
cldnn::mem_lock<uint8_t, mem_lock_type::read> axis_values_mem_lock(axis_values_mem, impl_param.prog.get_stream());
|
||||
const_data.emplace(1, make_host_tensor(axis_values_mem->get_layout(), axis_values_mem_lock.data()));
|
||||
|
||||
ov::op::v1::Split op;
|
||||
op.set_num_splits(desc->num_splits);
|
||||
shape_infer(&op, input_shapes, output_shapes, const_data);
|
||||
} else if (desc->op_mode == cldnn::crop_ngraph_op_mode::none) {
|
||||
// Legacy usage
|
||||
if (in_layout.is_dynamic()) {
|
||||
auto in_shape = in_layout.get<ShapeType>();
|
||||
auto r = (in_shape.rank().is_static())? in_shape.size() : 1;
|
||||
return { layout{ShapeType::dynamic(r),
|
||||
in_layout.data_type, in_layout.format.adjust_to_rank(in_layout.format, r)} };
|
||||
}
|
||||
|
||||
const auto& ref_in_sizes = desc->reference_input;
|
||||
const auto& in_sizes = in_layout.get_tensor();
|
||||
const auto& offsets = desc->offsets;
|
||||
|
||||
// Check for borders variant of crop.
|
||||
if (ref_in_sizes.batch[0] < 0 || ref_in_sizes.feature[0] < 0 || ref_in_sizes.spatial[0] < 0 ||
|
||||
ref_in_sizes.spatial[1] < 0 || ref_in_sizes.spatial[2] < 0) {
|
||||
// Ignore not supported dimensions.
|
||||
const auto rb_sizes = ref_in_sizes.negate().sub({0, 0, 0, 0, 0});
|
||||
const auto lt_sizes = offsets.sub({0, 0, 0, 0, 0});
|
||||
const auto out_sizes = in_sizes - (rb_sizes + lt_sizes);
|
||||
|
||||
return {layout({in_layout.data_type, in_layout.format, out_sizes})};
|
||||
}
|
||||
return {layout({in_layout.data_type, in_layout.format, ref_in_sizes})};
|
||||
}
|
||||
|
||||
bool is_output_static = false;
|
||||
std::vector<layout> output_layouts;
|
||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||
output_layouts.push_back(layout({output_shapes[i], in_layout.data_type, in_layout.format}));
|
||||
is_output_static = (output_shapes[i].is_static()) ? true : is_output_static;
|
||||
}
|
||||
|
||||
// update split offsets
|
||||
if (is_output_static) {
|
||||
auto p_param = const_cast<kernel_impl_params*>(&impl_param);
|
||||
InferenceEngine::SizeVector startOffset(p_param->input_layouts[0].get_partial_shape().size());
|
||||
auto input_shape = p_param->input_layouts[0].get_partial_shape();
|
||||
auto dims = p_param->input_layouts[0].get_partial_shape().size();
|
||||
for (int32_t prev = 0; prev < desc->output_idx; prev++) {
|
||||
auto prev_crop_shape = output_layouts[prev].get_partial_shape().to_shape();
|
||||
for (size_t i = 0; i < dims; ++i) {
|
||||
if (prev_crop_shape[i] != input_shape.to_shape()[i])
|
||||
startOffset[i] += prev_crop_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (p_param->input_offsets.empty()) {
|
||||
p_param->input_offsets.resize(1);
|
||||
p_param->input_offsets[0] = desc->offsets;
|
||||
}
|
||||
|
||||
p_param->input_offsets[0] = ov::intel_gpu::tensor_from_dims(startOffset, 0);
|
||||
}
|
||||
return {output_layouts[desc->output_idx]};
|
||||
}
|
||||
|
||||
std::string crop_inst::to_string(crop_node const& node) {
|
||||
const auto& desc = node.get_primitive();
|
||||
auto ref_in_sizes = desc->reference_input;
|
||||
@ -72,64 +175,66 @@ std::string crop_inst::to_string(crop_node const& node) {
|
||||
crop_inst::typed_primitive_inst(network& network, crop_node const& node) : parent(network, node) {
|
||||
const auto& ref_in_sizes = argument.reference_input;
|
||||
const auto in_layout = node.input().get_output_layout();
|
||||
const auto& in_sizes = in_layout.get_tensor();
|
||||
const auto& offsets = argument.offsets;
|
||||
tensor null_tensor {};
|
||||
tensor value_tensor { 1, 1, 1, 1, 1 };
|
||||
|
||||
// Check for borders variant of crop.
|
||||
if (ref_in_sizes.batch[0] < 0 || ref_in_sizes.feature[0] < 0 || ref_in_sizes.spatial[0] < 0 ||
|
||||
ref_in_sizes.spatial[1] < 0 || ref_in_sizes.spatial[2] < 0) {
|
||||
// Ignore not supported dimensions.
|
||||
const auto rb_sizes = ref_in_sizes.negate().sub({0, 0, 0, 0, 0});
|
||||
const auto lt_sizes = offsets.sub({0, 0, 0, 0, 0});
|
||||
if (in_layout.is_static()) {
|
||||
const auto& in_sizes = in_layout.get_tensor();
|
||||
// Check for borders variant of crop.
|
||||
if (ref_in_sizes.batch[0] < 0 || ref_in_sizes.feature[0] < 0 || ref_in_sizes.spatial[0] < 0 ||
|
||||
ref_in_sizes.spatial[1] < 0 || ref_in_sizes.spatial[2] < 0) {
|
||||
// Ignore not supported dimensions.
|
||||
const auto rb_sizes = ref_in_sizes.negate().sub({0, 0, 0, 0, 0});
|
||||
const auto lt_sizes = offsets.sub({0, 0, 0, 0, 0});
|
||||
|
||||
const auto out_sizes = in_sizes - (rb_sizes + lt_sizes);
|
||||
const auto out_sizes = in_sizes - (rb_sizes + lt_sizes);
|
||||
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"Left/top/lower borders",
|
||||
lt_sizes,
|
||||
"0 value",
|
||||
null_tensor,
|
||||
"Invalid border size: negative");
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"Right/bottom/upper borders",
|
||||
rb_sizes,
|
||||
"0 value",
|
||||
null_tensor,
|
||||
"Invalid border size: negative");
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"Left/top/lower borders",
|
||||
lt_sizes,
|
||||
"0 value",
|
||||
null_tensor,
|
||||
"Invalid border size: negative");
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"Right/bottom/upper borders",
|
||||
rb_sizes,
|
||||
"0 value",
|
||||
null_tensor,
|
||||
"Invalid border size: negative");
|
||||
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"Input sizes - border sizes",
|
||||
out_sizes,
|
||||
"1 value",
|
||||
value_tensor,
|
||||
"Invalid border sizes: greater-equal input sizes");
|
||||
}
|
||||
|
||||
// check if output sizes matches reference input sizes
|
||||
CLDNN_ERROR_TENSOR_SIZES_GREATER_THAN(node.id(),
|
||||
"Reference input",
|
||||
ref_in_sizes,
|
||||
"input sizes",
|
||||
in_sizes,
|
||||
"Reference input tensor/ input tensor mismtach");
|
||||
|
||||
// check if offsets do not extend input sizes and if match the output sizes
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"Input sizes - border sizes",
|
||||
out_sizes,
|
||||
"1 value",
|
||||
value_tensor,
|
||||
"Invalid border sizes: greater-equal input sizes");
|
||||
"Batch offsets",
|
||||
offsets,
|
||||
"0 value",
|
||||
null_tensor,
|
||||
"Invalid Batch offset: negative value");
|
||||
auto input_size_sub_offsets = in_sizes - offsets;
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"input sizes - offsets",
|
||||
input_size_sub_offsets,
|
||||
"reference input sizes",
|
||||
ref_in_sizes,
|
||||
"Invalid Batch offset: exceeds data for output!");
|
||||
}
|
||||
|
||||
// check if output sizes matches reference input sizes
|
||||
CLDNN_ERROR_TENSOR_SIZES_GREATER_THAN(node.id(),
|
||||
"Reference input",
|
||||
ref_in_sizes,
|
||||
"input sizes",
|
||||
in_sizes,
|
||||
"Reference input tensor/ input tensor mismtach");
|
||||
|
||||
// check if offsets do not extend input sizes and if match the output sizes
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"Batch offsets",
|
||||
offsets,
|
||||
"0 value",
|
||||
null_tensor,
|
||||
"Invalid Batch offset: negative value");
|
||||
auto input_size_sub_offsets = in_sizes - offsets;
|
||||
CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(),
|
||||
"input sizes - offsets",
|
||||
input_size_sub_offsets,
|
||||
"reference input sizes",
|
||||
ref_in_sizes,
|
||||
"Invalid Batch offset: exceeds data for output!");
|
||||
|
||||
if (node.can_be_optimized()) {
|
||||
build_deps();
|
||||
reuse_input();
|
||||
|
@ -41,16 +41,12 @@ protected:
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const crop_node& arg, const kernel_impl_params& impl_param) {
|
||||
const auto& primitive = arg.get_primitive();
|
||||
|
||||
auto ew_params = get_default_params<kernel_selector::eltwise_params>(impl_param, 1);
|
||||
auto ew_optional_params = get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
|
||||
|
||||
ew_params.operations.push_back(
|
||||
{{kernel_selector::eltwise_params::InputType::Buffer(0)}, kernel_selector::eltwise_mode::ASSIGN});
|
||||
|
||||
const auto& input_layout = impl_param.input_layouts[0];
|
||||
ew_params.inputs[0] = convert_data_tensor(input_layout, 1, primitive->offsets);
|
||||
ew_params.inputs[0] = convert_data_tensor(impl_param.get_input_layout(), 1, impl_param.input_offsets[0]);
|
||||
|
||||
auto& kernel_selector = kernel_selector::eltwise_kernel_selector::Instance();
|
||||
auto best_kernels = kernel_selector.GetBestKernels(ew_params, ew_optional_params);
|
||||
|
@ -24,6 +24,22 @@ public:
|
||||
support_padding_all(true);
|
||||
}
|
||||
program_node& input() const { return get_dependency(0); }
|
||||
|
||||
std::vector<size_t> get_shape_infer_dependencies() const override {
|
||||
std::vector<size_t> vec;
|
||||
for (size_t i = 1; i < get_dependencies().size(); i++) {
|
||||
vec.push_back(i);
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
using parent::get_kernel_impl_params;
|
||||
std::unique_ptr<kernel_impl_params> get_kernel_impl_params(const std::vector<layout>& in_layouts, const layout& out_layout) const override {
|
||||
auto params = parent::get_kernel_impl_params(in_layouts, out_layout);
|
||||
params->input_offsets.reserve(1);
|
||||
params->input_offsets[0] = get_primitive()->offsets;
|
||||
return params;
|
||||
}
|
||||
};
|
||||
|
||||
using crop_node = typed_program_node<crop>;
|
||||
@ -33,6 +49,8 @@ class typed_primitive_inst<crop> : public typed_primitive_inst_base<crop> {
|
||||
using parent = typed_primitive_inst_base<crop>;
|
||||
|
||||
public:
|
||||
template<typename ShapeType>
|
||||
static std::vector<layout> calc_output_layouts(const crop_node& /*node*/, const kernel_impl_params& impl_param);
|
||||
static layout calc_output_layout(crop_node const& node, kernel_impl_params const& impl_param);
|
||||
static std::string to_string(crop_node const& node);
|
||||
typed_primitive_inst(network& network, crop_node const& node);
|
||||
|
@ -113,6 +113,7 @@ struct kernel_impl_params {
|
||||
size_t unique_id;
|
||||
std::vector<layout> input_layouts;
|
||||
layout output_layout;
|
||||
std::vector<tensor> input_offsets;
|
||||
std::vector<cldnn::fused_primitive_desc> fused_desc;
|
||||
std::vector<activation_func> fused_act_funcs;
|
||||
std::vector<activation_additional_params> activation_params;
|
||||
|
@ -10,10 +10,13 @@
|
||||
#include "arg_max_min_inst.h"
|
||||
#include "fully_connected_inst.h"
|
||||
#include "convolution_inst.h"
|
||||
#include "strided_slice_inst.h"
|
||||
#include "crop_inst.h"
|
||||
#include "deconvolution_inst.h"
|
||||
#include "shape_of_inst.h"
|
||||
#include "strided_slice_inst.h"
|
||||
#include "experimental_detectron_roi_feature_extractor_inst.hpp"
|
||||
#include "intel_gpu/plugin/common_utils.hpp"
|
||||
|
||||
#include "intel_gpu/graph/network.hpp"
|
||||
#include "intel_gpu/runtime/engine.hpp"
|
||||
@ -151,13 +154,19 @@ void primitive_inst::update_shape() {
|
||||
}
|
||||
}
|
||||
|
||||
if (input_shape_changed)
|
||||
set_shape_change();
|
||||
|
||||
// We assume that tensor ranks are static, thus shape_of doesn't need to update anything even if input shape is dynamic
|
||||
if (_node.is_type<shape_of>())
|
||||
if (_node.is_type<shape_of>() && !input_shape_changed)
|
||||
return;
|
||||
|
||||
// Even though the predecessors' shapes are not changed, the output shape might be udpated by the mem_dep
|
||||
auto memory_deps = _node.get_const_memory_deps();
|
||||
for (auto& i : _node.get_shape_infer_dependencies()) {
|
||||
if (memory_deps.count(i) > 0) {
|
||||
continue;
|
||||
}
|
||||
input_shape_changed = true;
|
||||
}
|
||||
|
||||
// Strided slice loads data from {1,2,3} dependencies in impl::create method.
|
||||
// It means that this data must be put into impl_params map
|
||||
// Thus we treat it as "dynamic" case
|
||||
@ -173,7 +182,9 @@ void primitive_inst::update_shape() {
|
||||
if (!strided_slice_wa && !input_shape_changed && !_node.generates_dynamic_output() && _impl_params->output_layout.is_static())
|
||||
return;
|
||||
|
||||
auto memory_deps = _node.get_const_memory_deps();
|
||||
if (input_shape_changed)
|
||||
set_shape_change();
|
||||
|
||||
std::vector<event::ptr> dependencies_events;
|
||||
auto queue_type = get_network().get_stream().get_queue_type();
|
||||
bool has_runtime_deps = false;
|
||||
@ -205,6 +216,7 @@ void primitive_inst::update_shape() {
|
||||
}
|
||||
|
||||
_impl_params->memory_deps = memory_deps;
|
||||
|
||||
auto new_layouts = _node.type()->calc_output_layouts(_node, *_impl_params);
|
||||
auto new_layout = new_layouts.empty() ? _node.type()->calc_output_layout(_node, *_impl_params) : new_layouts[0];
|
||||
new_layout.data_padding = padding::max(_node.get_primitive()->output_padding, new_layout.data_padding);
|
||||
|
@ -17,41 +17,51 @@ static void CreateCommonSplitOp(Program& p, const std::shared_ptr<ngraph::Node>&
|
||||
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
auto input_pshape = op->get_input_partial_shape(0);
|
||||
OPENVINO_ASSERT(input_pshape.is_static(),
|
||||
"Dynamic shapes are not supported yet for v1::Split and v1::VariadicSplit operations");
|
||||
auto inPartialShape = op->get_input_partial_shape(0);
|
||||
InferenceEngine::SizeVector startOffset(inPartialShape.size());
|
||||
|
||||
auto input_shape = input_pshape.to_shape();
|
||||
InferenceEngine::SizeVector start_offset(input_shape.size());
|
||||
cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split;
|
||||
size_t num_splits = 1;
|
||||
if (ngraph::is_type<ngraph::op::v1::Split>(op)) {
|
||||
auto split = ngraph::as_type_ptr<ngraph::op::v1::Split>(op);
|
||||
num_splits = split->get_num_splits();
|
||||
op_mode = cldnn::crop_ngraph_op_mode::split;
|
||||
}
|
||||
|
||||
bool is_single_out_split = op->get_output_size() == 1;
|
||||
|
||||
for (size_t i = 0; i < op->get_output_size(); i++) {
|
||||
std::string outLayerName = layerName + (is_single_out_split ? "" : ".out" + std::to_string(i));
|
||||
const auto outLayerDims = op->get_output_shape(i);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
if (outLayerDims.size() != start_offset.size()) {
|
||||
IE_THROW() << "Invalid dimesions in split layer: " << op->get_friendly_name()
|
||||
<< " output: " << op->get_output_tensor_name(i);
|
||||
}
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
if ((outLayerDims[i] + start_offset[i]) > input_shape[i]) {
|
||||
const auto outPatialShape = op->get_output_partial_shape(i);
|
||||
if (outPatialShape.is_static()) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
if (outPatialShape.size() != startOffset.size()) {
|
||||
IE_THROW() << "Invalid dimesions in split layer: " << op->get_friendly_name()
|
||||
<< " output: " << op->get_output_tensor_name(i);
|
||||
<< " output: " << op->get_output_tensor_name(i);
|
||||
}
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
||||
auto outTensor = tensor_from_dims(outLayerDims, 1);
|
||||
auto offsetTensor = tensor_from_dims(start_offset, 0);
|
||||
|
||||
auto cropPrim = cldnn::crop(outLayerName, inputPrimitives[0], outTensor, offsetTensor);
|
||||
p.add_primitive(*op, cropPrim);
|
||||
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
if (outLayerDims[i] != input_shape[i]) {
|
||||
start_offset[i] += outLayerDims[i];
|
||||
for (size_t i = 0; i < inPartialShape.size(); i++) {
|
||||
if ((outPatialShape[i].get_length() + static_cast<ov::Dimension::value_type>(startOffset[i])) > inPartialShape[i].get_length()) {
|
||||
IE_THROW() << "Invalid dimesions in split layer: " << op->get_friendly_name()
|
||||
<< " output: " << op->get_output_tensor_name(i);
|
||||
}
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
||||
auto outTensor = tensor_from_dims(outPatialShape.to_shape(), 1);
|
||||
auto offsetTensor = tensor_from_dims(startOffset, 0);
|
||||
auto cropPrim = cldnn::crop(outLayerName, inputPrimitives, outTensor, offsetTensor, op_mode, i, num_splits);
|
||||
|
||||
p.add_primitive(*op, cropPrim);
|
||||
|
||||
for (size_t i = 0; i < inPartialShape.size(); i++) {
|
||||
if (outPatialShape[i] != inPartialShape[i]) {
|
||||
startOffset[i] += outPatialShape.to_shape()[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto cropPrim = cldnn::crop(outLayerName, inputPrimitives, cldnn::tensor(1), cldnn::tensor(0), op_mode, i, num_splits);
|
||||
|
||||
p.add_primitive(*op, cropPrim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
222
src/plugins/intel_gpu/tests/shape_infer/crop_si_test.cpp
Normal file
222
src/plugins/intel_gpu/tests/shape_infer/crop_si_test.cpp
Normal file
@ -0,0 +1,222 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/crop.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
|
||||
#include "crop_inst.h"
|
||||
#include "concatenation_inst.h"
|
||||
|
||||
#include "program_wrapper.h"
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace shape_infer_tests {
|
||||
|
||||
struct crop_si_test_params {
|
||||
tensor reference_input_size;
|
||||
std::vector<tensor> offsets;
|
||||
std::vector<std::vector<int64_t>> const_values;
|
||||
std::vector<layout> input_layouts;
|
||||
std::vector<layout> expected_layouts;
|
||||
size_t param_num_splits;
|
||||
};
|
||||
|
||||
std::string to_string(const cldnn::layout& l);
|
||||
|
||||
std::ostream& operator<<(std::ostream& ost, std::vector<std::vector<int64_t>> vec) {
|
||||
ost << "{";
|
||||
for (auto inner_vec : vec) {
|
||||
ost << "{";
|
||||
for (auto v : inner_vec) {
|
||||
std::cout << v << ",";
|
||||
}
|
||||
ost << "},";
|
||||
}
|
||||
ost << "}";
|
||||
return ost;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ost, const crop_si_test_params& params) {
|
||||
ost << params.reference_input_size.to_string() << ",";
|
||||
ost << "{";
|
||||
for (auto& offset_tensor : params.offsets) {
|
||||
ost << offset_tensor.to_string() << ",";
|
||||
}
|
||||
ost << "}," << params.const_values << ",{";
|
||||
for (auto& t : params.input_layouts) {
|
||||
ost << to_string(t) << ",";
|
||||
}
|
||||
ost << "},{";
|
||||
for (auto& t : params.expected_layouts) {
|
||||
ost << to_string(t) << ",";
|
||||
}
|
||||
ost << "}";
|
||||
return ost;
|
||||
}
|
||||
|
||||
class crop_si_test : public testing::TestWithParam<crop_si_test_params> { };
|
||||
|
||||
TEST_P(crop_si_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
cldnn::program prog(engine);
|
||||
std::vector<std::shared_ptr<primitive>> input_prims;
|
||||
std::vector<std::string> input_prim_ids;
|
||||
{
|
||||
auto prim_id = "data0";
|
||||
auto data_layout_prim = std::make_shared<input_layout>(prim_id, p.input_layouts[0]);
|
||||
input_prims.push_back(data_layout_prim);
|
||||
input_prim_ids.push_back(prim_id);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < p.input_layouts.size(); i++) {
|
||||
auto prim_id = "const::data"+std::to_string(i);
|
||||
auto prim_mem = engine.allocate_memory(p.input_layouts[i]);
|
||||
set_values(prim_mem, p.const_values[i-1]);
|
||||
auto const_data_prim = std::make_shared<data>(prim_id, prim_mem);
|
||||
input_prims.push_back(const_data_prim);
|
||||
input_prim_ids.push_back(prim_id);
|
||||
}
|
||||
|
||||
crop_ngraph_op_mode op_mode = crop_ngraph_op_mode::none;
|
||||
if (p.const_values.size() == 2) {
|
||||
op_mode = crop_ngraph_op_mode::variadic_split;
|
||||
} else if (p.const_values.size() == 1) {
|
||||
op_mode = crop_ngraph_op_mode::split;
|
||||
}
|
||||
|
||||
for (size_t output_idx = 0; output_idx < p.expected_layouts.size(); output_idx++) {
|
||||
auto prim_id = "crop.out" + std::to_string(output_idx);
|
||||
auto crop_prim = std::make_shared<crop>(prim_id, input_prim_ids, p.reference_input_size, p.offsets[output_idx], op_mode, output_idx, p.param_num_splits);
|
||||
auto& crop_node = prog.get_or_create(crop_prim);
|
||||
|
||||
for (auto& prim : input_prims) {
|
||||
auto& input_node = prog.get_or_create(prim);
|
||||
program_wrapper::add_connection(prog, input_node, crop_node);
|
||||
}
|
||||
|
||||
auto params = crop_node.get_kernel_impl_params();
|
||||
auto res = crop_inst::calc_output_layouts<ov::PartialShape>(crop_node, *params);
|
||||
|
||||
ASSERT_EQ(res.size(), 1);
|
||||
ASSERT_EQ(res[0], p.expected_layouts[output_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, crop_si_test,
|
||||
testing::ValuesIn(std::vector<crop_si_test_params>{
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,0,0,0,1,1,1})},
|
||||
{{-1}, {1,1}},
|
||||
{{{1,32,2},data_types::f32,format::bfyx}, {{},data_types::i64,format::bfyx}, {{2},data_types::i64,format::bfyx}},
|
||||
{{{1,32,1},data_types::f32,format::bfyx}, {{1,32,1},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,0,0,0,1,1,1})},
|
||||
{{-1}, {1,1}},
|
||||
{{ov::PartialShape::dynamic(),data_types::f32,format::bfyx}, {{},data_types::i64,format::bfyx}, {{2},data_types::i64,format::bfyx}},
|
||||
{{ov::PartialShape::dynamic(),data_types::f32,format::bfyx}, {ov::PartialShape::dynamic(),data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({3,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,0,0,0,1,1,1})},
|
||||
{},
|
||||
{{ov::PartialShape::dynamic(),data_types::f32,format::bfyx}},
|
||||
{{ov::PartialShape::dynamic(1),data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({3,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,0,0,0,1,1,1})},
|
||||
{},
|
||||
{{{4},data_types::f32,format::bfyx}},
|
||||
{{{3},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({-1,-1,-1,-1,-1,-1,-1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,0,0,0,1,1,1})},
|
||||
{},
|
||||
{{{4,3,2,5},data_types::f32,format::bfyx}},
|
||||
{{{3,2,1,4},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,1,0,0,1,1,1}),tensor({0,2,0,0,1,1,1}),tensor({0,3,0,0,1,1,1})},
|
||||
{{1}, {1,1,1,1}},
|
||||
{{{4819,4,1,1,4},data_types::f32,format::bfzyx}, {{},data_types::i64,format::bfzyx}, {{4},data_types::i64,format::bfzyx}},
|
||||
{{{4819,1,1,1,4},data_types::f32,format::bfzyx}, {{4819,1,1,1,4},data_types::f32,format::bfzyx}, {{4819,1,1,1,4},data_types::f32,format::bfzyx}, {{4819,1,1,1,4},data_types::f32,format::bfzyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({4507,1,1,1,1,1,1}),
|
||||
{tensor({0,2,0,0,1,1,1})},
|
||||
{},
|
||||
{{{4507,3,1,1},data_types::f32,format::bfyx}},
|
||||
{{{4507,1,1,1},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,0,0,0,1,1,1})},
|
||||
{{2}, {11,3}},
|
||||
{{{1,14,14,384},data_types::f32,format::bfyx}, {{},data_types::i64,format::bfyx}, {{2},data_types::i64,format::bfyx}},
|
||||
{{{1,14,11,384},data_types::f32,format::bfyx}, {{1,14,3,384},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,1,2048,1,1,1,1}),
|
||||
{tensor({0,2,0,0,1,1,1})},
|
||||
{},
|
||||
{{{1,16,1,2048},data_types::f32,format::bfyx}},
|
||||
{{{1,1,1,2048},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}),tensor({0,1320,0,0,1,1,1})},
|
||||
{{1},{1320,99}},
|
||||
{{{1,1419},data_types::f32,format::bfyx}, {{},data_types::i64,format::bfyx}, {{2},data_types::i64,format::bfyx}},
|
||||
{{{1,1320},data_types::f32,format::bfyx}, {{1,99},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,128,2,64,1,1,1}),
|
||||
{tensor({0,0,8,0,1,1,1})},
|
||||
{},
|
||||
{{{1,128,64,10},data_types::f32,format::bfyx}},
|
||||
{{{1,128,64,2},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}), tensor({0,1,0,0,1,1,1})},
|
||||
{{1}},
|
||||
{{{4,2},data_types::f32,format::bfyx}, {{},data_types::i64,format::bfyx}},
|
||||
{{{4,1},data_types::f32,format::bfyx}, {{4,1},data_types::f32,format::bfyx}}, 2
|
||||
},
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1}), tensor({0,0,2048,0,1,1,1})},
|
||||
{{2}},
|
||||
{{{5,1,4096,1},data_types::f32,format::bfyx}, {{},data_types::i64,format::bfyx}},
|
||||
{{{5,1,2048,1},data_types::f32,format::bfyx}, {{5,1,2048,1},data_types::f32,format::bfyx}}, 2
|
||||
},
|
||||
{
|
||||
tensor({1,1400,1,1,1,1,1}),
|
||||
{tensor({0,100,0,0,1,1,1})},
|
||||
{},
|
||||
{{{1,1500,1,1},data_types::f32,format::bfyx}},
|
||||
{{{1,1400,1,1},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
{
|
||||
tensor({1,1,1,1,1,1,1}),
|
||||
{tensor({0,0,0,0,1,1,1})},
|
||||
{{2},{1}},
|
||||
{{{7,1,1},data_types::f32,format::bfyx}, {{},data_types::i64,format::bfyx}, {{1},data_types::i64,format::bfyx}},
|
||||
{{{7,1,1},data_types::f32,format::bfyx}}, 0
|
||||
},
|
||||
}));
|
||||
|
||||
}; // shape_infer_tests
|
@ -1242,3 +1242,181 @@ INSTANTIATE_TEST_SUITE_P(crop_test, crop_gpu,
|
||||
::testing::ValuesIn(crop_features),
|
||||
::testing::ValuesIn(formats)
|
||||
));
|
||||
|
||||
|
||||
TEST(crop_gpu, dynamic_i32_in2x3x2x2_crop_offsets) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto batch_num = 2;
|
||||
auto feature_num = 2;
|
||||
auto x_size = 3;
|
||||
auto y_size = 2;
|
||||
|
||||
auto crop_batch_num = batch_num - 1;
|
||||
auto crop_feature_num = feature_num;
|
||||
auto crop_x_size = x_size - 1;
|
||||
auto crop_y_size = y_size - 1;
|
||||
|
||||
auto batch_offset = 1;
|
||||
auto feature_offset = 0;
|
||||
auto x_offset = 1;
|
||||
auto y_offset = 1;
|
||||
|
||||
auto input_dyn_layout = layout{ ov::PartialShape{ov::Dimension(1, 10), feature_num, y_size, x_size}, data_types::f32, format::bfyx };
|
||||
auto input_actual_layout = layout{ ov::PartialShape{batch_num, feature_num, y_size, x_size}, data_types::f32, format::bfyx };
|
||||
|
||||
auto input = engine.allocate_memory(input_actual_layout);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input_dyn_layout));
|
||||
topology.add(crop("crop", "input", tensor(batch(crop_batch_num), spatial(crop_x_size, crop_y_size), feature(crop_feature_num)), { tensor(feature(0)) }));
|
||||
|
||||
std::vector<int32_t> input_vec = { 1, 0, 5, 15,
|
||||
2, 0, 6, 52,
|
||||
-10, -11, -12, -13,
|
||||
3, 50, 7, 12,
|
||||
4, -5, 8, 8,
|
||||
-14, -15, -16, -17 };
|
||||
set_values(input, input_vec);
|
||||
build_options bo;
|
||||
bo.set_option(build_option::allow_new_shape_infer(true));
|
||||
network network(engine, topology, bo);
|
||||
|
||||
network.set_input_data("input", input);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("crop").get_memory();
|
||||
cldnn::mem_lock<int32_t> output_ptr(output, get_test_stream());
|
||||
|
||||
for (int b = 0; b < crop_batch_num; ++b) { //B
|
||||
for (int f = 0; f < crop_feature_num; ++f) { //F
|
||||
for (int y = 0; y < crop_y_size; ++y) { //Y
|
||||
for (int x = 0; x < crop_x_size; ++x) { //X
|
||||
int linear_id = (b + batch_offset) * (feature_num * y_size * x_size) + (f + feature_offset) * (y_size * x_size) + (y + y_offset) * x_size + (x + x_offset);
|
||||
int output_linear_id = b * (crop_feature_num * crop_y_size * crop_x_size) + f * (crop_y_size * crop_x_size) + y * crop_x_size + x;
|
||||
EXPECT_EQ(output_ptr[output_linear_id], input_vec[linear_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(crop_gpu, dynamic_in1x4x1x1_split) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto batch_num = 1;
|
||||
auto feature_num = 4;
|
||||
auto x_size = 1;
|
||||
auto y_size = 1;
|
||||
|
||||
auto crop_batch_num = 1;
|
||||
auto crop_feature_num_1 = 2;
|
||||
auto crop_feature_num_2 = 2;
|
||||
auto crop_x_size = 1;
|
||||
auto crop_y_size = 1;
|
||||
auto feature_offset_1 = 0;
|
||||
auto feature_offset_2 = 2;
|
||||
|
||||
auto input_dyn_layout = layout{ ov::PartialShape{ov::Dimension(1, 10), feature_num, y_size, x_size}, data_types::f32, format::bfyx };
|
||||
auto input_actual_layout = layout{ ov::PartialShape{batch_num, feature_num, y_size, x_size}, data_types::f32, format::bfyx };
|
||||
|
||||
auto input_mem = engine.allocate_memory(input_actual_layout);
|
||||
auto data_mem = engine.allocate_memory({ {}, data_types::i64, format::bfyx });
|
||||
set_values(data_mem, {1});
|
||||
|
||||
cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::split;
|
||||
size_t num_splits = 2;
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input_dyn_layout));
|
||||
topology.add(data("data", data_mem));
|
||||
topology.add(crop("crop1", {"input", "data"}, tensor(batch(crop_batch_num), spatial(crop_x_size, crop_y_size), feature(crop_feature_num_1)), { tensor(feature(feature_offset_1), spatial(0,0),batch(0)) }, op_mode, 0, num_splits));
|
||||
topology.add(crop("crop2", {"input", "data"}, tensor(batch(crop_batch_num), spatial(crop_x_size, crop_y_size), feature(crop_feature_num_2)), { tensor(feature(feature_offset_2), spatial(0,0),batch(0)) }, op_mode, 1, num_splits));
|
||||
|
||||
std::vector<int32_t> input_vec = { -1, 2, -3, 4 };
|
||||
std::vector<int32_t> out1 = { -1, 2 };
|
||||
std::vector<int32_t> out2 = { -3, 4 };
|
||||
set_values(input_mem, input_vec);
|
||||
build_options bo;
|
||||
bo.set_option(build_option::allow_new_shape_infer(true));
|
||||
bo.set_option(build_option::optimize_data(true));
|
||||
bo.set_option(build_option::outputs(topology.get_primitives_ids()));
|
||||
|
||||
network network(engine, topology, bo);
|
||||
network.set_input_data("input", input_mem);
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("crop1").get_memory();
|
||||
cldnn::mem_lock<int32_t> output_ptr(output, get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < out1.size(); i++)
|
||||
EXPECT_EQ(output_ptr[i], out1[i]);
|
||||
|
||||
auto output_2 = outputs.at("crop2").get_memory();
|
||||
cldnn::mem_lock<int32_t> output_ptr_2(output_2, get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < out2.size(); i++)
|
||||
EXPECT_EQ(output_ptr_2[i], out2[i]);
|
||||
}
|
||||
|
||||
TEST(crop_gpu, dynamic_in1x4x1x1_varaidic_split) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto batch_num = 1;
|
||||
auto feature_num = 4;
|
||||
auto x_size = 1;
|
||||
auto y_size = 1;
|
||||
|
||||
auto crop_batch_num = 1;
|
||||
auto crop_feature_num_1 = 3;
|
||||
auto crop_feature_num_2 = 1;
|
||||
auto crop_x_size = 1;
|
||||
auto crop_y_size = 1;
|
||||
auto feature_offset_1 = 0;
|
||||
auto feature_offset_2 = 3;
|
||||
|
||||
auto input_dyn_layout = layout{ ov::PartialShape{ov::Dimension(1, 10), feature_num, y_size, x_size}, data_types::f32, format::bfyx };
|
||||
auto input_actual_layout = layout{ ov::PartialShape{batch_num, feature_num, y_size, x_size}, data_types::f32, format::bfyx };
|
||||
|
||||
auto input_mem = engine.allocate_memory(input_actual_layout);
|
||||
auto axis_mem = engine.allocate_memory({ {}, data_types::i64, format::bfyx });
|
||||
auto splits_length_mem = engine.allocate_memory({ {2}, data_types::i64, format::bfyx });
|
||||
|
||||
cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split;
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input_dyn_layout));
|
||||
topology.add(data("axis", axis_mem));
|
||||
topology.add(data("splits_length", splits_length_mem));
|
||||
topology.add(crop("crop1", {"input", "axis", "splits_length"}, tensor(batch(crop_batch_num), spatial(crop_x_size, crop_y_size), feature(crop_feature_num_1)), { tensor(feature(feature_offset_1), spatial(0,0),batch(0)) }, op_mode, 0));
|
||||
topology.add(crop("crop2", {"input", "axis", "splits_length"}, tensor(batch(crop_batch_num), spatial(crop_x_size, crop_y_size), feature(crop_feature_num_2)), { tensor(feature(feature_offset_2), spatial(0,0),batch(0)) }, op_mode, 1));
|
||||
|
||||
std::vector<int32_t> input_vec = { -1, 2, -3, 4 };
|
||||
std::vector<int32_t> out1 = { -1, 2, -3 };
|
||||
std::vector<int32_t> out2 = { 4 };
|
||||
std::vector<int64_t> splits_vec = {3, 1};
|
||||
|
||||
set_values(input_mem, input_vec);
|
||||
set_values(axis_mem, {1});
|
||||
set_values(splits_length_mem, splits_vec);
|
||||
|
||||
build_options bo;
|
||||
bo.set_option(build_option::allow_new_shape_infer(true));
|
||||
bo.set_option(build_option::optimize_data(true));
|
||||
bo.set_option(build_option::outputs(topology.get_primitives_ids()));
|
||||
|
||||
network network(engine, topology, bo);
|
||||
network.set_input_data("input", input_mem);
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("crop1").get_memory();
|
||||
cldnn::mem_lock<int32_t> output_ptr(output, get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < out1.size(); i++)
|
||||
EXPECT_EQ(output_ptr[i], out1[i]);
|
||||
|
||||
auto output_2 = outputs.at("crop2").get_memory();
|
||||
cldnn::mem_lock<int32_t> output_ptr_2(output_2, get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < out2.size(); i++)
|
||||
EXPECT_EQ(output_ptr_2[i], out2[i]);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user