[IE Common][VPU]: Fix small input size inference for dynamic model (#3847)

* eliminate Unsqueeze+Gather pair, when Gather gathers data by 1 dimension which was previously added by Unsqueeze which is actually doing nothing.
* calculate K only once in StaticShapeTopK. The problem happens when we have ShapeOf->Concat->ReduceMin subgraph for K evaluation. If we have a pretty small input size, the value that we received from ShapeOf may be less than one that it is concatenated with (e.g. ShapeOf 283 vs const 300), so ReduceMin returns 283. After ShapeOf elimination we don't have a chance to propagate 283 so we get 300 as a result and shape inference fail then. There are no problems with bigger input sizes just because ShapeOf always propagates value >300 and there are no such mismatch.
This commit is contained in:
Andrew Bakalin 2021-01-25 19:17:32 +03:00 committed by GitHub
parent ce0537bd1f
commit 09e2231720
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 289 additions and 1 deletions

View File

@ -0,0 +1,31 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API EliminateUnsqueezeGather;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief Remove Unsqueeze + Gather pair, if Gather gathers data by dimension
* that was previously added by Unsqueeze
*/
class ngraph::pass::EliminateUnsqueezeGather : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EliminateUnsqueezeGather();
};

View File

@ -29,6 +29,7 @@
#include "transformations/common_optimizations/relu_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/clamp_fusion.hpp"
#include "transformations/common_optimizations/pad_fusion.hpp"
#include "transformations/common_optimizations/eliminate_unsqueeze_gather.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
@ -68,6 +69,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
manager.register_pass<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
manager.register_pass<ngraph::pass::EliminateUnsqueezeGather>();
manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF

View File

@ -0,0 +1,58 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/eliminate_unsqueeze_gather.hpp"
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::EliminateUnsqueezeGather, "EliminateUnsqueezeGather", 0);
ngraph::pass::EliminateUnsqueezeGather::EliminateUnsqueezeGather() {
MATCHER_SCOPE(EliminateUnsqueezeGather);
const auto unsqueezeAxis = ngraph::pattern::any_input();
const auto unsqueeze = ngraph::pattern::wrap_type<ngraph::opset6::Unsqueeze>({ngraph::pattern::any_input(), unsqueezeAxis});
const auto gatherIndices = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
const auto gatherAxis = ngraph::pattern::any_input();
const auto gather = ngraph::pattern::wrap_type<ngraph::opset6::Gather>({unsqueeze, gatherIndices, gatherAxis});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& patternValue = m.get_pattern_value_map();
const auto& m_unsqueezeAxis = patternValue.at(unsqueezeAxis);
const auto& m_gatherAxis = patternValue.at(gatherAxis);
const auto& unsqueezeAxisNode = ngraph::as_type_ptr<ngraph::opset6::Constant>(m_unsqueezeAxis.get_node_shared_ptr());
const auto& gatherAxisNode = ngraph::as_type_ptr<ngraph::opset6::Constant>(m_gatherAxis.get_node_shared_ptr());
if (!unsqueezeAxisNode || !gatherAxisNode) {
return false;
}
const auto& unsqueezeAxisVec = unsqueezeAxisNode->cast_vector<int64_t>();
const auto& gatherAxisVec = gatherAxisNode->cast_vector<int64_t>();
if (unsqueezeAxisVec.size() != 1 || gatherAxisVec.size() != 1) {
return false;
}
if (unsqueezeAxisVec.front() != gatherAxisVec.front()) {
return false;
}
auto& m_gather = patternValue.at(gather);
const auto& m_unsqueeze = patternValue.at(unsqueeze);
const auto& unsqueezeData = m_unsqueeze.get_node_shared_ptr()->get_input_node_shared_ptr(0);
ngraph::copy_runtime_info(m_gather.get_node_shared_ptr(), unsqueezeData);
m_gather.replace(unsqueezeData);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gather, "EliminateUnsqueezeGather");
ngraph::pass::MatcherPass::register_matcher(m, callback);
}

View File

@ -92,7 +92,7 @@ void ngraph::vpu::op::StaticShapeTopK::validate_and_infer_types() {
m_normalized_axis = ngraph::normalize_axis(this->description(), m_axis, output_shape.rank());
if (k != 0) {
output_shape[m_normalized_axis] = k;
} else {
} else if (m_maximumK == -1) {
auto max_k = maximum_value(input_value(1));
const auto is_max_value_calculated = max_k.first;
const auto calculated_max_value = max_k.second;

View File

@ -0,0 +1,87 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <transformations/common_optimizations/eliminate_unsqueeze_gather.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
#include <common_test_utils/test_common.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
namespace {
using TensorType = ngraph::element::Type_t;
using TensorShape = ngraph::Shape;
class EliminateUnsqueezeGatherTest : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<TensorType, TensorShape, size_t>> {
public:
void SetUp() override {
const auto& parameters = GetParam();
const auto& inType = std::get<0>(parameters);
const auto& inShape = std::get<1>(parameters);
const auto& axis = std::get<2>(parameters);
ngraph::helpers::CompareFunctions(*transform(inShape, inType, axis),
*reference(inShape, inType, axis));
}
protected:
std::shared_ptr<const ngraph::Function> transform(
const TensorShape& inShape,
const TensorType& inType,
size_t axis) {
const auto parameter = std::make_shared<ngraph::opset6::Parameter>(inType, inShape);
const auto unsqueeze = std::make_shared<ngraph::opset6::Unsqueeze>(
parameter,
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis}));
const auto gather = std::make_shared<ngraph::opset6::Gather>(
unsqueeze,
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}),
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis}));
const auto function = std::make_shared<ngraph::Function>(
ngraph::NodeVector{gather},
ngraph::ParameterVector{parameter},
"Actual");
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::EliminateUnsqueezeGather>();
manager.run_passes(function);
return function;
}
std::shared_ptr<const ngraph::Function> reference(
const TensorShape& inShape,
const TensorType& inType,
size_t axis) {
const auto parameter = std::make_shared<ngraph::opset6::Parameter>(inType, inShape);
return std::make_shared<ngraph::Function>(
ngraph::NodeVector{parameter},
ngraph::ParameterVector{parameter},
"Reference");
}
};
TEST_P(EliminateUnsqueezeGatherTest, CompareFunctions) {
}
INSTANTIATE_TEST_CASE_P(smoke_NGraph, EliminateUnsqueezeGatherTest, testing::Combine(
testing::Values(
ngraph::element::f16,
ngraph::element::f32,
ngraph::element::i32,
ngraph::element::i64,
ngraph::element::u8),
testing::Values(
TensorShape{3, 128, 256}),
testing::Values(0, 1, 2, 3)
));
} // namespace

