nGraph implementation of NMS-5 (without evaluate()) (#2651)

* Written nGraph NMS-5 without evaluate().

* Used NGRAPH_RTTI_DECLARATION.
This commit is contained in:
Vladimir Gavrilov
2020-10-14 16:47:43 +03:00
committed by GitHub
parent d86019d104
commit d277334028
6 changed files with 817 additions and 2 deletions

View File

@@ -58,5 +58,33 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector & new_args) const override;
};
class INFERENCE_ENGINE_API_CLASS(NonMaxSuppressionIE3) : public Op {
public:
NGRAPH_RTTI_DECLARATION;
NonMaxSuppressionIE3(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const Output<Node>& soft_nms_sigma,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type = ngraph::element::i64);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector & new_args) const override;
int m_center_point_box;
bool m_sort_result_descending = true;
element::Type m_output_type;
private:
int64_t max_boxes_output_from_input() const;
};
} // namespace op
} // namespace ngraph

View File

@@ -101,3 +101,74 @@ void op::NonMaxSuppressionIE2::validate_and_infer_types() {
m_output_type);
set_output_type(0, nms->output(0).get_element_type(), nms->output(0).get_partial_shape());
}
NGRAPH_RTTI_DEFINITION(op::NonMaxSuppressionIE3, "NonMaxSuppressionIE", 3);
op::NonMaxSuppressionIE3::NonMaxSuppressionIE3(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const Output<Node>& soft_nms_sigma,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type)
: Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, soft_nms_sigma}),
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type) {
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::NonMaxSuppressionIE3::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
check_new_args_count(this, new_args);
return make_shared<NonMaxSuppressionIE3>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
new_args.at(4), new_args.at(5), m_center_point_box, m_sort_result_descending,
m_output_type);
}
bool op::NonMaxSuppressionIE3::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("center_point_box", m_center_point_box);
visitor.on_attribute("sort_result_descending", m_sort_result_descending);
visitor.on_attribute("output_type", m_output_type);
return true;
}
static constexpr size_t boxes_port = 0;
static constexpr size_t scores_port = 1;
static constexpr size_t max_output_boxes_per_class_port = 2;
int64_t op::NonMaxSuppressionIE3::max_boxes_output_from_input() const {
int64_t max_output_boxes{0};
const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
max_output_boxes = max_output_boxes_input->cast_vector<int64_t>().at(0);
return max_output_boxes;
}
void op::NonMaxSuppressionIE3::validate_and_infer_types() {
const auto boxes_ps = get_input_partial_shape(boxes_port);
const auto scores_ps = get_input_partial_shape(scores_port);
// NonMaxSuppression produces triplets
// that have the following format: [batch_index, class_index, box_index]
PartialShape out_shape = {Dimension::dynamic(), 3};
if (boxes_ps.rank().is_static() && scores_ps.rank().is_static()) {
const auto num_boxes_boxes = boxes_ps[1];
const auto max_output_boxes_per_class_node = input_value(max_output_boxes_per_class_port).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() &&
op::is_constant(max_output_boxes_per_class_node)) {
const auto num_boxes = num_boxes_boxes.get_length();
const auto num_classes = scores_ps[1].get_length();
const auto max_output_boxes_per_class = max_boxes_output_from_input();
out_shape[0] = std::min(num_boxes, max_output_boxes_per_class) * num_classes *
scores_ps[0].get_length();
}
}
set_output_type(0, m_output_type, out_shape);
set_output_type(1, element::f32, out_shape);
set_output_type(2, m_output_type, Shape{1});
}

View File

