[IE][VPU] Fix NMS DTS (#2880)

Add a new constructor to fix absent NMS-5 inputs that will be introduced after #2450 will be merged.
This commit is contained in:
Andrew Bakalin 2020-11-05 13:33:16 +03:00 committed by GitHub
parent 935ac543ac
commit e758d2b325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 10 deletions

View File

@ -6,6 +6,7 @@
#include <ngraph/node.hpp>
#include <legacy/ngraph_ops/nms_ie.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <memory>
#include <vector>
@ -17,6 +18,8 @@ public:
static constexpr NodeTypeInfo type_info{"StaticShapeNonMaxSuppression", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
explicit StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms);
StaticShapeNonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& maxOutputBoxesPerClass,

View File

@ -12,6 +12,18 @@ namespace ngraph { namespace vpu { namespace op {
constexpr NodeTypeInfo StaticShapeNonMaxSuppression::type_info;
StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms)
: StaticShapeNonMaxSuppression(
nms.input_value(0),
nms.input_value(1),
nms.get_input_size() > 2 ? nms.input_value(2) : ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}),
nms.get_input_size() > 3 ? nms.input_value(3) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
nms.get_input_size() > 4 ? nms.input_value(4) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
nms.get_input_size() > 5 ? nms.input_value(5) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
nms.get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
nms.get_sort_result_descending(),
nms.get_output_type()) {}
StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,

View File

@ -21,16 +21,7 @@ void dynamicToStaticNonMaxSuppression(std::shared_ptr<ngraph::Node> node) {
VPU_THROW_UNLESS(nms, "dynamicToStaticNonMaxSuppression transformation for {} of type {} expects {} as node for replacement",
node->get_friendly_name(), node->get_type_info(), ngraph::opset5::NonMaxSuppression::type_info);
auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(
nms->input_value(0),
nms->input_value(1),
nms->input_value(2),
nms->input_value(3),
nms->input_value(4),
nms->input_value(5),
nms->get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
nms->get_sort_result_descending(),
nms->get_output_type());
auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(*nms);
auto dsrIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
staticShapeNMS->output(0), staticShapeNMS->output(2));