[IE][VPU]: Fixes for Yolo-V3 (#4517)

* Fix negative axis processing for StaticShapeTopK
* Preserve output names in MergeSubsequentDSROperations
* Do validate_and_infer_types every time it's called in StaticShape* operations. It's needed to infer the correct output type in case it was changed from the last call moment (e.g. the ConvertPrecision pass have been called)
This commit is contained in:
Andrew Bakalin 2021-03-03 16:47:42 +03:00 committed by GitHub
parent 7aed4ab3e7
commit 430adbc191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 69 additions and 52 deletions

View File

@ -37,6 +37,9 @@ public:
bool visit_attributes(ngraph::AttributeVisitor& visitor) override; bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override; bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
protected:
ngraph::PartialShape m_evaluatedOutputShape;
}; };
} // namespace op } // namespace op

View File

@ -16,6 +16,9 @@ public:
explicit StaticShapeLoop(const Loop& loop); explicit StaticShapeLoop(const Loop& loop);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor&) override; bool visit_attributes(AttributeVisitor&) override;
protected:
ngraph::PartialShape m_evaluatedIterationsCount;
}; };
} // namespace op } // namespace op

View File

@ -22,6 +22,9 @@ public:
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
protected:
ngraph::PartialShape m_evaluatedOutputShape;
}; };
} // namespace op } // namespace op

View File

@ -33,6 +33,9 @@ public:
const element::Type& index_element_type = element::i32); const element::Type& index_element_type = element::i32);
void validate_and_infer_types() override; void validate_and_infer_types() override;
protected:
ngraph::PartialShape m_evaluatedOutputShape;
}; };
} // namespace op } // namespace op

View File