@@ -235,6 +235,156 @@ namespace ngraph
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v4
namespace v5
{
/// \brief NonMaxSuppression operation
///
class NGRAPH_API NonMaxSuppression : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
enum class BoxEncodingType
{
CORNER,
CENTER
};
NonMaxSuppression() = default;
/// \brief Constructs a NonMaxSuppression operation with default values in the last
/// 4 inputs.
///
/// \param boxes Node producing the box coordinates
/// \param scores Node producing the box scores
/// \param box_encoding Specifies the format of boxes data encoding
/// \param sort_result_descending Specifies whether it is necessary to sort selected
/// boxes across batches
/// \param output_type Specifies the output tensor type
NonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true,
const ngraph::element::Type& output_type = ngraph::element::i64);
/// \brief Constructs a NonMaxSuppression operation with default values in the last.
/// 3 inputs.
///
/// \param boxes Node producing the box coordinates
/// \param scores Node producing the box scores
/// \param max_output_boxes_per_class Node producing maximum number of boxes to be
/// selected per class
/// \param box_encoding Specifies the format of boxes data encoding
/// \param sort_result_descending Specifies whether it is necessary to sort selected
/// boxes across batches
/// \param output_type Specifies the output tensor type
NonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true,
const ngraph::element::Type& output_type = ngraph::element::i64);
/// \brief Constructs a NonMaxSuppression operation with default values in the last.
/// 2 inputs.
///
/// \param boxes Node producing the box coordinates
/// \param scores Node producing the box scores
/// \param max_output_boxes_per_class Node producing maximum number of boxes to be
/// selected per class
/// \param iou_threshold Node producing intersection over union threshold
/// \param box_encoding Specifies the format of boxes data encoding
/// \param sort_result_descending Specifies whether it is necessary to sort selected
/// boxes across batches
/// \param output_type Specifies the output tensor type
NonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true,
const ngraph::element::Type& output_type = ngraph::element::i64);
/// \brief Constructs a NonMaxSuppression operation with default value in the last.
/// input.
///
/// \param boxes Node producing the box coordinates
/// \param scores Node producing the box scores
/// \param max_output_boxes_per_class Node producing maximum number of boxes to be
/// selected per class
/// \param iou_threshold Node producing intersection over union threshold
/// \param score_threshold Node producing minimum score threshold
/// \param box_encoding Specifies the format of boxes data encoding
/// \param sort_result_descending Specifies whether it is necessary to sort selected
/// boxes across batches
/// \param output_type Specifies the output tensor type
NonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true,
const ngraph::element::Type& output_type = ngraph::element::i64);
/// \brief Constructs a NonMaxSuppression operation.
///
/// \param boxes Node producing the box coordinates
/// \param scores Node producing the box scores
/// \param max_output_boxes_per_class Node producing maximum number of boxes to be
/// selected per class
/// \param iou_threshold Node producing intersection over union threshold
/// \param score_threshold Node producing minimum score threshold
/// \param soft_nms_sigma Node specifying the sigma parameter for Soft-NMS
/// \param box_encoding Specifies the format of boxes data encoding
/// \param sort_result_descending Specifies whether it is necessary to sort selected
/// boxes across batches
/// \param output_type Specifies the output tensor type
NonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const Output<Node>& soft_nms_sigma,
const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true,
const ngraph::element::Type& output_type = ngraph::element::i64);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
BoxEncodingType get_box_encoding() const { return m_box_encoding; }
void set_box_encoding(const BoxEncodingType box_encoding)
{
m_box_encoding = box_encoding;
}
bool get_sort_result_descending() const { return m_sort_result_descending; }
void set_sort_result_descending(const bool sort_result_descending)
{
m_sort_result_descending = sort_result_descending;
}
element::Type get_output_type() const { return m_output_type; }
void set_output_type(const element::Type& output_type)
{
m_output_type = output_type;
}
using Node::set_output_type;
protected:
BoxEncodingType m_box_encoding = BoxEncodingType::CORNER;
bool m_sort_result_descending = true;
ngraph::element::Type m_output_type = ngraph::element::i64;
void validate();
int64_t max_boxes_output_from_input() const;
float iou_threshold_from_input() const;
float score_threshold_from_input() const;
float soft_nms_sigma_from_input() const;
};
} // namespace v5
} // namespace op
NGRAPH_API
@@ -274,4 +424,23 @@ namespace ngraph
"AttributeAdapter<op::v3::NonMaxSuppression::BoxEncodingType>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
NGRAPH_API
std::ostream& operator<<(std::ostream& s,
const op::v5::NonMaxSuppression::BoxEncodingType& type);
template <>
class NGRAPH_API AttributeAdapter<op::v5::NonMaxSuppression::BoxEncodingType>
: public EnumAttributeAdapterBase<op::v5::NonMaxSuppression::BoxEncodingType>
{
public:
AttributeAdapter(op::v5::NonMaxSuppression::BoxEncodingType& value)
: EnumAttributeAdapterBase<op::v5::NonMaxSuppression::BoxEncodingType>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v5::NonMaxSuppression::BoxEncodingType>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph

View File

@@ -157,7 +157,6 @@ NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(HSwish, ngraph::op::v4)
NGRAPH_OP(Interpolate, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
NGRAPH_OP(ReduceL1, ngraph::op::v4)
NGRAPH_OP(ReduceL2, ngraph::op::v4)
NGRAPH_OP(SoftPlus, ngraph::op::v4)
@@ -167,6 +166,7 @@ NGRAPH_OP(Swish, ngraph::op::v4)
NGRAPH_OP(GatherND, ngraph::op::v5)
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v5)
NGRAPH_OP(GRUSequence, ngraph::op::v5)
NGRAPH_OP(RNNSequence, ngraph::op::v5)
NGRAPH_OP(Round, ngraph::op::v5)

View File

@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/non_max_suppression.hpp"
#include <cstring>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
@@ -530,3 +531,328 @@ void op::v4::NonMaxSuppression::validate_and_infer_types()
}
set_output_type(0, m_output_type, out_shape);
}
// ------------------------------ V5 ------------------------------
NGRAPH_RTTI_DEFINITION(op::v5::NonMaxSuppression, "NonMaxSuppression", 5);
op::v5::NonMaxSuppression::NonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
const op::v5::NonMaxSuppression::BoxEncodingType box_encoding,
const bool sort_result_descending,
const element::Type& output_type)
: Op({boxes,
scores,
op::Constant::create(element::i64, Shape{}, {0}),
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f})})
, m_box_encoding{box_encoding}
, m_sort_result_descending{sort_result_descending}
, m_output_type{output_type}
{
constructor_validate_and_infer_types();
}
op::v5::NonMaxSuppression::NonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const op::v5::NonMaxSuppression::BoxEncodingType box_encoding,
const bool sort_result_descending,
const element::Type& output_type)
: Op({boxes,
scores,
max_output_boxes_per_class,
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f})})
, m_box_encoding{box_encoding}
, m_sort_result_descending{sort_result_descending}
, m_output_type{output_type}
{
constructor_validate_and_infer_types();
}
op::v5::NonMaxSuppression::NonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const op::v5::NonMaxSuppression::BoxEncodingType box_encoding,
const bool sort_result_descending,
const element::Type& output_type)
: Op({boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f})})
, m_box_encoding{box_encoding}
, m_sort_result_descending{sort_result_descending}
, m_output_type{output_type}
{
constructor_validate_and_infer_types();
}
op::v5::NonMaxSuppression::NonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const op::v5::NonMaxSuppression::BoxEncodingType box_encoding,
const bool sort_result_descending,
const element::Type& output_type)
: Op({boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
op::Constant::create(element::f32, Shape{}, {.0f})})
, m_box_encoding{box_encoding}
, m_sort_result_descending{sort_result_descending}
, m_output_type{output_type}
{
constructor_validate_and_infer_types();
}
op::v5::NonMaxSuppression::NonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const Output<Node>& soft_nms_sigma,
const op::v5::NonMaxSuppression::BoxEncodingType box_encoding,
const bool sort_result_descending,
const element::Type& output_type)
: Op({boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
soft_nms_sigma})
, m_box_encoding{box_encoding}
, m_sort_result_descending{sort_result_descending}
, m_output_type{output_type}
{
constructor_validate_and_infer_types();
}
shared_ptr<Node>
op::v5::NonMaxSuppression::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
NODE_VALIDATION_CHECK(this,
new_args.size() >= 2 && new_args.size() <= 6,
"Number of inputs must be 2, 3, 4, 5 or 6");
const auto& arg2 = new_args.size() > 2
? new_args.at(2)
: ngraph::op::Constant::create(element::i64, Shape{}, {0});
const auto& arg3 = new_args.size() > 3
? new_args.at(3)
: ngraph::op::Constant::create(element::f32, Shape{}, {.0f});
const auto& arg4 = new_args.size() > 4
? new_args.at(4)
: ngraph::op::Constant::create(element::f32, Shape{}, {.0f});
const auto& arg5 = new_args.size() > 5
? new_args.at(5)
: ngraph::op::Constant::create(element::f32, Shape{}, {.0f});
return std::make_shared<op::v5::NonMaxSuppression>(new_args.at(0),
new_args.at(1),
arg2,
arg3,
arg4,
arg5,
m_box_encoding,
m_sort_result_descending,
m_output_type);
}
void op::v5::NonMaxSuppression::validate()
{
const auto boxes_ps = get_input_partial_shape(0);
const auto scores_ps = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
m_output_type == element::i64 || m_output_type == element::i32,
"Output type must be i32 or i64");
if (boxes_ps.is_dynamic() || scores_ps.is_dynamic())
{
return;
}
NODE_VALIDATION_CHECK(this,
boxes_ps.rank().is_static() && boxes_ps.rank().get_length() == 3,
"Expected a 3D tensor for the 'boxes' input. Got: ",
boxes_ps);
NODE_VALIDATION_CHECK(this,
scores_ps.rank().is_static() && scores_ps.rank().get_length() == 3,
"Expected a 3D tensor for the 'scores' input. Got: ",
scores_ps);
if (inputs().size() >= 3)
{
const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);
}
if (inputs().size() >= 4)
{
const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() ||
is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
}
if (inputs().size() >= 5)
{
const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
}
if (inputs().size() >= 6)
{
const auto soft_nms_sigma = get_input_partial_shape(5);
NODE_VALIDATION_CHECK(this,
soft_nms_sigma.is_dynamic() || is_scalar(soft_nms_sigma.to_shape()),
"Expected a scalar for the 'soft_nms_sigma' input. Got: ",
soft_nms_sigma);
}
const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
NODE_VALIDATION_CHECK(this,
num_batches_boxes.same_scheme(num_batches_scores),
"The first dimension of both 'boxes' and 'scores' must match. Boxes: ",
num_batches_boxes,
"; Scores: ",
num_batches_scores);
const auto num_boxes_boxes = boxes_ps[1];
const auto num_boxes_scores = scores_ps[2];
NODE_VALIDATION_CHECK(this,
num_boxes_boxes.same_scheme(num_boxes_scores),
"'boxes' and 'scores' input shapes must match at the second and third "
"dimension respectively. Boxes: ",
num_boxes_boxes,
"; Scores: ",
num_boxes_scores);
NODE_VALIDATION_CHECK(this,
boxes_ps[2].is_static() && boxes_ps[2].get_length() == 4u,
"The last dimension of the 'boxes' input must be equal to 4. Got:",
boxes_ps[2]);
}
int64_t op::v5::NonMaxSuppression::max_boxes_output_from_input() const
{
int64_t max_output_boxes{0};
const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
max_output_boxes = max_output_boxes_input->cast_vector<int64_t>().at(0);
return max_output_boxes;
}
static constexpr size_t boxes_port = 0;
static constexpr size_t scores_port = 1;
static constexpr size_t iou_threshold_port = 3;
static constexpr size_t score_threshold_port = 4;
static constexpr size_t soft_nms_sigma_port = 5;
float op::v5::NonMaxSuppression::iou_threshold_from_input() const
{
float iou_threshold = 0.0f;
const auto iou_threshold_input =
as_type_ptr<op::Constant>(input_value(iou_threshold_port).get_node_shared_ptr());
iou_threshold = iou_threshold_input->cast_vector<float>().at(0);
return iou_threshold;
}
float op::v5::NonMaxSuppression::score_threshold_from_input() const
{
float score_threshold = 0.0f;
const auto score_threshold_input =
as_type_ptr<op::Constant>(input_value(score_threshold_port).get_node_shared_ptr());
score_threshold = score_threshold_input->cast_vector<float>().at(0);
return score_threshold;
}
float op::v5::NonMaxSuppression::soft_nms_sigma_from_input() const
{
float soft_nms_sigma = 0.0f;
const auto soft_nms_sigma_input =
as_type_ptr<op::Constant>(input_value(soft_nms_sigma_port).get_node_shared_ptr());
soft_nms_sigma = soft_nms_sigma_input->cast_vector<float>().at(0);
return soft_nms_sigma;
}
bool ngraph::op::v5::NonMaxSuppression::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("box_encoding", m_box_encoding);
visitor.on_attribute("sort_result_descending", m_sort_result_descending);
visitor.on_attribute("output_type", m_output_type);
return true;
}
void op::v5::NonMaxSuppression::validate_and_infer_types()
{
const auto boxes_ps = get_input_partial_shape(0);
const auto scores_ps = get_input_partial_shape(1);
// NonMaxSuppression produces triplets
// that have the following format: [batch_index, class_index, box_index]
PartialShape out_shape = {Dimension::dynamic(), 3};
validate();
set_output_type(0, m_output_type, out_shape);
set_output_type(1, element::f32, out_shape);
set_output_type(2, m_output_type, Shape{1});
}
namespace ngraph
{
template <>
EnumNames<op::v5::NonMaxSuppression::BoxEncodingType>&
EnumNames<op::v5::NonMaxSuppression::BoxEncodingType>::get()
{
static auto enum_names = EnumNames<op::v5::NonMaxSuppression::BoxEncodingType>(
"op::v5::NonMaxSuppression::BoxEncodingType",
{{"corner", op::v5::NonMaxSuppression::BoxEncodingType::CORNER},
{"center", op::v5::NonMaxSuppression::BoxEncodingType::CENTER}});
return enum_names;
}
constexpr DiscreteTypeInfo
AttributeAdapter<op::v5::NonMaxSuppression::BoxEncodingType>::type_info;
std::ostream& operator<<(std::ostream& s,
const op::v5::NonMaxSuppression::BoxEncodingType& type)
{
return s << as_string(type);
}
} // namespace ngraph

