Fix ConvertPrecision for NMS5 (#7778)

* Fix FP32 to FP16 convert precision for NMS

* Add tests

* Code style

* Codestyle
This commit is contained in:
Gleb Kazantaev 2021-10-04 11:52:53 +03:00 committed by GitHub
parent 6a97decf67
commit d4e4e8d1e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 101 additions and 14 deletions

View File

@ -30,7 +30,8 @@ public:
const Output<Node>& score_threshold,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type = ngraph::element::i64);
const ngraph::element::Type& output_type = ngraph::element::i64,
const ngraph::element::Type& score_output_type = ngraph::element::f32);
NonMaxSuppressionIEInternal(const Output<Node>& boxes,
const Output<Node>& scores,
@ -40,7 +41,8 @@ public:
const Output<Node>& soft_nms_sigma,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type = ngraph::element::i64);
const ngraph::element::Type& output_type = ngraph::element::i64,
const ngraph::element::Type& score_output_type = ngraph::element::f32);
void validate_and_infer_types() override;
@ -51,6 +53,7 @@ public:
int m_center_point_box;
bool m_sort_result_descending = true;
element::Type m_output_type;
element::Type m_scores_output_type;
private:
int64_t max_boxes_output_from_input() const;

View File

@ -116,7 +116,7 @@ protected:
for (size_t i = 0; i < node.get_output_size(); ++i) {
auto overridden_output_type = get_overridden_output_type(i);
if (overridden_output_type != element::undefined) {
node.set_output_type(0, overridden_output_type, node.get_output_partial_shape(i));
node.set_output_type(i, overridden_output_type, node.get_output_partial_shape(i));
}
}
}

View File

@ -20,9 +20,11 @@ op::internal::NonMaxSuppressionIEInternal::NonMaxSuppressionIEInternal(const Out
const Output<Node>& score_threshold,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type)
const ngraph::element::Type& output_type,
const ngraph::element::Type& score_output_type)
: Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}),
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type) {
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type),
m_scores_output_type(score_output_type) {
constructor_validate_and_infer_types();
}
@ -34,9 +36,11 @@ op::internal::NonMaxSuppressionIEInternal::NonMaxSuppressionIEInternal(const Out
const Output<Node>& soft_nms_sigma,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type)
const ngraph::element::Type& output_type,
const ngraph::element::Type& score_output_type)
: Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, soft_nms_sigma}),
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type) {
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type),
m_scores_output_type(score_output_type) {
constructor_validate_and_infer_types();
}
@ -59,6 +63,7 @@ bool op::internal::NonMaxSuppressionIEInternal::visit_attributes(AttributeVisito
visitor.on_attribute("center_point_box", m_center_point_box);
visitor.on_attribute("sort_result_descending", m_sort_result_descending);
visitor.on_attribute("output_type", m_output_type);
visitor.on_attribute("score_output_type", m_scores_output_type);
return true;
}
@ -105,6 +110,6 @@ void op::internal::NonMaxSuppressionIEInternal::validate_and_infer_types() {
}
set_output_type(0, m_output_type, out_shape);
set_output_type(1, element::f32, out_shape);
set_output_type(1, m_scores_output_type, out_shape);
set_output_type(2, m_output_type, Shape{1});
}

View File

@ -86,7 +86,8 @@ ngraph::pass::ConvertNMSToNMSIEInternal::ConvertNMSToNMSIEInternal() {
new_soft_nms_sigma,
center_point_box,
nms_5->get_sort_result_descending(),
element::i32);
element::i32,
nms_5->get_output_element_type(1));
new_ops.push_back(nms_legacy);
} else {
nms_legacy = std::make_shared<op::internal::NonMaxSuppressionIEInternal>(
@ -97,7 +98,8 @@ ngraph::pass::ConvertNMSToNMSIEInternal::ConvertNMSToNMSIEInternal() {
new_score_threshold,
center_point_box,
nms_5->get_sort_result_descending(),
element::i32);
element::i32,
nms_5->get_output_element_type(1));
new_ops.push_back(nms_legacy);
}

View File

