[IE][VPU]: Fixes for Yolo-V3 (#4517)

* Fix negative axis processing for StaticShapeTopK
* Preserve output names in MergeSubsequentDSROperations
* Do validate_and_infer_types every time it's called in StaticShape* operations. It's needed to infer the correct output type in case it was changed from the last call moment (e.g. the ConvertPrecision pass have been called)
This commit is contained in:
Andrew Bakalin 2021-03-03 16:47:42 +03:00 committed by GitHub
parent 7aed4ab3e7
commit 430adbc191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 69 additions and 52 deletions

View File

@ -37,6 +37,9 @@ public:
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
protected:
ngraph::PartialShape m_evaluatedOutputShape;
};
} // namespace op

View File

@ -16,6 +16,9 @@ public:
explicit StaticShapeLoop(const Loop& loop);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor&) override;
protected:
ngraph::PartialShape m_evaluatedIterationsCount;
};
} // namespace op

View File

@ -22,6 +22,9 @@ public:
const NodeTypeInfo& get_type_info() const override { return type_info; }
void validate_and_infer_types() override;
protected:
ngraph::PartialShape m_evaluatedOutputShape;
};
} // namespace op

View File

@ -33,6 +33,9 @@ public:
const element::Type& index_element_type = element::i32);
void validate_and_infer_types() override;
protected:
ngraph::PartialShape m_evaluatedOutputShape;
};
} // namespace op

View File

