[IE][VPU] Fix NMS DTS (#2880)
Add a new constructor to fix absent NMS-5 inputs that will be introduced after #2450 will be merged.
This commit is contained in:
parent
935ac543ac
commit
e758d2b325
@ -6,6 +6,7 @@
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
#include <legacy/ngraph_ops/nms_ie.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -17,6 +18,8 @@ public:
|
||||
static constexpr NodeTypeInfo type_info{"StaticShapeNonMaxSuppression", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
|
||||
explicit StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms);
|
||||
|
||||
StaticShapeNonMaxSuppression(const Output<Node>& boxes,
|
||||
const Output<Node>& scores,
|
||||
const Output<Node>& maxOutputBoxesPerClass,
|
||||
|
@ -12,6 +12,18 @@ namespace ngraph { namespace vpu { namespace op {
|
||||
|
||||
constexpr NodeTypeInfo StaticShapeNonMaxSuppression::type_info;
|
||||
|
||||
StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms)
|
||||
: StaticShapeNonMaxSuppression(
|
||||
nms.input_value(0),
|
||||
nms.input_value(1),
|
||||
nms.get_input_size() > 2 ? nms.input_value(2) : ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}),
|
||||
nms.get_input_size() > 3 ? nms.input_value(3) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
|
||||
nms.get_input_size() > 4 ? nms.input_value(4) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
|
||||
nms.get_input_size() > 5 ? nms.input_value(5) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
|
||||
nms.get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
|
||||
nms.get_sort_result_descending(),
|
||||
nms.get_output_type()) {}
|
||||
|
||||
StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(
|
||||
const Output<Node>& boxes,
|
||||
const Output<Node>& scores,
|
||||
|
@ -21,16 +21,7 @@ void dynamicToStaticNonMaxSuppression(std::shared_ptr<ngraph::Node> node) {
|
||||
VPU_THROW_UNLESS(nms, "dynamicToStaticNonMaxSuppression transformation for {} of type {} expects {} as node for replacement",
|
||||
node->get_friendly_name(), node->get_type_info(), ngraph::opset5::NonMaxSuppression::type_info);
|
||||
|
||||
auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(
|
||||
nms->input_value(0),
|
||||
nms->input_value(1),
|
||||
nms->input_value(2),
|
||||
nms->input_value(3),
|
||||
nms->input_value(4),
|
||||
nms->input_value(5),
|
||||
nms->get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
|
||||
nms->get_sort_result_descending(),
|
||||
nms->get_output_type());
|
||||
auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(*nms);
|
||||
|
||||
auto dsrIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
|
||||
staticShapeNMS->output(0), staticShapeNMS->output(2));
|
||||
|
Loading…
Reference in New Issue
Block a user