View File

@@ -547,3 +547,224 @@ TEST(type_prop, nms_v4_dynamic_boxes_and_scores)
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
}
// ------------------------------ V5 ------------------------------
TEST(type_prop, nms_v5_incorrect_boxes_rank)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
make_shared<op::v5::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a 3D tensor for the 'boxes' input");
}
}
TEST(type_prop, nms_v5_incorrect_scores_rank)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2});
make_shared<op::v5::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a 3D tensor for the 'scores' input");
}
}
TEST(type_prop, nms_v5_incorrect_scheme_num_batches)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3});
make_shared<op::v5::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"The first dimension of both 'boxes' and 'scores' must match");
}
}
TEST(type_prop, nms_v5_incorrect_scheme_num_boxes)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
make_shared<op::v5::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"'boxes' and 'scores' input shapes must match at the second and third "
"dimension respectively");
}
}
TEST(type_prop, nms_v5_scalar_inputs_check)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 2});
const auto scalar = make_shared<op::Parameter>(element::f32, Shape{});
const auto non_scalar = make_shared<op::Parameter>(element::f32, Shape{1});
try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, non_scalar, scalar, scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Expected a scalar for the 'max_output_boxes_per_class' input");
}
try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, non_scalar, scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'iou_threshold' input");
}
try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, scalar, non_scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'score_threshold' input");
}
try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, scalar, scalar, non_scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'soft_nms_sigma' input");
}
}
TEST(type_prop, nms_v5_output_shape)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{5, 2, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{5, 3, 2});
const auto nms = make_shared<op::v5::NonMaxSuppression>(boxes, scores);
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
}
TEST(type_prop, nms_v5_output_shape_2)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{2, 7, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{2, 5, 7});
const auto max_output_boxes_per_class = op::Constant::create(element::i32, Shape{}, {3});
const auto iou_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto score_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto nms = make_shared<op::v5::NonMaxSuppression>(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold);
ASSERT_EQ(nms->get_output_element_type(0), element::i64);
ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i64);
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
}
TEST(type_prop, nms_v5_output_shape_3)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{2, 7, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{2, 5, 7});
const auto max_output_boxes_per_class = op::Constant::create(element::i16, Shape{}, {1000});
const auto iou_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto score_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto nms = make_shared<op::v5::NonMaxSuppression>(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold);
ASSERT_EQ(nms->get_output_element_type(0), element::i64);
ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i64);
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
}
TEST(type_prop, nms_v5_output_shape_i32)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{2, 7, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{2, 5, 7});
const auto max_output_boxes_per_class = op::Constant::create(element::i16, Shape{}, {3});
const auto iou_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto score_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto nms =
make_shared<op::v5::NonMaxSuppression>(boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
op::v5::NonMaxSuppression::BoxEncodingType::CORNER,
true,
element::i32);
ASSERT_EQ(nms->get_output_element_type(0), element::i32);
ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i32);
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
}
TEST(type_prop, nms_v5_dynamic_boxes_and_scores)
{
const auto boxes = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto scores = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto max_output_boxes_per_class = op::Constant::create(element::i16, Shape{}, {3});
const auto iou_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto score_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto nms = make_shared<op::v5::NonMaxSuppression>(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold);
ASSERT_EQ(nms->get_output_element_type(0), element::i64);
ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i64);
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
}