@ -19,30 +19,32 @@ StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg,
const Output<Node>& targetShape,
const Output<Node>& axesMapping,
const ngraph::op::BroadcastModeSpec& broadcastSpec)
: ::ngraph::op::v3::Broadcast{arg, targetShape, axesMapping, broadcastSpec} {
: ::ngraph::op::v3::Broadcast{arg, targetShape, axesMapping, broadcastSpec},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types();
}
StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg,
const Output<Node>& targetShape,
const ngraph::op::BroadcastModeSpec& broadcastSpec)
: ::ngraph::op::v3::Broadcast{arg, targetShape, broadcastSpec} {
: ::ngraph::op::v3::Broadcast{arg, targetShape, broadcastSpec},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types();
}
void StaticShapeBroadcast::validate_and_infer_types() {
if (get_output_partial_shape(0).is_static()) {
return;
auto& outputShape = m_evaluatedOutputShape;
if (outputShape.is_dynamic()) {
::ngraph::op::v3::Broadcast::validate_and_infer_types();
outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeBroadcast (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
}
}
::ngraph::op::v3::Broadcast::validate_and_infer_types();
auto outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeBroadcast (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
}
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
"StaticShapeBroadcast (", get_friendly_name(), ") can't evaluate output shape");

View File

@ -9,24 +9,17 @@ namespace ngraph { namespace vpu { namespace op {
constexpr NodeTypeInfo StaticShapeLoop::type_info;
StaticShapeLoop::StaticShapeLoop(const Loop& loop) : Loop(loop) {}
StaticShapeLoop::StaticShapeLoop(const Loop& loop) : Loop(loop), m_evaluatedIterationsCount{ngraph::PartialShape::dynamic()} {}
void StaticShapeLoop::validate_and_infer_types() {
const auto isLoopStatic = [this]() {
const auto& outs = outputs();
return !outs.empty() && std::all_of(outs.cbegin(), outs.cend(), [](const Output<Node>& output) { return output.get_partial_shape().is_static(); });
};
auto& iterationsCount = m_evaluatedIterationsCount;
if (iterationsCount.is_dynamic()) {
Loop::validate_and_infer_types();
if (isLoopStatic()) {
return;
NODE_VALIDATION_CHECK(this, ngraph::evaluate_as_partial_shape(input_value(0), iterationsCount),
"Encountered a loop for which upper-bound estimation for iterations count ", input_value(0), " failed");
}
Loop::validate_and_infer_types();
ngraph::PartialShape iterationsCount;
NODE_VALIDATION_CHECK(this, ngraph::evaluate_as_partial_shape(input_value(0), iterationsCount),
"Encountered a loop for which upper-bound estimation for iterations count ", input_value(0), " failed");
const auto& maxIterationsCount = iterationsCount[0].get_max_length();
NODE_VALIDATION_CHECK(this, maxIterationsCount > 0,
"Encountered a loop with non-positive upper-bound estimation for iterations count ",

View File

@ -12,7 +12,8 @@ namespace ngraph { namespace vpu { namespace op {
constexpr NodeTypeInfo StaticShapeReshape::type_info;
StaticShapeReshape::StaticShapeReshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero)
: ::ngraph::opset3::Reshape(arg, pattern, special_zero) {
: ::ngraph::opset3::Reshape(arg, pattern, special_zero),
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types();
}
@ -21,21 +22,20 @@ StaticShapeReshape::StaticShapeReshape(const std::shared_ptr<ngraph::opset3::Res
}
void StaticShapeReshape::validate_and_infer_types() {
if (get_output_partial_shape(0).is_static()) {
return;
}
opset3::Reshape::validate_and_infer_types();
set_input_is_relevant_to_shape(1);
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
"input#0 is expected to be of static shape, got: ", get_input_partial_shape(0));
auto outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
auto& outputShape = m_evaluatedOutputShape;
if (outputShape.is_dynamic()) {
opset3::Reshape::validate_and_infer_types();
outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
}
}
NODE_VALIDATION_CHECK(this, outputShape.is_static(),

View File

@ -15,7 +15,8 @@ ngraph::vpu::op::StaticShapeTopK::StaticShapeTopK(
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type} {
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types();
}
@ -26,21 +27,22 @@ ngraph::vpu::op::StaticShapeTopK::StaticShapeTopK(
const ngraph::vpu::op::StaticShapeTopK::Mode mode,
const ngraph::vpu::op::StaticShapeTopK::SortType sort,
const ngraph::element::Type &index_element_type)
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type} {
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type},
m_evaluatedOutputShape{PartialShape::dynamic()} {
constructor_validate_and_infer_types();
}
void ngraph::vpu::op::StaticShapeTopK::validate_and_infer_types() {
if (get_output_partial_shape(0).is_static() && get_output_partial_shape(1).is_static()) {
return;
}
auto& outputShape = m_evaluatedOutputShape;
if (outputShape.is_dynamic()) {
ngraph::op::v3::TopK::validate_and_infer_types();
ngraph::op::v3::TopK::validate_and_infer_types();
auto outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeTopK (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
outputShape = get_output_partial_shape(0);
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeTopK (", get_friendly_name(), ") ",
"output is expected to be of static rank");
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
outputShape[i] = outputShape[i].get_max_length();
}
}
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
"StaticShapeTopK (", get_friendly_name(), ") can't evaluate output shape");

View File

@ -22,6 +22,7 @@ MergeSubsequentDSROperations::MergeSubsequentDSROperations() {
}
// this will create a new DSR with correct inputs
auto newDsr = dsr->copy_with_new_inputs({predecessor->input_value(0), dsr->input_value(1)});
newDsr->set_friendly_name(dsr->get_friendly_name());
// replace DSR2 with new so DSR2 will lose all consumers so it will die after pass execution
replace_node(dsr, newDsr);
// reconnect all DSR1 consumers even with DSR2 which will be destructed so this is no more an issue

View File

@ -139,10 +139,13 @@ void FrontEnd::parseTopK(const Model& model, const ie::CNNLayerPtr& _layer, cons
IE_ASSERT(!outputValues || outputValues->desc().numDims() == numDims);
IE_ASSERT(!outputIndices || outputIndices->desc().numDims() == numDims);
IE_ASSERT(layer->axis < numDims);
VPU_THROW_UNLESS(layer->axis < numDims && layer->axis >= -numDims,
"Failed to parse layer {} with type {}: axis is expected to be in range [{}, {}], but got {}",
layer->name, layer->type, -numDims, numDims - 1, layer->axis);
auto perm = DimsOrder::fromNumDims(numDims).toPermutation();
auto axis = perm[numDims - 1 - layer->axis];
const auto perm = DimsOrder::fromNumDims(numDims).toPermutation();
const auto normalizedAxis = layer->axis + (layer->axis < 0 ? numDims : 0);
const auto axis = perm[numDims - 1 - normalizedAxis];
const TopKMode mode = getMode(layer);
const TopKSort sort = getSort(layer);

View File

@ -19,6 +19,9 @@ const std::vector<int64_t> axes = {
0,
1,
2,
-1,
-2,
-3,
};
const std::vector<int64_t> k = {

View File

@ -21,7 +21,8 @@ const auto combinations = testing::Combine(
ngraph::element::i32),
testing::Values(
TopKTestCase{{{12345}, {80000}}, 75, 0},
TopKTestCase{{{1234}, {4663}}, 70, 0}),
TopKTestCase{{{1234}, {4663}}, 70, 0},
TopKTestCase{{{1234}, {4663}}, 70, -1}),
testing::Values(CommonTestUtils::DEVICE_MYRIAD));