[IE][VPU]: Fixes for Yolo-V3 (#4517)
* Fix negative axis processing for StaticShapeTopK * Preserve output names in MergeSubsequentDSROperations * Do validate_and_infer_types every time it's called in StaticShape* operations. It's needed to infer the correct output type in case it was changed from the last call moment (e.g. the ConvertPrecision pass have been called)
This commit is contained in:
parent
7aed4ab3e7
commit
430adbc191
@ -37,6 +37,9 @@ public:
|
|||||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
|
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ngraph::PartialShape m_evaluatedOutputShape;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -16,6 +16,9 @@ public:
|
|||||||
explicit StaticShapeLoop(const Loop& loop);
|
explicit StaticShapeLoop(const Loop& loop);
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor&) override;
|
bool visit_attributes(AttributeVisitor&) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ngraph::PartialShape m_evaluatedIterationsCount;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -22,6 +22,9 @@ public:
|
|||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ngraph::PartialShape m_evaluatedOutputShape;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -33,6 +33,9 @@ public:
|
|||||||
const element::Type& index_element_type = element::i32);
|
const element::Type& index_element_type = element::i32);
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ngraph::PartialShape m_evaluatedOutputShape;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -19,30 +19,32 @@ StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg,
|
|||||||
const Output<Node>& targetShape,
|
const Output<Node>& targetShape,
|
||||||
const Output<Node>& axesMapping,
|
const Output<Node>& axesMapping,
|
||||||
const ngraph::op::BroadcastModeSpec& broadcastSpec)
|
const ngraph::op::BroadcastModeSpec& broadcastSpec)
|
||||||
: ::ngraph::op::v3::Broadcast{arg, targetShape, axesMapping, broadcastSpec} {
|
: ::ngraph::op::v3::Broadcast{arg, targetShape, axesMapping, broadcastSpec},
|
||||||
|
m_evaluatedOutputShape{PartialShape::dynamic()} {
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg,
|
StaticShapeBroadcast::StaticShapeBroadcast(const Output<Node>& arg,
|
||||||
const Output<Node>& targetShape,
|
const Output<Node>& targetShape,
|
||||||
const ngraph::op::BroadcastModeSpec& broadcastSpec)
|
const ngraph::op::BroadcastModeSpec& broadcastSpec)
|
||||||
: ::ngraph::op::v3::Broadcast{arg, targetShape, broadcastSpec} {
|
: ::ngraph::op::v3::Broadcast{arg, targetShape, broadcastSpec},
|
||||||
|
m_evaluatedOutputShape{PartialShape::dynamic()} {
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticShapeBroadcast::validate_and_infer_types() {
|
void StaticShapeBroadcast::validate_and_infer_types() {
|
||||||
if (get_output_partial_shape(0).is_static()) {
|
auto& outputShape = m_evaluatedOutputShape;
|
||||||
return;
|
if (outputShape.is_dynamic()) {
|
||||||
|
::ngraph::op::v3::Broadcast::validate_and_infer_types();
|
||||||
|
|
||||||
|
outputShape = get_output_partial_shape(0);
|
||||||
|
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeBroadcast (", get_friendly_name(), ") ",
|
||||||
|
"output is expected to be of static rank");
|
||||||
|
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
|
||||||
|
outputShape[i] = outputShape[i].get_max_length();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
::ngraph::op::v3::Broadcast::validate_and_infer_types();
|
|
||||||
|
|
||||||
auto outputShape = get_output_partial_shape(0);
|
|
||||||
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeBroadcast (", get_friendly_name(), ") ",
|
|
||||||
"output is expected to be of static rank");
|
|
||||||
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
|
|
||||||
outputShape[i] = outputShape[i].get_max_length();
|
|
||||||
}
|
|
||||||
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
|
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
|
||||||
"StaticShapeBroadcast (", get_friendly_name(), ") can't evaluate output shape");
|
"StaticShapeBroadcast (", get_friendly_name(), ") can't evaluate output shape");
|
||||||
|
|
||||||
|
@ -9,24 +9,17 @@ namespace ngraph { namespace vpu { namespace op {
|
|||||||
|
|
||||||
constexpr NodeTypeInfo StaticShapeLoop::type_info;
|
constexpr NodeTypeInfo StaticShapeLoop::type_info;
|
||||||
|
|
||||||
StaticShapeLoop::StaticShapeLoop(const Loop& loop) : Loop(loop) {}
|
StaticShapeLoop::StaticShapeLoop(const Loop& loop) : Loop(loop), m_evaluatedIterationsCount{ngraph::PartialShape::dynamic()} {}
|
||||||
|
|
||||||
void StaticShapeLoop::validate_and_infer_types() {
|
void StaticShapeLoop::validate_and_infer_types() {
|
||||||
const auto isLoopStatic = [this]() {
|
auto& iterationsCount = m_evaluatedIterationsCount;
|
||||||
const auto& outs = outputs();
|
if (iterationsCount.is_dynamic()) {
|
||||||
return !outs.empty() && std::all_of(outs.cbegin(), outs.cend(), [](const Output<Node>& output) { return output.get_partial_shape().is_static(); });
|
Loop::validate_and_infer_types();
|
||||||
};
|
|
||||||
|
|
||||||
if (isLoopStatic()) {
|
NODE_VALIDATION_CHECK(this, ngraph::evaluate_as_partial_shape(input_value(0), iterationsCount),
|
||||||
return;
|
"Encountered a loop for which upper-bound estimation for iterations count ", input_value(0), " failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
Loop::validate_and_infer_types();
|
|
||||||
|
|
||||||
ngraph::PartialShape iterationsCount;
|
|
||||||
NODE_VALIDATION_CHECK(this, ngraph::evaluate_as_partial_shape(input_value(0), iterationsCount),
|
|
||||||
"Encountered a loop for which upper-bound estimation for iterations count ", input_value(0), " failed");
|
|
||||||
|
|
||||||
const auto& maxIterationsCount = iterationsCount[0].get_max_length();
|
const auto& maxIterationsCount = iterationsCount[0].get_max_length();
|
||||||
NODE_VALIDATION_CHECK(this, maxIterationsCount > 0,
|
NODE_VALIDATION_CHECK(this, maxIterationsCount > 0,
|
||||||
"Encountered a loop with non-positive upper-bound estimation for iterations count ",
|
"Encountered a loop with non-positive upper-bound estimation for iterations count ",
|
||||||
|
@ -12,7 +12,8 @@ namespace ngraph { namespace vpu { namespace op {
|
|||||||
constexpr NodeTypeInfo StaticShapeReshape::type_info;
|
constexpr NodeTypeInfo StaticShapeReshape::type_info;
|
||||||
|
|
||||||
StaticShapeReshape::StaticShapeReshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero)
|
StaticShapeReshape::StaticShapeReshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero)
|
||||||
: ::ngraph::opset3::Reshape(arg, pattern, special_zero) {
|
: ::ngraph::opset3::Reshape(arg, pattern, special_zero),
|
||||||
|
m_evaluatedOutputShape{PartialShape::dynamic()} {
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,21 +22,20 @@ StaticShapeReshape::StaticShapeReshape(const std::shared_ptr<ngraph::opset3::Res
|
|||||||
}
|
}
|
||||||
|
|
||||||
void StaticShapeReshape::validate_and_infer_types() {
|
void StaticShapeReshape::validate_and_infer_types() {
|
||||||
if (get_output_partial_shape(0).is_static()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
opset3::Reshape::validate_and_infer_types();
|
|
||||||
|
|
||||||
set_input_is_relevant_to_shape(1);
|
set_input_is_relevant_to_shape(1);
|
||||||
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
|
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
|
||||||
"input#0 is expected to be of static shape, got: ", get_input_partial_shape(0));
|
"input#0 is expected to be of static shape, got: ", get_input_partial_shape(0));
|
||||||
|
|
||||||
auto outputShape = get_output_partial_shape(0);
|
auto& outputShape = m_evaluatedOutputShape;
|
||||||
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
|
if (outputShape.is_dynamic()) {
|
||||||
"output is expected to be of static rank");
|
opset3::Reshape::validate_and_infer_types();
|
||||||
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
|
|
||||||
outputShape[i] = outputShape[i].get_max_length();
|
outputShape = get_output_partial_shape(0);
|
||||||
|
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeReshape (", get_friendly_name(), ") ",
|
||||||
|
"output is expected to be of static rank");
|
||||||
|
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
|
||||||
|
outputShape[i] = outputShape[i].get_max_length();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
|
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
|
||||||
|
@ -15,7 +15,8 @@ ngraph::vpu::op::StaticShapeTopK::StaticShapeTopK(
|
|||||||
const std::string& mode,
|
const std::string& mode,
|
||||||
const std::string& sort,
|
const std::string& sort,
|
||||||
const element::Type& index_element_type)
|
const element::Type& index_element_type)
|
||||||
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type} {
|
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type},
|
||||||
|
m_evaluatedOutputShape{PartialShape::dynamic()} {
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,21 +27,22 @@ ngraph::vpu::op::StaticShapeTopK::StaticShapeTopK(
|
|||||||
const ngraph::vpu::op::StaticShapeTopK::Mode mode,
|
const ngraph::vpu::op::StaticShapeTopK::Mode mode,
|
||||||
const ngraph::vpu::op::StaticShapeTopK::SortType sort,
|
const ngraph::vpu::op::StaticShapeTopK::SortType sort,
|
||||||
const ngraph::element::Type &index_element_type)
|
const ngraph::element::Type &index_element_type)
|
||||||
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type} {
|
: ngraph::op::v3::TopK{data, k, axis, mode, sort, index_element_type},
|
||||||
|
m_evaluatedOutputShape{PartialShape::dynamic()} {
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ngraph::vpu::op::StaticShapeTopK::validate_and_infer_types() {
|
void ngraph::vpu::op::StaticShapeTopK::validate_and_infer_types() {
|
||||||
if (get_output_partial_shape(0).is_static() && get_output_partial_shape(1).is_static()) {
|
auto& outputShape = m_evaluatedOutputShape;
|
||||||
return;
|
if (outputShape.is_dynamic()) {
|
||||||
}
|
ngraph::op::v3::TopK::validate_and_infer_types();
|
||||||
|
|
||||||
ngraph::op::v3::TopK::validate_and_infer_types();
|
outputShape = get_output_partial_shape(0);
|
||||||
auto outputShape = get_output_partial_shape(0);
|
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeTopK (", get_friendly_name(), ") ",
|
||||||
NODE_VALIDATION_CHECK(this, outputShape.rank().is_static(), "StaticShapeTopK (", get_friendly_name(), ") ",
|
"output is expected to be of static rank");
|
||||||
"output is expected to be of static rank");
|
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
|
||||||
for (size_t i = 0; i < outputShape.rank().get_length(); i++) {
|
outputShape[i] = outputShape[i].get_max_length();
|
||||||
outputShape[i] = outputShape[i].get_max_length();
|
}
|
||||||
}
|
}
|
||||||
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
|
NODE_VALIDATION_CHECK(this, outputShape.is_static(),
|
||||||
"StaticShapeTopK (", get_friendly_name(), ") can't evaluate output shape");
|
"StaticShapeTopK (", get_friendly_name(), ") can't evaluate output shape");
|
||||||
|
@ -22,6 +22,7 @@ MergeSubsequentDSROperations::MergeSubsequentDSROperations() {
|
|||||||
}
|
}
|
||||||
// this will create a new DSR with correct inputs
|
// this will create a new DSR with correct inputs
|
||||||
auto newDsr = dsr->copy_with_new_inputs({predecessor->input_value(0), dsr->input_value(1)});
|
auto newDsr = dsr->copy_with_new_inputs({predecessor->input_value(0), dsr->input_value(1)});
|
||||||
|
newDsr->set_friendly_name(dsr->get_friendly_name());
|
||||||
// replace DSR2 with new so DSR2 will lose all consumers so it will die after pass execution
|
// replace DSR2 with new so DSR2 will lose all consumers so it will die after pass execution
|
||||||
replace_node(dsr, newDsr);
|
replace_node(dsr, newDsr);
|
||||||
// reconnect all DSR1 consumers even with DSR2 which will be destructed so this is no more an issue
|
// reconnect all DSR1 consumers even with DSR2 which will be destructed so this is no more an issue
|
||||||
|
@ -139,10 +139,13 @@ void FrontEnd::parseTopK(const Model& model, const ie::CNNLayerPtr& _layer, cons
|
|||||||
IE_ASSERT(!outputValues || outputValues->desc().numDims() == numDims);
|
IE_ASSERT(!outputValues || outputValues->desc().numDims() == numDims);
|
||||||
IE_ASSERT(!outputIndices || outputIndices->desc().numDims() == numDims);
|
IE_ASSERT(!outputIndices || outputIndices->desc().numDims() == numDims);
|
||||||
|
|
||||||
IE_ASSERT(layer->axis < numDims);
|
VPU_THROW_UNLESS(layer->axis < numDims && layer->axis >= -numDims,
|
||||||
|
"Failed to parse layer {} with type {}: axis is expected to be in range [{}, {}], but got {}",
|
||||||
|
layer->name, layer->type, -numDims, numDims - 1, layer->axis);
|
||||||
|
|
||||||
auto perm = DimsOrder::fromNumDims(numDims).toPermutation();
|
const auto perm = DimsOrder::fromNumDims(numDims).toPermutation();
|
||||||
auto axis = perm[numDims - 1 - layer->axis];
|
const auto normalizedAxis = layer->axis + (layer->axis < 0 ? numDims : 0);
|
||||||
|
const auto axis = perm[numDims - 1 - normalizedAxis];
|
||||||
|
|
||||||
const TopKMode mode = getMode(layer);
|
const TopKMode mode = getMode(layer);
|
||||||
const TopKSort sort = getSort(layer);
|
const TopKSort sort = getSort(layer);
|
||||||
|
@ -19,6 +19,9 @@ const std::vector<int64_t> axes = {
|
|||||||
0,
|
0,
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
|
-1,
|
||||||
|
-2,
|
||||||
|
-3,
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<int64_t> k = {
|
const std::vector<int64_t> k = {
|
||||||
|
@ -21,7 +21,8 @@ const auto combinations = testing::Combine(
|
|||||||
ngraph::element::i32),
|
ngraph::element::i32),
|
||||||
testing::Values(
|
testing::Values(
|
||||||
TopKTestCase{{{12345}, {80000}}, 75, 0},
|
TopKTestCase{{{12345}, {80000}}, 75, 0},
|
||||||
TopKTestCase{{{1234}, {4663}}, 70, 0}),
|
TopKTestCase{{{1234}, {4663}}, 70, 0},
|
||||||
|
TopKTestCase{{{1234}, {4663}}, 70, -1}),
|
||||||
testing::Values(CommonTestUtils::DEVICE_MYRIAD));
|
testing::Values(CommonTestUtils::DEVICE_MYRIAD));
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user