[ONNX FE] Extend ONNX FE for operation GenerateProposals (#12510)
* Create ONNX FrontEnd GenerateProposals op * Add onnx GP Op validation * Add batch 2 test * Improve code readability .. per review comments * Fix test model paths * Use heterogeneous test values
This commit is contained in:
parent
8b75e8d4b9
commit
e24a5b8ac3
@ -65,7 +65,7 @@ All proposals of the whole batch are concated image by image, and distinguishabl
|
||||
* *nms_eta*
|
||||
|
||||
* **Description**: eta parameter for adaptive NMS.
|
||||
* **Range of values**: a floating-point number in close range `[0, 1.0]`.
|
||||
* **Range of values**: a floating-point number in closed range `[0, 1.0]`.
|
||||
* **Type**: float
|
||||
* **Default value**: `1.0`
|
||||
* **Required**: *no*
|
||||
|
@ -3589,7 +3589,7 @@ template <element::Type_t ET>
|
||||
bool evaluate(const shared_ptr<op::v9::GenerateProposals>& op,
|
||||
const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) {
|
||||
const auto attrs = op->get_attrs();
|
||||
const auto& attrs = op->get_attrs();
|
||||
|
||||
size_t post_nms_count = 0;
|
||||
if (attrs.post_nms_count < 0) {
|
||||
@ -3600,12 +3600,12 @@ bool evaluate(const shared_ptr<op::v9::GenerateProposals>& op,
|
||||
post_nms_count = static_cast<size_t>(attrs.post_nms_count);
|
||||
}
|
||||
|
||||
const auto output_type = op->get_input_element_type(0);
|
||||
const auto& output_type = op->get_input_element_type(0);
|
||||
|
||||
const auto im_info_shape = inputs[0]->get_shape();
|
||||
const auto anchors_shape = inputs[1]->get_shape();
|
||||
const auto deltas_shape = inputs[2]->get_shape();
|
||||
const auto scores_shape = inputs[3]->get_shape();
|
||||
const auto& im_info_shape = inputs[0]->get_shape();
|
||||
const auto& anchors_shape = inputs[1]->get_shape();
|
||||
const auto& deltas_shape = inputs[2]->get_shape();
|
||||
const auto& scores_shape = inputs[3]->get_shape();
|
||||
|
||||
const auto im_info_data = get_floats(inputs[0], im_info_shape);
|
||||
const auto anchors_data = get_floats(inputs[1], anchors_shape);
|
||||
@ -3639,7 +3639,7 @@ bool evaluate(const shared_ptr<op::v9::GenerateProposals>& op,
|
||||
outputs[1]->set_element_type(output_type);
|
||||
outputs[1]->set_shape(output_scores_shape);
|
||||
|
||||
const auto roi_num_type = op->get_output_element_type(2);
|
||||
const auto& roi_num_type = op->get_output_element_type(2);
|
||||
Shape output_roi_num_shape = Shape{im_info_shape[0]};
|
||||
outputs[2]->set_element_type(roi_num_type);
|
||||
outputs[2]->set_shape(output_roi_num_shape);
|
||||
|
@ -0,0 +1,163 @@
|
||||
ir_version: 8
|
||||
producer_name: "OpenVINO"
|
||||
graph {
|
||||
name: "just GenerateProposals"
|
||||
node {
|
||||
input: "scores"
|
||||
input: "deltas"
|
||||
input: "im_info"
|
||||
input: "anchors"
|
||||
output: "rpnrois"
|
||||
output: "rpnscores"
|
||||
output: "rpnroisnum"
|
||||
op_type: "GenerateProposals"
|
||||
domain: "org.openvinotoolkit"
|
||||
attribute {
|
||||
name: "pre_nms_topN"
|
||||
i: 1000
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "post_nms_topN"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "nms_thresh"
|
||||
f: 0.7
|
||||
type: FLOAT
|
||||
}
|
||||
attribute {
|
||||
name: "min_size"
|
||||
f: 1
|
||||
type: FLOAT
|
||||
}
|
||||
attribute {
|
||||
name: "legacy_plus_one"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "scores"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "deltas"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 12
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "im_info"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "anchors"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "rpnrois"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "num_rois"
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "rpnscores"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "num_rois"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "rpnroisnum"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: "org.openvinotoolkit"
|
||||
version: 1
|
||||
}
|
@ -0,0 +1,163 @@
|
||||
ir_version: 8
|
||||
producer_name: "OpenVINO"
|
||||
graph {
|
||||
name: "just GenerateProposals"
|
||||
node {
|
||||
input: "scores"
|
||||
input: "deltas"
|
||||
input: "im_info"
|
||||
input: "anchors"
|
||||
output: "rpnrois"
|
||||
output: "rpnscores"
|
||||
output: "rpnroisnum"
|
||||
op_type: "GenerateProposals"
|
||||
domain: "org.openvinotoolkit"
|
||||
attribute {
|
||||
name: "pre_nms_topN"
|
||||
i: 1000
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "post_nms_topN"
|
||||
i: 5
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "nms_thresh"
|
||||
f: 0.7
|
||||
type: FLOAT
|
||||
}
|
||||
attribute {
|
||||
name: "min_size"
|
||||
f: 1
|
||||
type: FLOAT
|
||||
}
|
||||
attribute {
|
||||
name: "legacy_plus_one"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "scores"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "deltas"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 12
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "im_info"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "anchors"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "rpnrois"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "num_rois"
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "rpnscores"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "num_rois"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "rpnroisnum"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: "org.openvinotoolkit"
|
||||
version: 1
|
||||
}
|
@ -22,6 +22,10 @@
|
||||
RETHROW_FRONTEND_EXCEPTION(ov::frontend::NotImplementedFailure) \
|
||||
RETHROW_FRONTEND_EXCEPTION(ov::AssertFailure) \
|
||||
RETHROW_FRONTEND_EXCEPTION(ov::Exception) \
|
||||
catch (const std::exception& e) { \
|
||||
const auto message = std::string(MESSAGE "\n") + e.what(); \
|
||||
OPENVINO_ASSERT(false, message); \
|
||||
} \
|
||||
catch (...) { \
|
||||
OPENVINO_ASSERT(false, (MESSAGE)); \
|
||||
}
|
||||
@ -37,6 +41,10 @@
|
||||
RETHROW_FRONTEND_EXCEPTION(ov::frontend::NotImplementedFailure) \
|
||||
RETHROW_FRONTEND_EXCEPTION(ov::AssertFailure) \
|
||||
RETHROW_FRONTEND_EXCEPTION(ov::Exception) \
|
||||
catch (const std::exception& e) { \
|
||||
const auto message = std::string(MESSAGE "\n") + e.what(); \
|
||||
OPENVINO_ASSERT(false, message); \
|
||||
} \
|
||||
catch (...) { \
|
||||
OPENVINO_ASSERT(false, (MESSAGE)); \
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ static const std::vector<std::string> legacy_ops_to_fixup = {"DeformableConv2D",
|
||||
"ExperimentalDetectronROIFeatureExtractor",
|
||||
"ExperimentalDetectronTopKROIs",
|
||||
"FakeQuantize",
|
||||
"GenerateProposals",
|
||||
"GroupNorm",
|
||||
"Normalize",
|
||||
"PriorBox",
|
||||
|
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "generate_proposals.hpp"
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
|
||||
namespace {
|
||||
void validate_generate_proposals_inputs(const OutputVector& inputs) {
|
||||
OPENVINO_ASSERT(inputs.size() == 4, "GenerateProposals operator expects 4 inputs, got ", inputs.size());
|
||||
|
||||
const auto scores_rank = inputs[0].get_partial_shape().rank();
|
||||
OPENVINO_ASSERT(scores_rank.compatible(4), "GenerateProposals input scores rank should be 4, is ", scores_rank);
|
||||
|
||||
const auto& anchors_shape = inputs[3].get_partial_shape();
|
||||
const auto anchors_rank = anchors_shape.rank();
|
||||
OPENVINO_ASSERT(anchors_rank == Rank(2), "GenerateProposals input anchors rank should be 2, is ", anchors_rank);
|
||||
OPENVINO_ASSERT(anchors_shape[1].compatible(4),
|
||||
"GenerateProposals input anchors shape should be {A, 4}, is ",
|
||||
anchors_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector generate_proposals(const Node& node) {
|
||||
const auto inputs = node.get_ng_inputs();
|
||||
validate_generate_proposals_inputs(inputs);
|
||||
|
||||
const auto& scores = inputs[0]; // shape [N, A, H, W]
|
||||
const auto& deltas = inputs[1]; // shape [N, A*4, H, W]
|
||||
const auto& im_info = inputs[2]; // shape [N, 3] or [N, 4]
|
||||
const auto& anchors = inputs[3]; // shape [A, 4]
|
||||
|
||||
ov::op::v9::GenerateProposals::Attributes attrs;
|
||||
attrs.min_size = node.get_attribute_value<float>("min_size", 1.f);
|
||||
attrs.nms_threshold = node.get_attribute_value<float>("nms_thresh", 0.7f);
|
||||
attrs.pre_nms_count = node.get_attribute_value<int64_t>("pre_nms_topN", 6000);
|
||||
attrs.post_nms_count = node.get_attribute_value<int64_t>("post_nms_topN", 300);
|
||||
attrs.normalized = !node.get_attribute_value<int64_t>("legacy_plus_one", true);
|
||||
|
||||
// Broadcast anchors from [A, 4] to [H, W, A, 4] where [H, W] is taken from scores shape.
|
||||
const auto zero = default_opset::Constant::create(element::i64, Shape{1}, {0});
|
||||
const auto scores_shape = std::make_shared<default_opset::ShapeOf>(scores);
|
||||
const auto anchors_shape = std::make_shared<default_opset::ShapeOf>(anchors);
|
||||
const auto scores_shape_tail = default_opset::Constant::create(element::i64, Shape{2}, {2, 3});
|
||||
const auto new_anchors_shape_front = std::make_shared<default_opset::Gather>(scores_shape, scores_shape_tail, zero);
|
||||
const auto new_anchors_shape =
|
||||
std::make_shared<default_opset::Concat>(OutputVector{new_anchors_shape_front, anchors_shape}, 0);
|
||||
const auto new_anchors = std::make_shared<default_opset::Broadcast>(anchors, new_anchors_shape);
|
||||
|
||||
const auto proposals = std::make_shared<ov::op::v9::GenerateProposals>(im_info, new_anchors, deltas, scores, attrs);
|
||||
|
||||
return proposals->outputs();
|
||||
}
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
@ -0,0 +1,18 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "onnx_import/core/node.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
OutputVector generate_proposals(const Node& node);
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
@ -111,6 +111,7 @@
|
||||
#include "op/org.openvinotoolkit/experimental_detectron/roi_feature_extractor.hpp"
|
||||
#include "op/org.openvinotoolkit/experimental_detectron/topk_rios.hpp"
|
||||
#include "op/org.openvinotoolkit/fake_quantize.hpp"
|
||||
#include "op/org.openvinotoolkit/generate_proposals.hpp"
|
||||
#include "op/org.openvinotoolkit/group_norm.hpp"
|
||||
#include "op/org.openvinotoolkit/normalize.hpp"
|
||||
#include "op/org.openvinotoolkit/prior_box.hpp"
|
||||
@ -494,6 +495,7 @@ OperatorsBridge::OperatorsBridge() {
|
||||
1,
|
||||
experimental_detectron_topk_rois);
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "FakeQuantize", 1, fake_quantize);
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "GenerateProposals", 1, generate_proposals);
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "GroupNorm", 1, group_norm);
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Normalize", 1, normalize);
|
||||
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "PriorBox", 1, prior_box);
|
||||
|
@ -561,3 +561,77 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_deformable_conv_2d) {
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_generate_proposals) {
|
||||
auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/org.openvinotoolkit/generate_proposals.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
|
||||
// scores
|
||||
test_case.add_input<float>(
|
||||
Shape{1, 3, 2, 6},
|
||||
{0.56637216, 0.90457034, 0.69827306, 0.4353543, 0.47985056, 0.42658508, 0.14516132, 0.08081771, 0.1799732,
|
||||
0.9229515, 0.42420176, 0.50857586, 0.82664067, 0.4972319, 0.3752427, 0.56731623, 0.18241242, 0.33252355,
|
||||
0.30608943, 0.6572437, 0.69185436, 0.88646156, 0.36985755, 0.5590753, 0.5256446, 0.03342898, 0.1344396,
|
||||
0.68642473, 0.37953874, 0.32575172, 0.21108444, 0.5661886, 0.45378175, 0.62126315, 0.26799858, 0.37272978});
|
||||
// deltas
|
||||
test_case.add_input<float>(
|
||||
Shape{1, 12, 2, 6},
|
||||
{0.5337073, 0.86607957, 0.55151343, 0.21626699, 0.4462629, 0.03985678, 0.5157072, 0.9932138, 0.7565954,
|
||||
0.43803605, 0.802818, 0.14834064, 0.53932905, 0.14314, 0.3817048, 0.95075196, 0.05516243, 0.2567484,
|
||||
0.25508744, 0.77438325, 0.43561, 0.2094628, 0.8299043, 0.44982538, 0.95615596, 0.5651084, 0.11801951,
|
||||
0.05352486, 0.9774733, 0.14439464, 0.62644225, 0.14370479, 0.54161614, 0.557915, 0.53102225, 0.0840179,
|
||||
0.7249888, 0.9843559, 0.5490522, 0.53788143, 0.822474, 0.3278008, 0.39688024, 0.3286012, 0.5117038,
|
||||
0.04743988, 0.9408995, 0.29885054, 0.81039643, 0.85277915, 0.06807619, 0.86430097, 0.36225632, 0.16606331,
|
||||
0.5401001, 0.7541649, 0.11998601, 0.5131829, 0.40606487, 0.327888, 0.27721855, 0.6378373, 0.22795396,
|
||||
0.4961256, 0.3215895, 0.15607187, 0.14782153, 0.8908137, 0.8835288, 0.834191, 0.29907143, 0.7983525,
|
||||
0.755875, 0.30837986, 0.0839176, 0.26624718, 0.04371626, 0.09472824, 0.20689541, 0.37622106, 0.1083321,
|
||||
0.1342548, 0.05815459, 0.7676379, 0.8105144, 0.92348766, 0.26761323, 0.7183306, 0.8947588, 0.19020908,
|
||||
0.42731014, 0.7473663, 0.85775334, 0.9340091, 0.3278848, 0.755993, 0.05307213, 0.39705503, 0.21003333,
|
||||
0.5625373, 0.66188884, 0.80521655, 0.6125863, 0.44678232, 0.97802377, 0.0204936, 0.02686367, 0.7390654,
|
||||
0.74631, 0.58399844, 0.5988792, 0.37413648, 0.5946692, 0.6955776, 0.36377597, 0.7891322, 0.40900692,
|
||||
0.99139464, 0.50169915, 0.41435778, 0.17142445, 0.26761186, 0.31591868, 0.14249913, 0.12919712, 0.5418711,
|
||||
0.6523203, 0.50259084, 0.7379765, 0.01171071, 0.94423133, 0.00841132, 0.97486794, 0.2921785, 0.7633071,
|
||||
0.88477814, 0.03563205, 0.50833166, 0.01354555, 0.535081, 0.41366324, 0.0694767, 0.9944055, 0.9981207});
|
||||
// im_info
|
||||
test_case.add_input<float>(Shape{1, 3}, {200, 200, 0});
|
||||
// anchors
|
||||
test_case.add_input<float>(Shape{3, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
|
||||
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{6, 4},
|
||||
{0.12904608, 1.3703424, 3.6230984, 3.4675088, 0.9725206, 0., 4.4917974, 4.9623675,
|
||||
4.882682, 5.1236916, 7.1700497, 10.213073, 4.4913187, 4.305372, 8.750267, 8.803502,
|
||||
0.9777608, 1.0317986, 3.228293, 4.495021, 4.125554, 5.4091997, 6.35439, 10.124915});
|
||||
test_case.add_expected_output<float>(Shape{6},
|
||||
{0.9229515, 0.90457034, 0.88646156, 0.82664067, 0.69827306, 0.69185436});
|
||||
test_case.add_expected_output<int64_t>(Shape{1}, {6});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_generate_proposals_batch) {
|
||||
auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/org.openvinotoolkit/generate_proposals_batch2.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
|
||||
// scores
|
||||
test_case.add_input<float>(Shape{2, 3, 2, 3}, {5, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 7, 1, 1, 1, 1,
|
||||
1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 8, 1});
|
||||
// deltas
|
||||
test_case.add_input<float>(Shape{2, 12, 2, 3}, std::vector<float>(144, 1));
|
||||
// im_info
|
||||
test_case.add_input<float>(Shape{2, 3}, {1, 1, 0, 1, 1, 0});
|
||||
// anchors
|
||||
test_case.add_input<float>(Shape{3, 4}, std::vector<float>(12, 1));
|
||||
|
||||
test_case.add_expected_output<float>(Shape{10, 4}, std::vector<float>(40, 1));
|
||||
test_case.add_expected_output<float>(Shape{10}, {7, 5, 3, 1, 1, 8, 4, 2, 1, 1});
|
||||
test_case.add_expected_output<int64_t>(Shape{2}, {5, 5});
|
||||
test_case.run();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user