* Fix NMS1/3/4 to Legacy conversion * Added functoinal tests
This commit is contained in:
parent
84e775af29
commit
2ed92688dd
@ -12,7 +12,6 @@
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_matmul_to_fc_or_gemm.hpp"
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp"
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp"
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.hpp"
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp"
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_normalizel2_to_normalize_ie.hpp"
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/convert_one_hot_to_one_hot_ie.hpp"
|
||||
@ -45,6 +44,7 @@
|
||||
|
||||
#include <transformations/common_optimizations/conv_bias_fusion.hpp>
|
||||
#include <transformations/op_conversions/convert_convolutions.hpp>
|
||||
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
|
||||
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
@ -129,7 +129,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
anchor->add_matcher<ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(f);
|
||||
anchor->add_matcher<ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertNMSToNMSIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertNMS1ToNMS5>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertNMS3ToNMS5>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertNMS4ToNMS5>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertNMS5ToLegacyMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertGRUSequenceMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertRNNSequenceMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertLSTMSequenceMatcher>();
|
||||
|
@ -16,7 +16,6 @@ namespace pass {
|
||||
class TRANSFORMATIONS_API ConvertNMS1ToNMS5;
|
||||
class TRANSFORMATIONS_API ConvertNMS3ToNMS5;
|
||||
class TRANSFORMATIONS_API ConvertNMS4ToNMS5;
|
||||
class TRANSFORMATIONS_API ConvertPreviousNMSToNMS5;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@ -39,12 +38,3 @@ public:
|
||||
ConvertNMS4ToNMS5();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertPreviousNMSToNMS5: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertPreviousNMSToNMS5() {
|
||||
add_matcher<ngraph::pass::ConvertNMS1ToNMS5>();
|
||||
add_matcher<ngraph::pass::ConvertNMS3ToNMS5>();
|
||||
add_matcher<ngraph::pass::ConvertNMS4ToNMS5>();
|
||||
}
|
||||
};
|
||||
|
@ -113,8 +113,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConvertInterpolate1ToInterpolate4, false>();
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPreviousNMSToNMS5>();
|
||||
|
||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeReshapeFusion>();
|
||||
|
@ -125,7 +125,7 @@ struct NMSAttributes {
|
||||
return attrs;
|
||||
}
|
||||
|
||||
bool callback_func(pattern::Matcher &m) {
|
||||
bool callback_func(pattern::Matcher &m, pass::MatcherPass * impl) {
|
||||
auto root = m.get_match_root();
|
||||
|
||||
auto attrs = get_nms_attrs(root);
|
||||
@ -141,20 +141,7 @@ struct NMSAttributes {
|
||||
const auto& arg3 = num_of_args > 3 ? new_args.at(3) : ngraph::opset5::Constant::create(element::f32, Shape{}, {.0f});
|
||||
const auto& arg4 = num_of_args > 4 ? new_args.at(4) : ngraph::opset5::Constant::create(element::f32, Shape{}, {.0f});
|
||||
|
||||
// list of new nGraph operations
|
||||
std::list<std::shared_ptr<::ngraph::Node>> new_ops_list;
|
||||
|
||||
if (num_of_args <= 4) {
|
||||
new_ops_list.push_front(arg4.get_node_shared_ptr());
|
||||
}
|
||||
if (num_of_args <= 3) {
|
||||
new_ops_list.push_front(arg3.get_node_shared_ptr());
|
||||
}
|
||||
if (num_of_args <= 2) {
|
||||
new_ops_list.push_front(arg2.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
const auto nms_5 = std::make_shared<ngraph::op::v5::NonMaxSuppression>(
|
||||
const auto nms_5 = impl->register_new_node<ngraph::op::v5::NonMaxSuppression>(
|
||||
new_args.at(0),
|
||||
new_args.at(1),
|
||||
arg2,
|
||||
@ -164,25 +151,20 @@ struct NMSAttributes {
|
||||
attrs.sort_result_descending,
|
||||
attrs.output_type);
|
||||
|
||||
new_ops_list.push_back(nms_5);
|
||||
|
||||
// vector of new nGraph operations
|
||||
NodeVector new_ops(new_ops_list.begin(), new_ops_list.end());
|
||||
|
||||
nms_5->set_friendly_name(root->get_friendly_name());
|
||||
ngraph::copy_runtime_info(root, new_ops);
|
||||
ngraph::copy_runtime_info(root, nms_5);
|
||||
root->output(0).replace(nms_5->output(0));
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPreviousNMSToNMS5, "ConvertPreviousNMSToNMS5", 0);
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNMS4ToNMS5, "ConvertNMS4ToNMS5", 0);
|
||||
|
||||
ngraph::pass::ConvertNMS4ToNMS5::ConvertNMS4ToNMS5() {
|
||||
auto nms = ngraph::pattern::wrap_type<ngraph::opset4::NonMaxSuppression>();
|
||||
ngraph::matcher_pass_callback callback = callback_func;
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
return callback_func(m, this);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMS4ToNMS5");
|
||||
this->register_matcher(m, callback);
|
||||
@ -192,7 +174,9 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNMS3ToNMS5, "ConvertNMS3ToNMS5", 0);
|
||||
|
||||
ngraph::pass::ConvertNMS3ToNMS5::ConvertNMS3ToNMS5() {
|
||||
auto nms = ngraph::pattern::wrap_type<ngraph::opset3::NonMaxSuppression>();
|
||||
ngraph::matcher_pass_callback callback = callback_func;
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
return callback_func(m, this);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMS3ToNMS5");
|
||||
this->register_matcher(m, callback);
|
||||
@ -202,7 +186,9 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNMS1ToNMS5, "ConvertNMS1ToNMS5", 0);
|
||||
|
||||
ngraph::pass::ConvertNMS1ToNMS5::ConvertNMS1ToNMS5() {
|
||||
auto nms = ngraph::pattern::wrap_type<ngraph::opset1::NonMaxSuppression>();
|
||||
ngraph::matcher_pass_callback callback = callback_func;
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
return callback_func(m, this);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMS1ToNMS5");
|
||||
this->register_matcher(m, callback);
|
||||
|
@ -11,7 +11,7 @@
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
@ -366,149 +366,3 @@ TEST_F(NGraphReaderTests, ReadNonMaxSuppression5) {
|
||||
fp32w[4] = 0.0;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(NGraphReaderTests, ReadNonMaxSuppression4) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="Network" version="10">
|
||||
<layers>
|
||||
<layer id="0" name="in1" type="Parameter" version="opset1">
|
||||
<data element_type="f32" shape="1,15130,4"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>15130</dim>
|
||||
<dim>4</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="in2" type="Parameter" version="opset1">
|
||||
<data element_type="f32" shape="1,80,15130"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>80</dim>
|
||||
<dim>15130</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="2" name="max_output_boxes_per_class" precision="I64" type="Const" version="opset1">
|
||||
<data offset="0" size="8"/>
|
||||
<output>
|
||||
<port id="0" precision="I64"/>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="iou_threshold" precision="FP32" type="Const" version="opset1">
|
||||
<data offset="8" size="4"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32"/>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="4" name="score_threshold" precision="FP32" type="Const" version="opset1">
|
||||
<data offset="12" size="4"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32"/>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="5" name="nms" type="NonMaxSuppression" version="opset4">
|
||||
<data box_encoding="corner" output_type="i32" sort_result_descending="0"/>
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>15130</dim>
|
||||
<dim>4</dim>
|
||||
</port>
|
||||
<port id="1" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>80</dim>
|
||||
<dim>15130</dim>
|
||||
</port>
|
||||
<port id="2" precision="I64"/>
|
||||
<port id="3" precision="FP32"/>
|
||||
<port id="4" precision="FP32"/>
|
||||
</input>
|
||||
<output>
|
||||
<port id="5" precision="I32">
|
||||
<dim>16000</dim>
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="6" name="mul" type="Multiply" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="I32">
|
||||
<dim>16000</dim>
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
<port id="1" precision="I32">
|
||||
<dim>16000</dim>
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="I32">
|
||||
<dim>16000</dim>
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="7" name="output" type="Result" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="I32">
|
||||
<dim>16000</dim>
|
||||
<dim>3</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="5" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="5" to-port="1"/>
|
||||
<edge from-layer="2" from-port="0" to-layer="5" to-port="2"/>
|
||||
<edge from-layer="3" from-port="0" to-layer="5" to-port="3"/>
|
||||
<edge from-layer="4" from-port="0" to-layer="5" to-port="4"/>
|
||||
<edge from-layer="5" from-port="5" to-layer="6" to-port="0"/>
|
||||
<edge from-layer="5" from-port="5" to-layer="6" to-port="1"/>
|
||||
<edge from-layer="6" from-port="2" to-layer="7" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
|
||||
constexpr size_t weightsSize = 16;
|
||||
|
||||
Blob::Ptr weights;
|
||||
|
||||
weights = make_shared_blob<uint8_t>(TensorDesc(Precision::U8, {weightsSize}, Layout::C));
|
||||
weights->allocate();
|
||||
CommonTestUtils::fill_data(weights->buffer().as<float *>(), weights->size() / sizeof(float));
|
||||
|
||||
auto * i64w = weights->buffer().as<int64_t*>();
|
||||
i64w[0] = 200;
|
||||
|
||||
auto * fp32w = weights->buffer().as<float*>();
|
||||
fp32w[2] = 0.5;
|
||||
fp32w[3] = 0.05;
|
||||
|
||||
Core ie;
|
||||
|
||||
auto ngraphImpl = ie.ReadNetwork(model, weights);
|
||||
auto graph = ngraph::clone_function(*ngraphImpl.getFunction());
|
||||
|
||||
::ngraph::pass::Manager manager;
|
||||
manager.register_pass<::ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<::ngraph::pass::CommonOptimizations>();
|
||||
manager.run_passes(graph);
|
||||
|
||||
auto boxes = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 15130, 4});
|
||||
auto scores = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 80, 15130});
|
||||
auto max_output_boxes_per_class = opset5::Constant::create(element::i64, Shape{}, {200});
|
||||
auto iou_threshold = opset5::Constant::create(element::f32, Shape{}, {0.5});
|
||||
auto score_threshold = opset5::Constant::create(element::f32, Shape{}, {0.05});
|
||||
auto nms = std::make_shared<opset5::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold,
|
||||
opset5::NonMaxSuppression::BoxEncodingType::CORNER, true, ngraph::element::i32);
|
||||
|
||||
auto mul = std::make_shared<opset5::Multiply>(nms, nms);
|
||||
auto graph_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{boxes, scores});
|
||||
|
||||
auto res = compare_functions(graph, graph_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <legacy/ngraph_ops/nms_ie.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -141,3 +142,121 @@ TEST(TransformationTests, ConvertNMSToNMSIEDynamic2) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertNMST1oNMSIE) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto boxes = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = opset1::Constant::create(element::i64, Shape{}, {10});
|
||||
auto iou_threshold = opset1::Constant::create(element::f32, Shape{}, {0.75});
|
||||
auto score_threshold = opset1::Constant::create(element::f32, Shape{}, {0.7});
|
||||
auto nms = std::make_shared<opset1::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, op::v1::NonMaxSuppression::BoxEncodingType::CORNER, true);
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||
|
||||
const auto & orig_shape = f->get_output_partial_shape(0);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
{
|
||||
auto boxes = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = opset1::Constant::create(element::i64, Shape{1}, {10});
|
||||
auto iou_threshold = opset1::Constant::create(element::f32, Shape{1}, {0.75});
|
||||
auto score_threshold = opset1::Constant::create(element::f32, Shape{1}, {0.7});
|
||||
auto nms = std::make_shared<op::NonMaxSuppressionIE3>(boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, 0, true, element::i32);
|
||||
auto convert = std::make_shared<opset1::Convert>(nms->output(0), element::i64);
|
||||
|
||||
f_ref = std::make_shared<Function>(NodeVector{convert}, ParameterVector{boxes, scores});
|
||||
ASSERT_TRUE(f_ref->get_output_partial_shape(0).is_static()) << "Shape " << f_ref->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertNMST3oNMSIE) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto boxes = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = opset1::Constant::create(element::i32, Shape{}, {10});
|
||||
auto iou_threshold = opset1::Constant::create(element::f32, Shape{}, {0.75});
|
||||
auto score_threshold = opset1::Constant::create(element::f32, Shape{}, {0.7});
|
||||
auto nms = std::make_shared<opset3::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, opset3::NonMaxSuppression::BoxEncodingType::CORNER, true, element::i32);
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||
|
||||
const auto & orig_shape = f->get_output_partial_shape(0);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
{
|
||||
auto boxes = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = opset1::Constant::create(element::i32, Shape{1}, {10});
|
||||
auto iou_threshold = opset1::Constant::create(element::f32, Shape{1}, {0.75});
|
||||
auto score_threshold = opset1::Constant::create(element::f32, Shape{1}, {0.7});
|
||||
auto nms = std::make_shared<op::NonMaxSuppressionIE3>(boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, 0, true, element::i32);
|
||||
|
||||
f_ref = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||
ASSERT_TRUE(f_ref->get_output_partial_shape(0).is_static()) << "Shape " << f_ref->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertNMST4oNMSIE) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto boxes = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = opset1::Constant::create(element::i32, Shape{}, {10});
|
||||
auto iou_threshold = opset1::Constant::create(element::f32, Shape{}, {0.75});
|
||||
auto score_threshold = opset1::Constant::create(element::f32, Shape{}, {0.7});
|
||||
auto nms = std::make_shared<opset4::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, opset4::NonMaxSuppression::BoxEncodingType::CORNER, true, element::i32);
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||
|
||||
const auto & orig_shape = f->get_output_partial_shape(0);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
{
|
||||
auto boxes = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<opset1::Parameter>(element::f32, Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = opset1::Constant::create(element::i32, Shape{1}, {10});
|
||||
auto iou_threshold = opset1::Constant::create(element::f32, Shape{1}, {0.75});
|
||||
auto score_threshold = opset1::Constant::create(element::f32, Shape{1}, {0.7});
|
||||
auto nms = std::make_shared<op::NonMaxSuppressionIE3>(boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, 0, true, element::i32);
|
||||
|
||||
f_ref = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
|
||||
ASSERT_TRUE(f_ref->get_output_partial_shape(0).is_static()) << "Shape " << f_ref->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user