Parametrize NMS5ToLegacy conversion to avoid Convert operatoins insertion that breaks outputs naming (#3480)

This commit is contained in:
Gleb Kazantaev 2020-12-10 14:07:20 +03:00 committed by GitHub
parent 8213505e24
commit 4d81bd9e0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 6 deletions

View File

@ -340,7 +340,7 @@ CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>&
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertToLegacy");
::ngraph::pass::Manager manager;
// resolves dynamism by replacing dynamic operation with static version
manager.register_pass<::ngraph::pass::ConvertNMS5ToLegacyMatcher>();
manager.register_pass<::ngraph::pass::ConvertNMS5ToLegacyMatcher>(false);
manager.register_pass<::ngraph::pass::ConstantFolding>();
// OneHotToLegacy changes output precision
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(

View File

@ -28,6 +28,6 @@ class INFERENCE_ENGINE_API_CLASS(ConvertNMS5ToLegacyMatcher);
class ngraph::pass::ConvertNMS5ToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertNMS5ToLegacyMatcher();
ConvertNMS5ToLegacyMatcher(bool force_i32_output_type = true);
};

View File

@ -17,10 +17,10 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNMS5ToLegacyMatcher, "ConvertNMS5ToLegacyMatcher", 0);
ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher(bool force_i32_output_type) {
auto nms = ngraph::pattern::wrap_type<ngraph::opset5::NonMaxSuppression>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
ngraph::matcher_pass_callback callback = [force_i32_output_type](pattern::Matcher &m) {
auto nms_5 = std::dynamic_pointer_cast<ngraph::opset5::NonMaxSuppression>(m.get_match_root());
if (!nms_5) {
return false;
@ -72,6 +72,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
std::shared_ptr<op::NonMaxSuppressionIE3> nms_legacy{nullptr};
auto output_type = force_i32_output_type ? element::i32 : nms_5->get_output_type();
if (num_of_inputs > 5 && nms_5->soft_nms_sigma_from_input() != 0.0f) {
new_soft_nms_sigma = std::make_shared<opset1::Reshape>(new_args.at(5), new_shape_for_soft_nms_sigma, true);
new_ops.emplace_back(new_soft_nms_sigma.get_node_shared_ptr());
@ -84,7 +85,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
new_soft_nms_sigma,
center_point_box,
nms_5->get_sort_result_descending(),
element::i32);
output_type);
new_ops.push_back(nms_legacy);
} else {
nms_legacy = std::make_shared<op::NonMaxSuppressionIE3>(
@ -95,7 +96,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
new_score_threshold,
center_point_box,
nms_5->get_sort_result_descending(),
element::i32);
output_type);
new_ops.push_back(nms_legacy);
}

View File

@ -23,6 +23,7 @@
#include <ngraph/pass/manager.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/function.hpp>
#include <ngraph/variant.hpp>
#include <ngraph/op/maximum.hpp>
@ -44,6 +45,28 @@ using namespace InferenceEngine;
IE_SUPPRESS_DEPRECATED_START
TEST(CNNNGraphImplTests, TestNMS5OutputNames) {
std::shared_ptr<ngraph::Function> f;
{
auto boxes = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1000, 4});
auto scores = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1000});
auto max_output_boxes_per_class = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
auto iou_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.75});
auto score_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.7});
auto nms = std::make_shared<ngraph::opset5::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold,
ngraph::opset5::NonMaxSuppression::BoxEncodingType::CORNER, true);
nms->set_friendly_name("nms");
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{nms->output(0), nms->output(1), nms->output(2)}, ngraph::ParameterVector{boxes, scores});
}
InferenceEngine::CNNNetwork cnnNet(f);
auto outputs_info = cnnNet.getOutputsInfo();
ASSERT_EQ(outputs_info.size(), 3);
ASSERT_EQ(outputs_info.count("nms.0"), 1);
ASSERT_EQ(outputs_info.count("nms.1"), 1);
ASSERT_EQ(outputs_info.count("nms.2"), 1);
}
TEST(CNNNGraphImplTests, TestConvertWithRemoveLastLayerNetwork) {
std::shared_ptr<ngraph::Function> ngraph;
{