@ -19,30 +19,32 @@ StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg,
const Output<Node>& targetShape, const Output<Node>& targetShape,
const Output<Node>& axesMapping, const Output<Node>& axesMapping,
const ngraph::op::BroadcastModeSpec& broadcastSpec) const ngraph::op::BroadcastModeSpec& broadcastSpec)
: ::ngraph::op::v3::Broadcast{arg, targetShape, axesMapping, broadcastSpec} { : ::ngraph::op::v3::Broadcast{arg, targetShape, axesMapping, broadcastSpec},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg, StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg,
const Output<Node>& targetShape, const Output<Node>& targetShape,
const ngraph::op::BroadcastModeSpec& broadcastSpec) const ngraph::op::BroadcastModeSpec& broadcastSpec)
: ::ngraph::op::v3::Broadcast{arg, targetShape, broadcastSpec} { : ::ngraph::op::v3::Broadcast{arg, targetShape, broadcastSpec},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void StaticShapeBroadcast::validate_and_infer_types() { void StaticShapeBroadcast::validate_and_infer_types() {
if (get_output_partial_shape(0).is_static()) { auto& outputShape = m_evaluatedOutputShape;
return; if (outputShape.is_dynamic()) {
::ngraph::op::v3::Broadcast::validate_and_infer_types();
outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeBroadcast (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
}
} }
::ngraph::op::v3::Broadcast::validate_and_infer_types();
auto outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeBroadcast (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
}
NODE_VALIDATION_CHECK(this, outputShape.is_static(), NODE_VALIDATION_CHECK(this, outputShape.is_static(),
"StaticShapeBroadcast (", get_friendly_name(), ") can't evaluate output shape"); "StaticShapeBroadcast (", get_friendly_name(), ") can't evaluate output shape");

View File

@ -9,24 +9,17 @@ namespace ngraph { namespace vpu { namespace op {
constexpr NodeTypeInfo StaticShapeLoop::type_info; constexpr NodeTypeInfo StaticShapeLoop::type_info;
StaticShapeLoop::StaticShapeLoop(const Loop& loop) : Loop(loop) {} StaticShapeLoop::StaticShapeLoop(const Loop& loop) : Loop(loop), m_evaluatedIterationsCount{ngraph::PartialShape::dynamic()} {}
void StaticShapeLoop::validate_and_infer_types() { void StaticShapeLoop::validate_and_infer_types() {
const auto isLoopStatic = [this]() { auto& iterationsCount = m_evaluatedIterationsCount;
const auto& outs = outputs(); if (iterationsCount.is_dynamic()) {
return !outs.empty() && std::all_of(outs.cbegin(), outs.cend(), [](const Output<Node>& output) { return output.get_partial_shape().is_static(); }); Loop::validate_and_infer_types();
};
if (isLoopStatic()) { NODE_VALIDATION_CHECK(this, ngraph::evaluate_as_partial_shape(input_value(0), iterationsCount),
return; "Encountered a loop for which upper-bound estimation for iterations count ", input_value(0), " failed");
} }
Loop::validate_and_infer_types();
ngraph::PartialShape iterationsCount;
NODE_VALIDATION_CHECK(this, ngraph::evaluate_as_partial_shape(input_value(0), iterationsCount),
"Encountered a loop for which upper-bound estimation for iterations count ", input_value(0), " failed");
const auto& maxIterationsCount = iterationsCount[0].get_max_length(); const auto& maxIterationsCount = iterationsCount[0].get_max_length();
NODE_VALIDATION_CHECK(this, maxIterationsCount > 0, NODE_VALIDATION_CHECK(this, maxIterationsCount > 0,
"Encountered a loop with non-positive upper-bound estimation for iterations count ", "Encountered a loop with non-positive upper-bound estimation for iterations count ",

View File

@ -12,7 +12,8 @@ namespace ngraph { namespace vpu { namespace op {
constexpr NodeTypeInfo StaticShapeReshape::type_info; constexpr NodeTypeInfo StaticShapeReshape::type_info;
StaticShapeReshape::StaticShapeReshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero) StaticShapeReshape::StaticShapeReshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero)
: ::ngraph::opset3::Reshape(arg, pattern, special_zero) { : ::ngraph::opset3::Reshape(arg, pattern, special_zero),
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
@ -21,21 +22,20 @@ StaticShapeReshape::StaticShapeReshape(const std::shared_ptr<ngraph::opset3::Res
} }
void StaticShapeReshape::validate_and_infer_types() { void StaticShapeReshape::validate_and_infer_types() {
if (get_output_partial_shape(0).is_static()) {
return;
}
opset3::Reshape::validate_and_infer_types();
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static(), "StaticShapeReshape (", get_friendly_name(), ") ", NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
"input#0 is expected to be of static shape, got: ", get_input_partial_shape(0)); "input#0 is expected to be of static shape, got: ", get_input_partial_shape(0));
auto outputShape = get_output_partial_shape(0); auto& outputShape = m_evaluatedOutputShape;
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeReshape (", get_friendly_name(), ") ", if (outputShape.is_dynamic()) {
"output is expected to be of static rank"); opset3::Reshape::validate_and_infer_types();
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length(); outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
}
} }
NODE_VALIDATION_CHECK(this, outputShape.is_static(), NODE_VALIDATION_CHECK(this, outputShape.is_static(),

View File

@ -15,7 +15,8 @@ ngraph::vpu::op::StaticShapeTopK::StaticShapeTopK(
const std::string& mode, const std::string& mode,
const std::string& sort, const std::string& sort,
const element::Type& index_element_type) const element::Type& index_element_type)
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type} { : ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
@ -26,21 +27,22 @@ ngraph::vpu::op::StaticShapeTopK::StaticShapeTopK(
const ngraph::vpu::op::StaticShapeTopK::Mode mode, const ngraph::vpu::op::StaticShapeTopK::Mode mode,
const ngraph::vpu::op::StaticShapeTopK::SortType sort, const ngraph::vpu::op::StaticShapeTopK::SortType sort,
const ngraph::element::Type &index_element_type) const ngraph::element::Type &index_element_type)
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type} { : ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::vpu::op::StaticShapeTopK::validate_and_infer_types() { void ngraph::vpu::op::StaticShapeTopK::validate_and_infer_types() {
if (get_output_partial_shape(0).is_static() && get_output_partial_shape(1).is_static()) { auto& outputShape = m_evaluatedOutputShape;
return; if (outputShape.is_dynamic()) {
} ngraph::op::v3::TopK::validate_and_infer_types();
ngraph::op::v3::TopK::validate_and_infer_types(); outputShape = get_output_partial_shape(0);
auto outputShape = get_output_partial_shape(0); NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeTopK (", get_friendly_name(), ") ",
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeTopK (", get_friendly_name(), ") ", "output is expected to be of static rank");
"output is expected to be of static rank"); for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
for (size_t i = 0; i < outputShape.rank().get_length(); i++) { outputShape[i] = outputShape[i].get_max_length();
outputShape[i] = outputShape[i].get_max_length(); }
} }
NODE_VALIDATION_CHECK(this, outputShape.is_static(), NODE_VALIDATION_CHECK(this, outputShape.is_static(),
"StaticShapeTopK (", get_friendly_name(), ") can't evaluate output shape"); "StaticShapeTopK (", get_friendly_name(), ") can't evaluate output shape");

View File

@ -22,6 +22,7 @@ MergeSubsequentDSROperations::MergeSubsequentDSROperations() {
} }
// this will create a new DSR with correct inputs // this will create a new DSR with correct inputs
auto newDsr = dsr->copy_with_new_inputs({predecessor->input_value(0), dsr->input_value(1)}); auto newDsr = dsr->copy_with_new_inputs({predecessor->input_value(0), dsr->input_value(1)});
newDsr->set_friendly_name(dsr->get_friendly_name());
// replace DSR2 with new so DSR2 will lose all consumers so it will die after pass execution // replace DSR2 with new so DSR2 will lose all consumers so it will die after pass execution
replace_node(dsr, newDsr); replace_node(dsr, newDsr);
// reconnect all DSR1 consumers even with DSR2 which will be destructed so this is no more an issue // reconnect all DSR1 consumers even with DSR2 which will be destructed so this is no more an issue

View File

@ -139,10 +139,13 @@ void FrontEnd::parseTopK(const Model& model, const ie::CNNLayerPtr& _layer, cons
IE_ASSERT(!outputValues || outputValues->desc().numDims() == numDims); IE_ASSERT(!outputValues || outputValues->desc().numDims() == numDims);
IE_ASSERT(!outputIndices || outputIndices->desc().numDims() == numDims); IE_ASSERT(!outputIndices || outputIndices->desc().numDims() == numDims);
IE_ASSERT(layer->axis < numDims); VPU_THROW_UNLESS(layer->axis < numDims && layer->axis >= -numDims,
"Failed to parse layer {} with type {}: axis is expected to be in range [{}, {}], but got {}",
layer->name, layer->type, -numDims, numDims - 1, layer->axis);
auto perm = DimsOrder::fromNumDims(numDims).toPermutation(); const auto perm = DimsOrder::fromNumDims(numDims).toPermutation();
auto axis = perm[numDims - 1 - layer->axis]; const auto normalizedAxis = layer->axis + (layer->axis < 0 ? numDims : 0);
const auto axis = perm[numDims - 1 - normalizedAxis];
const TopKMode mode = getMode(layer); const TopKMode mode = getMode(layer);
const TopKSort sort = getSort(layer); const TopKSort sort = getSort(layer);

View File

@ -19,6 +19,9 @@ const std::vector<int64_t> axes = {
0, 0,
1, 1,
2, 2,
-1,
-2,
-3,
}; };
const std::vector<int64_t> k = { const std::vector<int64_t> k = {

View File

@ -21,7 +21,8 @@ const auto combinations = testing::Combine(
ngraph::element::i32), ngraph::element::i32),
testing::Values( testing::Values(
TopKTestCase{{{12345}, {80000}}, 75, 0}, TopKTestCase{{{12345}, {80000}}, 75, 0},
TopKTestCase{{{1234}, {4663}}, 70, 0}), TopKTestCase{{{1234}, {4663}}, 70, 0},
TopKTestCase{{{1234}, {4663}}, 70, -1}),
testing::Values(CommonTestUtils::DEVICE_MYRIAD)); testing::Values(CommonTestUtils::DEVICE_MYRIAD));