@ -13,6 +13,7 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
@ -96,6 +97,34 @@ TEST(TransformationTests, ConvertPrecision_NMS4) {
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
}
TEST(TransformationTests, ConvertPrecision_NMS5) {
std::shared_ptr<ngraph::Function> f;
{
auto boxes = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1000, 4});
auto scores = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1000});
auto max_output_boxes_per_class = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
auto iou_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.75});
auto score_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.7});
auto nms = std::make_shared<ngraph::opset5::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold,
ngraph::opset5::NonMaxSuppression::BoxEncodingType::CORNER, true);
auto result1 = std::make_shared<ngraph::opset5::Result>(nms->output(0));
auto result2 = std::make_shared<ngraph::opset5::Result>(nms->output(1));
auto result3 = std::make_shared<ngraph::opset5::Result>(nms->output(2));
f = std::make_shared<ngraph::Function>(ngraph::ResultVector{result1, result2, result3}, ngraph::ParameterVector{boxes, scores});
}
pass::Manager manager;
static const precisions_array precisions = {
{ ngraph::element::i64, ngraph::element::i32 },
{ ngraph::element::f32, ngraph::element::f16 }
};
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
manager.run_passes(f);
ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(f));
ASSERT_FALSE(has_type<ngraph::element::Type_t::f32>(f));
}
TEST(TransformationTests, ConvertPrecision_ShapeOf) {
std::shared_ptr<Function> f(nullptr);
{

View File

@ -182,6 +182,27 @@ TEST_F(TypeRelaxedTests, notSupportedTypeOverridePartially) {
ASSERT_EQ(4, ngraph->get_ops().size());
}
TEST_F(TypeRelaxedTests, multiOutputTypeOverride) {
auto overriden_type = element::f16;
auto orig_type = element::f32;
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
auto param1 = make_shared<ngraph::opset1::Parameter>(orig_type, shape);
auto op = ngraph::opset1::Split(param1, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}), 3);
auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Split>>(
op, TypeVector{}, TypeVector{overriden_type, overriden_type, overriden_type});
auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1});
for (size_t i = 0; i < 3; ++i) {
ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(i));
ASSERT_EQ(ngraph::Shape({1, 1, 22, 22}), relaxed_op->get_output_shape(i));
}
}
}
TEST_F(TypeRelaxedTests, setGetTypes) {
std::shared_ptr<ngraph::Function> ngraph;
{

View File

@ -331,7 +331,11 @@ bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, element::Ty
bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
if (auto nms = ov::as_type_ptr<opset3::NonMaxSuppression>(node)) {
nms->set_output_type(to);
if (to == element::i32 || to == element::i64) {
nms->set_output_type(to);
} else {
throw ngraph_error("Type: " + to.get_type_name() + " is not supported for NMS3");
}
return true;
}
return false;
@ -339,18 +343,41 @@ bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::elemen
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
if (auto nms = ov::as_type_ptr<opset4::NonMaxSuppression>(node)) {
nms->set_output_type(to);
if (to == element::i32 || to == element::i64) {
nms->set_output_type(to);
} else {
throw ngraph_error("Type: " + to.get_type_name() + " is not supported for NMS4");
}
return true;
}
return false;
}
bool fuse_type_to_nms5(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
if (auto nms = ov::as_type_ptr<opset5::NonMaxSuppression>(node)) {
auto nms = ov::as_type_ptr<opset5::NonMaxSuppression>(node);
if (!nms) {
return false;
}
if ((idx == 0 || idx == 2) && (to == element::i32 || to == element::i64)) {
nms->set_output_type(to);
return true;
}
return false;
if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
type_relaxed->set_overridden_output_type(to, idx);
return true;
}
element::TypeVector output_types;
for (const auto& output : nms->outputs()) {
output_types.emplace_back(output.get_element_type());
}
output_types[idx] = to;
auto relaxed_op =
std::make_shared<ngraph::op::TypeRelaxed<opset5::NonMaxSuppression>>(*nms, element::TypeVector{}, output_types);
replace_node(node, relaxed_op);
return true;
}
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {