Fix ConvertPrecision for NMS5 (#7778)
* Fix FP32 to FP16 convert precision for NMS * Add tests * Code style * Codestyle
This commit is contained in:
parent
6a97decf67
commit
d4e4e8d1e8
@ -30,7 +30,8 @@ public:
|
||||
const Output<Node>& score_threshold,
|
||||
int center_point_box,
|
||||
bool sort_result_descending,
|
||||
const ngraph::element::Type& output_type = ngraph::element::i64);
|
||||
const ngraph::element::Type& output_type = ngraph::element::i64,
|
||||
const ngraph::element::Type& score_output_type = ngraph::element::f32);
|
||||
|
||||
NonMaxSuppressionIEInternal(const Output<Node>& boxes,
|
||||
const Output<Node>& scores,
|
||||
@ -40,7 +41,8 @@ public:
|
||||
const Output<Node>& soft_nms_sigma,
|
||||
int center_point_box,
|
||||
bool sort_result_descending,
|
||||
const ngraph::element::Type& output_type = ngraph::element::i64);
|
||||
const ngraph::element::Type& output_type = ngraph::element::i64,
|
||||
const ngraph::element::Type& score_output_type = ngraph::element::f32);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
@ -51,6 +53,7 @@ public:
|
||||
int m_center_point_box;
|
||||
bool m_sort_result_descending = true;
|
||||
element::Type m_output_type;
|
||||
element::Type m_scores_output_type;
|
||||
|
||||
private:
|
||||
int64_t max_boxes_output_from_input() const;
|
||||
|
@ -116,7 +116,7 @@ protected:
|
||||
for (size_t i = 0; i < node.get_output_size(); ++i) {
|
||||
auto overridden_output_type = get_overridden_output_type(i);
|
||||
if (overridden_output_type != element::undefined) {
|
||||
node.set_output_type(0, overridden_output_type, node.get_output_partial_shape(i));
|
||||
node.set_output_type(i, overridden_output_type, node.get_output_partial_shape(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,9 +20,11 @@ op::internal::NonMaxSuppressionIEInternal::NonMaxSuppressionIEInternal(const Out
|
||||
const Output<Node>& score_threshold,
|
||||
int center_point_box,
|
||||
bool sort_result_descending,
|
||||
const ngraph::element::Type& output_type)
|
||||
const ngraph::element::Type& output_type,
|
||||
const ngraph::element::Type& score_output_type)
|
||||
: Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}),
|
||||
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type) {
|
||||
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type),
|
||||
m_scores_output_type(score_output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -34,9 +36,11 @@ op::internal::NonMaxSuppressionIEInternal::NonMaxSuppressionIEInternal(const Out
|
||||
const Output<Node>& soft_nms_sigma,
|
||||
int center_point_box,
|
||||
bool sort_result_descending,
|
||||
const ngraph::element::Type& output_type)
|
||||
const ngraph::element::Type& output_type,
|
||||
const ngraph::element::Type& score_output_type)
|
||||
: Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, soft_nms_sigma}),
|
||||
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type) {
|
||||
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type),
|
||||
m_scores_output_type(score_output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -59,6 +63,7 @@ bool op::internal::NonMaxSuppressionIEInternal::visit_attributes(AttributeVisito
|
||||
visitor.on_attribute("center_point_box", m_center_point_box);
|
||||
visitor.on_attribute("sort_result_descending", m_sort_result_descending);
|
||||
visitor.on_attribute("output_type", m_output_type);
|
||||
visitor.on_attribute("score_output_type", m_scores_output_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -105,6 +110,6 @@ void op::internal::NonMaxSuppressionIEInternal::validate_and_infer_types() {
|
||||
}
|
||||
|
||||
set_output_type(0, m_output_type, out_shape);
|
||||
set_output_type(1, element::f32, out_shape);
|
||||
set_output_type(1, m_scores_output_type, out_shape);
|
||||
set_output_type(2, m_output_type, Shape{1});
|
||||
}
|
||||
|
@ -86,7 +86,8 @@ ngraph::pass::ConvertNMSToNMSIEInternal::ConvertNMSToNMSIEInternal() {
|
||||
new_soft_nms_sigma,
|
||||
center_point_box,
|
||||
nms_5->get_sort_result_descending(),
|
||||
element::i32);
|
||||
element::i32,
|
||||
nms_5->get_output_element_type(1));
|
||||
new_ops.push_back(nms_legacy);
|
||||
} else {
|
||||
nms_legacy = std::make_shared<op::internal::NonMaxSuppressionIEInternal>(
|
||||
@ -97,7 +98,8 @@ ngraph::pass::ConvertNMSToNMSIEInternal::ConvertNMSToNMSIEInternal() {
|
||||
new_score_threshold,
|
||||
center_point_box,
|
||||
nms_5->get_sort_result_descending(),
|
||||
element::i32);
|
||||
element::i32,
|
||||
nms_5->get_output_element_type(1));
|
||||
new_ops.push_back(nms_legacy);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
@ -96,6 +97,34 @@ TEST(TransformationTests, ConvertPrecision_NMS4) {
|
||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_NMS5) {
|
||||
std::shared_ptr<ngraph::Function> f;
|
||||
{
|
||||
auto boxes = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1000, 4});
|
||||
auto scores = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1000});
|
||||
auto max_output_boxes_per_class = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
|
||||
auto iou_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.75});
|
||||
auto score_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.7});
|
||||
auto nms = std::make_shared<ngraph::opset5::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold,
|
||||
ngraph::opset5::NonMaxSuppression::BoxEncodingType::CORNER, true);
|
||||
|
||||
auto result1 = std::make_shared<ngraph::opset5::Result>(nms->output(0));
|
||||
auto result2 = std::make_shared<ngraph::opset5::Result>(nms->output(1));
|
||||
auto result3 = std::make_shared<ngraph::opset5::Result>(nms->output(2));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::ResultVector{result1, result2, result3}, ngraph::ParameterVector{boxes, scores});
|
||||
}
|
||||
|
||||
pass::Manager manager;
|
||||
static const precisions_array precisions = {
|
||||
{ ngraph::element::i64, ngraph::element::i32 },
|
||||
{ ngraph::element::f32, ngraph::element::f16 }
|
||||
};
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::f32>(f));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ShapeOf) {
|
||||
std::shared_ptr<Function> f(nullptr);
|
||||
{
|
||||
|
@ -182,6 +182,27 @@ TEST_F(TypeRelaxedTests, notSupportedTypeOverridePartially) {
|
||||
ASSERT_EQ(4, ngraph->get_ops().size());
|
||||
}
|
||||
|
||||
TEST_F(TypeRelaxedTests, multiOutputTypeOverride) {
|
||||
auto overriden_type = element::f16;
|
||||
auto orig_type = element::f32;
|
||||
std::shared_ptr<ngraph::Function> ngraph;
|
||||
{
|
||||
ngraph::PartialShape shape({1, 3, 22, 22});
|
||||
auto param1 = make_shared<ngraph::opset1::Parameter>(orig_type, shape);
|
||||
auto op = ngraph::opset1::Split(param1, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}), 3);
|
||||
auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Split>>(
|
||||
op, TypeVector{}, TypeVector{overriden_type, overriden_type, overriden_type});
|
||||
auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
|
||||
|
||||
ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1});
|
||||
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(i));
|
||||
ASSERT_EQ(ngraph::Shape({1, 1, 22, 22}), relaxed_op->get_output_shape(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TypeRelaxedTests, setGetTypes) {
|
||||
std::shared_ptr<ngraph::Function> ngraph;
|
||||
{
|
||||
|
@ -331,7 +331,11 @@ bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, element::Ty
|
||||
|
||||
bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
if (auto nms = ov::as_type_ptr<opset3::NonMaxSuppression>(node)) {
|
||||
nms->set_output_type(to);
|
||||
if (to == element::i32 || to == element::i64) {
|
||||
nms->set_output_type(to);
|
||||
} else {
|
||||
throw ngraph_error("Type: " + to.get_type_name() + " is not supported for NMS3");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -339,18 +343,41 @@ bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::elemen
|
||||
|
||||
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
if (auto nms = ov::as_type_ptr<opset4::NonMaxSuppression>(node)) {
|
||||
nms->set_output_type(to);
|
||||
if (to == element::i32 || to == element::i64) {
|
||||
nms->set_output_type(to);
|
||||
} else {
|
||||
throw ngraph_error("Type: " + to.get_type_name() + " is not supported for NMS4");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_nms5(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
if (auto nms = ov::as_type_ptr<opset5::NonMaxSuppression>(node)) {
|
||||
auto nms = ov::as_type_ptr<opset5::NonMaxSuppression>(node);
|
||||
if (!nms) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((idx == 0 || idx == 2) && (to == element::i32 || to == element::i64)) {
|
||||
nms->set_output_type(to);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
|
||||
type_relaxed->set_overridden_output_type(to, idx);
|
||||
return true;
|
||||
}
|
||||
|
||||
element::TypeVector output_types;
|
||||
for (const auto& output : nms->outputs()) {
|
||||
output_types.emplace_back(output.get_element_type());
|
||||
}
|
||||
output_types[idx] = to;
|
||||
auto relaxed_op =
|
||||
std::make_shared<ngraph::op::TypeRelaxed<opset5::NonMaxSuppression>>(*nms, element::TypeVector{}, output_types);
|
||||
replace_node(node, relaxed_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
|
||||
|
Loading…
Reference in New Issue
Block a user