Parametrize NMS5ToLegacy conversion to avoid Convert operatoins insertion that breaks outputs naming (#3480)
This commit is contained in:
parent
8213505e24
commit
4d81bd9e0e
@ -340,7 +340,7 @@ CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>&
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertToLegacy");
|
||||
::ngraph::pass::Manager manager;
|
||||
// resolves dynamism by replacing dynamic operation with static version
|
||||
manager.register_pass<::ngraph::pass::ConvertNMS5ToLegacyMatcher>();
|
||||
manager.register_pass<::ngraph::pass::ConvertNMS5ToLegacyMatcher>(false);
|
||||
manager.register_pass<::ngraph::pass::ConstantFolding>();
|
||||
// OneHotToLegacy changes output precision
|
||||
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
|
||||
|
@ -28,6 +28,6 @@ class INFERENCE_ENGINE_API_CLASS(ConvertNMS5ToLegacyMatcher);
|
||||
class ngraph::pass::ConvertNMS5ToLegacyMatcher: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertNMS5ToLegacyMatcher();
|
||||
ConvertNMS5ToLegacyMatcher(bool force_i32_output_type = true);
|
||||
};
|
||||
|
||||
|
@ -17,10 +17,10 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNMS5ToLegacyMatcher, "ConvertNMS5ToLegacyMatcher", 0);
|
||||
|
||||
ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
|
||||
ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher(bool force_i32_output_type) {
|
||||
auto nms = ngraph::pattern::wrap_type<ngraph::opset5::NonMaxSuppression>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
ngraph::matcher_pass_callback callback = [force_i32_output_type](pattern::Matcher &m) {
|
||||
auto nms_5 = std::dynamic_pointer_cast<ngraph::opset5::NonMaxSuppression>(m.get_match_root());
|
||||
if (!nms_5) {
|
||||
return false;
|
||||
@ -72,6 +72,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
|
||||
|
||||
std::shared_ptr<op::NonMaxSuppressionIE3> nms_legacy{nullptr};
|
||||
|
||||
auto output_type = force_i32_output_type ? element::i32 : nms_5->get_output_type();
|
||||
if (num_of_inputs > 5 && nms_5->soft_nms_sigma_from_input() != 0.0f) {
|
||||
new_soft_nms_sigma = std::make_shared<opset1::Reshape>(new_args.at(5), new_shape_for_soft_nms_sigma, true);
|
||||
new_ops.emplace_back(new_soft_nms_sigma.get_node_shared_ptr());
|
||||
@ -84,7 +85,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
|
||||
new_soft_nms_sigma,
|
||||
center_point_box,
|
||||
nms_5->get_sort_result_descending(),
|
||||
element::i32);
|
||||
output_type);
|
||||
new_ops.push_back(nms_legacy);
|
||||
} else {
|
||||
nms_legacy = std::make_shared<op::NonMaxSuppressionIE3>(
|
||||
@ -95,7 +96,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
|
||||
new_score_threshold,
|
||||
center_point_box,
|
||||
nms_5->get_sort_result_descending(),
|
||||
element::i32);
|
||||
output_type);
|
||||
new_ops.push_back(nms_legacy);
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include <ngraph/op/maximum.hpp>
|
||||
@ -44,6 +45,28 @@ using namespace InferenceEngine;
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
TEST(CNNNGraphImplTests, TestNMS5OutputNames) {
|
||||
std::shared_ptr<ngraph::Function> f;
|
||||
{
|
||||
auto boxes = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
|
||||
auto iou_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.75});
|
||||
auto score_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.7});
|
||||
auto nms = std::make_shared<ngraph::opset5::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold,
|
||||
ngraph::opset5::NonMaxSuppression::BoxEncodingType::CORNER, true);
|
||||
nms->set_friendly_name("nms");
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{nms->output(0), nms->output(1), nms->output(2)}, ngraph::ParameterVector{boxes, scores});
|
||||
}
|
||||
|
||||
InferenceEngine::CNNNetwork cnnNet(f);
|
||||
auto outputs_info = cnnNet.getOutputsInfo();
|
||||
ASSERT_EQ(outputs_info.size(), 3);
|
||||
ASSERT_EQ(outputs_info.count("nms.0"), 1);
|
||||
ASSERT_EQ(outputs_info.count("nms.1"), 1);
|
||||
ASSERT_EQ(outputs_info.count("nms.2"), 1);
|
||||
}
|
||||
|
||||
TEST(CNNNGraphImplTests, TestConvertWithRemoveLastLayerNetwork) {
|
||||
std::shared_ptr<ngraph::Function> ngraph;
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user