[GPU] Dynamism support for Proposal (#18489)
This commit is contained in:
parent
67c88f4434
commit
3f67b3948d
@ -98,8 +98,10 @@ struct proposal : public primitive_base<proposal> {
|
||||
bool round_ratios,
|
||||
bool shift_anchors,
|
||||
bool normalize,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {cls_scores, bbox_pred, image_info}, {output_padding}),
|
||||
const padding& output_padding = padding(),
|
||||
data_types output_data_type = data_types::f32,
|
||||
const size_t num_outputs = 1)
|
||||
: primitive_base(id, {cls_scores, bbox_pred, image_info}, {output_padding}, {optional_data_type{output_data_type}}, num_outputs),
|
||||
max_proposals(max_proposals),
|
||||
iou_threshold(iou_threshold),
|
||||
base_bbox_size(base_bbox_size),
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "data_inst.h"
|
||||
#include "mutable_data_inst.h"
|
||||
#include "reshape_inst.h"
|
||||
#include "proposal_inst.h"
|
||||
#include "quantize_inst.h"
|
||||
#include "arg_max_min_inst.h"
|
||||
#include "fully_connected_inst.h"
|
||||
@ -73,7 +74,7 @@ void compile_graph::run(program& p) {
|
||||
if (node->is_dynamic() && !is_planar)
|
||||
can_select_impl = false;
|
||||
|
||||
if (node->is_type<condition>())
|
||||
if (node->is_type<condition>() || node->is_type<proposal>())
|
||||
can_select_impl = true;
|
||||
|
||||
if (can_select_impl) {
|
||||
|
@ -414,6 +414,15 @@ struct proposal_impl : typed_primitive_impl<proposal> {
|
||||
mem_lock<data_type_to_type<data_types::f32>::type, mem_lock_type::read> proposal_prob_ptr{proposal_probabilities, stream};
|
||||
execute<data_type_to_type<data_types::f32>::type>(stream, instance, im_info, proposal_prob_ptr.data());
|
||||
}
|
||||
} else if (instance.outputs_memory_count() == 2) {
|
||||
auto proposal_probabilities = instance.output_memory_ptr(1);
|
||||
if (instance.dep_memory(proposal_inst::cls_scores_index).get_layout().data_type == data_types::f16) {
|
||||
mem_lock<data_type_to_type<data_types::f16>::type, mem_lock_type::write> proposal_prob_ptr{proposal_probabilities, stream};
|
||||
execute<data_type_to_type<data_types::f16>::type>(stream, instance, im_info, proposal_prob_ptr.data());
|
||||
} else {
|
||||
mem_lock<data_type_to_type<data_types::f32>::type, mem_lock_type::write> proposal_prob_ptr{proposal_probabilities, stream};
|
||||
execute<data_type_to_type<data_types::f32>::type>(stream, instance, im_info, proposal_prob_ptr.data());
|
||||
}
|
||||
} else {
|
||||
if (instance.dep_memory(proposal_inst::cls_scores_index).get_layout().data_type == data_types::f16) {
|
||||
execute<data_type_to_type<data_types::f16>::type>(stream, instance, im_info);
|
||||
@ -430,7 +439,9 @@ struct proposal_impl : typed_primitive_impl<proposal> {
|
||||
|
||||
static std::unique_ptr<primitive_impl> create(const proposal_node& arg, const kernel_impl_params& impl_param) {
|
||||
const layout& l = impl_param.input_layouts[2];
|
||||
const size_t count = l.feature() == 1 ? static_cast<size_t>(l.batch()) : static_cast<size_t>(l.feature());
|
||||
if (l.is_static() && l.get_partial_shape().size() >= 2) {
|
||||
const size_t count = l.get_partial_shape()[1].get_length() == 1 ? l.get_partial_shape()[0].get_length() :
|
||||
l.get_partial_shape()[1].get_length();
|
||||
|
||||
// Supported image_info sizes and components meaning:
|
||||
// - image_info[3] = { img_height, img_width, img_depth }
|
||||
@ -439,6 +450,7 @@ struct proposal_impl : typed_primitive_impl<proposal> {
|
||||
if (count != 3 && count != 4 && count != 6) {
|
||||
CLDNN_ERROR_MESSAGE(arg.id(), "image_info must have either 3, 4 or 6 items");
|
||||
}
|
||||
}
|
||||
|
||||
return make_unique<proposal_impl>(arg);
|
||||
}
|
||||
@ -447,10 +459,17 @@ struct proposal_impl : typed_primitive_impl<proposal> {
|
||||
namespace detail {
|
||||
|
||||
attach_proposal_impl::attach_proposal_impl() {
|
||||
implementation_map<proposal>::add(impl_types::cpu, proposal_impl::create, {
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx)
|
||||
});
|
||||
auto formats = {
|
||||
format::bfyx
|
||||
};
|
||||
|
||||
auto types = {
|
||||
data_types::f32,
|
||||
data_types::f16
|
||||
};
|
||||
|
||||
implementation_map<proposal>::add(impl_types::cpu, shape_types::static_shape, proposal_impl::create, types, formats);
|
||||
implementation_map<proposal>::add(impl_types::cpu, shape_types::dynamic_shape, proposal_impl::create, types, formats);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -20,6 +20,7 @@ struct typed_program_node<proposal> : public typed_program_node_base<proposal> {
|
||||
program_node& cls_score() const { return get_dependency(0); }
|
||||
program_node& bbox_pred() const { return get_dependency(1); }
|
||||
program_node& image_info() const { return get_dependency(2); }
|
||||
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
|
||||
};
|
||||
|
||||
using proposal_node = typed_program_node<proposal>;
|
||||
@ -72,6 +73,8 @@ public:
|
||||
image_info_scale_depth_index,
|
||||
};
|
||||
|
||||
template<typename ShapeType>
|
||||
static std::vector<layout> calc_output_layouts(proposal_node const& node, kernel_impl_params const& impl_param);
|
||||
static layout calc_output_layout(proposal_node const& node, kernel_impl_params const& impl_param);
|
||||
static std::string to_string(proposal_node const& node);
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include "proposal_inst.h"
|
||||
#include "proposal_shape_inference.hpp"
|
||||
#include "primitive_type_base.h"
|
||||
#include "json_object.h"
|
||||
|
||||
@ -33,6 +34,40 @@ layout proposal_inst::calc_output_layout(proposal_node const& node, kernel_impl_
|
||||
{input_layout.batch() * desc->post_nms_topn, CLDNN_ROI_VECTOR_SIZE, 1, 1});
|
||||
}
|
||||
|
||||
template<typename ShapeType>
|
||||
std::vector<layout> proposal_inst::calc_output_layouts(proposal_node const& node, kernel_impl_params const& impl_param) {
|
||||
std::vector<layout> layouts;
|
||||
|
||||
auto desc = impl_param.typed_desc<proposal>();
|
||||
auto input0_layout = impl_param.get_input_layout(0);
|
||||
auto class_probs_shape = input0_layout.get<ShapeType>();
|
||||
|
||||
ov::op::v4::Proposal op;
|
||||
ov::op::v0::Proposal::Attributes attrs;
|
||||
attrs.base_size = desc->base_bbox_size;
|
||||
attrs.pre_nms_topn = desc->pre_nms_topn;
|
||||
attrs.post_nms_topn = desc->post_nms_topn;
|
||||
op.set_attrs(attrs);
|
||||
|
||||
ShapeType bbox_deltas_shape = impl_param.get_input_layout(1).get<ShapeType>();
|
||||
ShapeType image_shape_shape = impl_param.get_input_layout(2).get<ShapeType>();
|
||||
std::vector<ShapeType> input_shapes = {
|
||||
class_probs_shape,
|
||||
bbox_deltas_shape,
|
||||
image_shape_shape
|
||||
};
|
||||
|
||||
const auto output_shapes = ov::op::v4::shape_infer(&op, input_shapes);
|
||||
|
||||
for (size_t i = 0; i < desc->num_outputs; ++i) {
|
||||
auto dt = desc->output_data_types[i].value_or(input0_layout.data_type);
|
||||
layouts.push_back({output_shapes[i], dt, format::get_default_format(output_shapes[i].size())});
|
||||
}
|
||||
return layouts;
|
||||
}
|
||||
|
||||
template std::vector<layout> proposal_inst::calc_output_layouts<ov::PartialShape>(proposal_node const& node, const kernel_impl_params& impl_param);
|
||||
|
||||
static inline std::string stringify_vector(std::vector<float> v) {
|
||||
std::stringstream s;
|
||||
|
||||
|
@ -17,6 +17,7 @@ namespace intel_gpu {
|
||||
static void CreateProposalOp(Program& p, const std::shared_ptr<ngraph::op::v0::Proposal>& op) {
|
||||
validate_inputs_count(op, {3});
|
||||
auto inputs = p.GetInputInfo(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
auto attrs = op->get_attrs();
|
||||
float nms_thresh = attrs.nms_thresh;
|
||||
@ -54,6 +55,54 @@ static void CreateProposalOp(Program& p, const std::shared_ptr<ngraph::op::v0::P
|
||||
swap_xy = false;
|
||||
}
|
||||
|
||||
if (p.use_new_shape_infer()) {
|
||||
size_t num_outputs = op->get_output_size();
|
||||
auto get_output_paddings = [&]() {
|
||||
std::vector<cldnn::padding> output_paddings;
|
||||
for (size_t i = 0; i < num_outputs; i++)
|
||||
output_paddings.push_back(cldnn::padding());
|
||||
return output_paddings;
|
||||
};
|
||||
auto get_output_data_types = [&]() {
|
||||
std::vector<cldnn::optional_data_type> output_data_types;
|
||||
for (size_t i = 0; i < num_outputs; i++) {
|
||||
auto type = op->get_output_element_type(i);
|
||||
output_data_types.push_back(cldnn::element_type_to_data_type(type));
|
||||
}
|
||||
return output_data_types;
|
||||
};
|
||||
|
||||
auto proposalPrim = cldnn::proposal(layerName,
|
||||
inputs[0], // cls_score
|
||||
inputs[1], // bbox_pred
|
||||
inputs[2], // im_info
|
||||
0, // max_num_proposals is unused
|
||||
nms_thresh,
|
||||
base_size,
|
||||
min_size,
|
||||
feature_stride,
|
||||
pre_nms_topn,
|
||||
post_nms_topn,
|
||||
ratio,
|
||||
scale,
|
||||
coordinates_offset,
|
||||
box_coordinate_scale,
|
||||
box_size_scale,
|
||||
false,
|
||||
swap_xy,
|
||||
initial_clip,
|
||||
clip_before_nms,
|
||||
clip_after_nms,
|
||||
round_ratios,
|
||||
shift_anchors,
|
||||
normalize,
|
||||
cldnn::padding({0, 0, 0, 0}, 0),
|
||||
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
||||
num_outputs);
|
||||
proposalPrim.output_paddings = get_output_paddings();
|
||||
proposalPrim.output_data_types = get_output_data_types();
|
||||
p.add_primitive(*op, proposalPrim);
|
||||
} else {
|
||||
if (op->get_output_size() == 2) {
|
||||
auto mutable_precision = op->get_output_element_type(1);
|
||||
if (mutable_precision == ngraph::element::i64) {
|
||||
@ -64,16 +113,16 @@ static void CreateProposalOp(Program& p, const std::shared_ptr<ngraph::op::v0::P
|
||||
cldnn::format::get_default_format(op->get_output_shape(1).size()),
|
||||
tensor_from_dims(op->get_output_shape(1)));
|
||||
|
||||
GPU_DEBUG_LOG << "[" << layer_type_name_ID(op) << ": mutable data]" << std::endl;
|
||||
GPU_DEBUG_LOG << "[" << layerName << ": mutable data]" << std::endl;
|
||||
auto shared_memory = p.get_engine().allocate_memory(mutableLayout);
|
||||
|
||||
cldnn::primitive_id proposal_mutable_id_w = layer_type_name_ID(op) + "_md_write";
|
||||
cldnn::primitive_id proposal_mutable_id_w = layerName + "_md_write";
|
||||
auto argmax_mutable_prim = cldnn::mutable_data(proposal_mutable_id_w,
|
||||
shared_memory);
|
||||
p.add_primitive(*op, argmax_mutable_prim);
|
||||
inputs.push_back(cldnn::input_info(proposal_mutable_id_w));
|
||||
|
||||
std::string proposalLayerName = layer_type_name_ID(op) + ".out0";
|
||||
std::string proposalLayerName = layerName + ".out0";
|
||||
auto proposalPrim = cldnn::proposal(proposalLayerName,
|
||||
inputs[0], // cls_score
|
||||
inputs[1], // bbox_pred
|
||||
@ -102,16 +151,14 @@ static void CreateProposalOp(Program& p, const std::shared_ptr<ngraph::op::v0::P
|
||||
|
||||
p.add_primitive(*op, proposalPrim);
|
||||
|
||||
cldnn::primitive_id proposal_mutable_id_r = layer_type_name_ID(op) + ".out1";
|
||||
cldnn::primitive_id proposal_mutable_id_r = layerName + ".out1";
|
||||
auto argmax_mutable_prim_r = cldnn::mutable_data(proposal_mutable_id_r,
|
||||
{ cldnn::input_info(proposalLayerName) },
|
||||
shared_memory);
|
||||
p.add_primitive(*op, argmax_mutable_prim_r);
|
||||
return;
|
||||
}
|
||||
|
||||
std::string proposalLayerName = layer_type_name_ID(op);
|
||||
auto proposalPrim = cldnn::proposal(proposalLayerName,
|
||||
} else if (op->get_output_size() == 1) {
|
||||
auto proposalPrim = cldnn::proposal(layerName,
|
||||
inputs[0], // cls_score
|
||||
inputs[1], // bbox_pred
|
||||
inputs[2], // im_info
|
||||
@ -137,6 +184,10 @@ static void CreateProposalOp(Program& p, const std::shared_ptr<ngraph::op::v0::P
|
||||
normalize);
|
||||
|
||||
p.add_primitive(*op, proposalPrim);
|
||||
} else {
|
||||
IE_THROW() << op->get_friendly_name() << " Incorrect Proposal outputs number";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(v0, Proposal);
|
||||
|
@ -0,0 +1,134 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/proposal.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
|
||||
#include "proposal_inst.h"
|
||||
|
||||
#include "program_wrapper.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace shape_infer_tests {
|
||||
|
||||
const float iou_threshold = 0.7f;
|
||||
const int base_bbox_size = 16;
|
||||
const int min_bbox_size = 12;
|
||||
const int feature_stride = 16;
|
||||
const int pre_nms_topn = 6000;
|
||||
const int post_nms_topn = 300;
|
||||
const float coordinates_offset = 1.0f;
|
||||
const float box_coordinate_scale = 1.0f;
|
||||
const float box_size_scale = 1.0f;
|
||||
const bool swap_xy = false;
|
||||
const bool initial_clip = false;
|
||||
const bool clip_before_nms = true;
|
||||
const bool clip_after_nms = false;
|
||||
const bool round_ratios = true;
|
||||
const bool shift_anchors = false;
|
||||
const bool normalize = true;
|
||||
const std::vector<float> ratios = { 0.5f, 1.0f, 2.0f };
|
||||
const std::vector<float> scales = { 2.0f, 4.0f, 8.0f, 16.0f, 32.0f };
|
||||
|
||||
struct proposal_test_params {
|
||||
std::vector<layout> in_layouts;
|
||||
data_types output_data_type;
|
||||
size_t num_outputs;
|
||||
std::vector<layout> expected_layouts;
|
||||
};
|
||||
|
||||
class proposal_test : public testing::TestWithParam<proposal_test_params> { };
|
||||
|
||||
TEST_P(proposal_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
cldnn::program prog(engine);
|
||||
std::vector<std::shared_ptr<primitive>> input_prims;
|
||||
std::vector<input_info> input_prim_ids;
|
||||
for (size_t i = 0; i < p.in_layouts.size(); i++) {
|
||||
auto prim_id = "input" + std::to_string(i);
|
||||
auto input_layout_prim = std::make_shared<input_layout>(prim_id, p.in_layouts[i]);
|
||||
input_prims.push_back(input_layout_prim);
|
||||
input_prim_ids.push_back(input_info(prim_id));
|
||||
}
|
||||
|
||||
auto proposal_prim = std::make_shared<proposal>("depth_to_space",
|
||||
input_prim_ids[0],
|
||||
input_prim_ids[1],
|
||||
input_prim_ids[2],
|
||||
0,
|
||||
iou_threshold,
|
||||
base_bbox_size,
|
||||
min_bbox_size,
|
||||
feature_stride,
|
||||
pre_nms_topn,
|
||||
post_nms_topn,
|
||||
ratios,
|
||||
scales,
|
||||
coordinates_offset,
|
||||
box_coordinate_scale,
|
||||
box_size_scale,
|
||||
false,
|
||||
swap_xy,
|
||||
initial_clip,
|
||||
clip_before_nms,
|
||||
clip_after_nms,
|
||||
round_ratios,
|
||||
shift_anchors,
|
||||
normalize,
|
||||
padding({0, 0, 0, 0}, 0),
|
||||
p.output_data_type,
|
||||
p.num_outputs);
|
||||
std::vector<padding> output_paddings;
|
||||
std::vector<optional_data_type> output_data_types;
|
||||
for (size_t i = 0; i < p.num_outputs; i++) {
|
||||
output_paddings.push_back(padding());
|
||||
output_data_types.push_back(optional_data_type{p.output_data_type});
|
||||
}
|
||||
proposal_prim->output_paddings = output_paddings;
|
||||
proposal_prim->output_data_types = output_data_types;
|
||||
auto& proposal_node = prog.get_or_create(proposal_prim);
|
||||
for (auto& prim : input_prims) {
|
||||
auto& input_layout_node = prog.get_or_create(prim);
|
||||
program_wrapper::add_connection(prog, input_layout_node, proposal_node);
|
||||
}
|
||||
|
||||
auto res = proposal_inst::calc_output_layouts<ov::PartialShape>(proposal_node, *proposal_node.get_kernel_impl_params());
|
||||
|
||||
ASSERT_EQ(res.size(), p.num_outputs);
|
||||
for (size_t i = 0; i < p.expected_layouts.size(); i++)
|
||||
ASSERT_EQ(res[i], p.expected_layouts[i]);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, proposal_test,
|
||||
testing::ValuesIn(std::vector<proposal_test_params>{
|
||||
{
|
||||
{layout{ov::PartialShape{-1, 30, -1, -1}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{-1, 60, -1, -1}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{3}, data_types::f32, format::bfyx}},
|
||||
data_types::f32, 2,
|
||||
{layout{ov::PartialShape{ov::Dimension::dynamic(), 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(1), data_types::f32, format::bfyx}}
|
||||
},
|
||||
{
|
||||
{layout{ov::PartialShape{1, 24, -1, -1}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1, 48, -1, -1}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{3}, data_types::f32, format::bfyx}},
|
||||
data_types::f32, 2,
|
||||
{layout{ov::PartialShape{300, 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{300}, data_types::f32, format::bfyx}}
|
||||
},
|
||||
}));
|
||||
|
||||
} // shape_infer_tests
|
Loading…
Reference in New Issue
Block a user