View File

@ -8,10 +8,13 @@
#include <ngraph/function.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
#include <vpu/ngraph/operations/static_shape_topk.hpp>
#include <vpu/ngraph/transformations/dynamic_to_static_shape_topk.hpp>
#include <vpu/ngraph/transformations/dynamic_to_static_shape.hpp>
#include <vpu/ngraph/transformations/eliminate_shapeof_after_dsr.hpp>
namespace {
@ -162,4 +165,50 @@ TEST_P(DynamicToStaticTopKPropagationShapeOfGather, KPropagation) {
INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticTopKPropagationShapeOfGather, ::testing::ValuesIn(kVec));
class KPropagationAfterShapeOfElimination : public DynamicToStaticTopKPropagationShapeOfBased {
void SetUp() override {
const auto& k = GetParam();
const auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::i64, ngraph::Shape{static_cast<size_t>(k)});
const auto shape = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::i64, ngraph::Shape{1});
const auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
data,
shape);
const auto shapeOf = std::make_shared<ngraph::opset5::ShapeOf>(dsr);
const auto builtSubgraph = buildSubgraph(shapeOf);
const auto staticShapeTopK = std::make_shared<ngraph::vpu::op::StaticShapeTopK>(dsr, builtSubgraph, 0, "max", "value");
const auto function = std::make_shared<ngraph::Function>(
staticShapeTopK->outputs(),
ngraph::ParameterVector{data, shape},
"KPropagationAfterShapeOfElimination");
validate(*function);
ngraph::pass::Manager manager;
manager.register_pass<vpu::EliminateShapeOfAfterDSR>();
manager.run_passes(function);
function->validate_nodes_and_infer_types();
validate(*function);
}
std::shared_ptr<ngraph::Node> buildSubgraph(std::shared_ptr<ngraph::Node> node) const override {
const auto concat = std::make_shared<ngraph::opset6::Concat>(
ngraph::NodeVector{
node,
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {upperBoundK})},
0);
return std::make_shared<ngraph::opset6::ReduceMin>(
concat,
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
}
};
TEST_P(KPropagationAfterShapeOfElimination, KPropagation) {
}
INSTANTIATE_TEST_CASE_P(smoke_NGraph, KPropagationAfterShapeOfElimination, ::testing::ValuesIn(kVec));
} // namespace

View File

@ -0,0 +1,61 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <shared_test_classes/base/layer_test_utils.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <vpu/private_plugin_config.hpp>
namespace {
using TensorType = ngraph::element::Type_t;
using TensorShape = ngraph::Shape;
class UnsqueezeGather : public testing::WithParamInterface<std::tuple<TensorType, TensorShape, size_t, std::string>>,
virtual public LayerTestsUtils::LayerTestsCommon {
protected:
void SetUp() override {
configuration[InferenceEngine::MYRIAD_DETECT_NETWORK_BATCH] = CONFIG_VALUE(NO);
const auto &parameters = GetParam();
const auto &inType = std::get<0>(parameters);
const auto &inShape = std::get<1>(parameters);
const auto &axis = std::get<2>(parameters);
targetDevice = std::get<3>(GetParam());
const auto parameter = std::make_shared<ngraph::opset6::Parameter>(inType, inShape);
const auto unsqueeze = std::make_shared<ngraph::opset6::Unsqueeze>(
parameter,
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis}));
const auto gather = std::make_shared<ngraph::opset6::Gather>(
unsqueeze,
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}),
ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis}));
const auto relu = std::make_shared<ngraph::opset6::Relu>(gather);
function = std::make_shared<ngraph::Function>(
ngraph::NodeVector{relu},
ngraph::ParameterVector{parameter},
"unsqueeze-gather");
}
};
TEST_P(UnsqueezeGather, CompareWithRefs) {
Run();
}
INSTANTIATE_TEST_CASE_P(smoke_NGraph, UnsqueezeGather, testing::Combine(
testing::Values(
ngraph::element::f16,
ngraph::element::f32),
testing::Values(
TensorShape{3, 128, 256}),
testing::Values(0, 1, 2, 3),
testing::Values(CommonTestUtils::DEVICE_MYRIAD)
));
} // namespace