From 8985feff6fb19b275065090bdd1bde8e77ee1bf6 Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Tue, 7 Sep 2021 10:56:41 +0300 Subject: [PATCH] [GNA] Rewrite RemoveSingleInputConcatPass using ngraph (#7208) * initial matcher pass * write test implementation; + add unit tests * base * add unit tests * code review fixes * code review fixes * fix * fix * move RemoveSingleInputConcat before opset to legacy conversion --- .../src/gna_plugin/gna_plugin.cpp | 5 +- .../remove_single_input_concat.cpp | 47 ++++++ .../remove_single_input_concat.hpp | 20 +++ .../gna_remove_single_input_concat.cpp | 143 ++++++++++++++++++ 4 files changed, 213 insertions(+), 2 deletions(-) create mode 100644 inference-engine/src/gna_plugin/transformations/remove_single_input_concat.cpp create mode 100644 inference-engine/src/gna_plugin/transformations/remove_single_input_concat.hpp create mode 100644 inference-engine/tests/unit/gna/ngraph/transformations/gna_remove_single_input_concat.cpp diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 3f61d3289c7..e1d615dbdec 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -67,6 +67,7 @@ #include "transformations/decompose_2d_conv.hpp" #include "transformations/convert_padded2valid_conv.hpp" #include "transformations/op_conversions/lstm_cell_decomposition.hpp" +#include "transformations/remove_single_input_concat.hpp" #include @@ -738,6 +739,7 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); @@ -793,10 +795,9 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) { passes->registerPass(); passes->registerPass(); passes->registerPass(); + passes->registerPass(); } - passes->registerPass(); - // fake quantisation aware passes passes->registerPass(); passes->registerPass(); diff --git a/inference-engine/src/gna_plugin/transformations/remove_single_input_concat.cpp b/inference-engine/src/gna_plugin/transformations/remove_single_input_concat.cpp new file mode 100644 index 00000000000..b367bd63811 --- /dev/null +++ b/inference-engine/src/gna_plugin/transformations/remove_single_input_concat.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/remove_single_input_concat.hpp" + +#include +#include + +#include +#include +#include + +using NodeInput = ngraph::Input; +using NodeOutput = ngraph::Output; + +namespace GNAPluginNS { + NGRAPH_RTTI_DEFINITION(RemoveSingleInputConcat, "RemoveSingleInputConcat", 0); + + RemoveSingleInputConcat::RemoveSingleInputConcat() { + MATCHER_SCOPE(RemoveSingleInputConcat); + + auto is_required_node = [](const ngraph::Output& value) { + return value.get_node_shared_ptr()->get_input_size() == 1; + }; + + auto concat_operation = ngraph::pattern::wrap_type(is_required_node); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto concat_operation_node = pattern_map.find(concat_operation)->second.get_node_shared_ptr(); + + NodeOutput prev_node_output = concat_operation_node->get_input_source_output(0); + + for (NodeInput child_input : concat_operation_node->get_output_target_inputs(0)) + child_input.replace_source_output(prev_node_output); + + return true; + }; + + auto m = std::make_shared(concat_operation, matcher_name); + this->register_matcher(m, callback); + } + +} // namespace GNAPluginNS diff --git a/inference-engine/src/gna_plugin/transformations/remove_single_input_concat.hpp b/inference-engine/src/gna_plugin/transformations/remove_single_input_concat.hpp new file mode 100644 index 00000000000..7730c36d9af --- /dev/null +++ b/inference-engine/src/gna_plugin/transformations/remove_single_input_concat.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace GNAPluginNS { + +/** + * @brief remove concat layers with single input + */ +class RemoveSingleInputConcat : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + RemoveSingleInputConcat(); +}; + +} // namespace GNAPluginNS diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_remove_single_input_concat.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_remove_single_input_concat.cpp new file mode 100644 index 00000000000..dfb2a0f0a2d --- /dev/null +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_remove_single_input_concat.cpp @@ -0,0 +1,143 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/remove_single_input_concat.hpp" + +#include "common_test_utils/ngraph_test_utils.hpp" +#include +#include +#include +#include + +namespace testing { +namespace { + +using GraphInputs = std::vector>; +using GraphOutputs = ngraph::OutputVector; + +struct Graph { + std::shared_ptr createFunction(); + + GraphInputs inputs; + GraphOutputs outputs; +}; + +std::shared_ptr Graph::createFunction() { + ngraph::ResultVector results; + std::transform(outputs.begin(), outputs.end(), std::back_inserter(results), + [] (ngraph::Output output) { + return std::make_shared(output); + }); + + ngraph::ParameterVector params(inputs.begin(), inputs.end()); + + return std::make_shared(results, params); +} + +// ------------------------------------------------------------------------------------------------------- + +using Operations = std::vector>; + +Graph createGraph(int n_inputs, bool has_concat, int n_outputs) { + GraphInputs inputs; + Operations outputs; + + for (int i = 0; i < n_inputs; ++i) { + auto input = std::make_shared(ngraph::element::i64, + ngraph::Shape{1, 3, 64}); + inputs.push_back(input); + outputs.push_back(input); + } + + { + Operations new_outputs; + for (auto output : outputs) { + auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {2}); + auto add_operation = std::make_shared(output, add_bias); + new_outputs.push_back(add_operation); + } + outputs.swap(new_outputs); + } + + if (has_concat) { + auto concat_operation = std::make_shared(ngraph::OutputVector(outputs.begin(), + outputs.end()), + 0); + outputs = {concat_operation}; + } + + { + Operations new_outputs; + for (auto output : outputs) { + for (int i = 0; i < n_outputs; ++i) { + auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {3}); + auto add_operation = std::make_shared(output, add_bias); + new_outputs.push_back(add_operation); + } + } + outputs.swap(new_outputs); + } + + Graph graph; + graph.inputs.swap(inputs); + graph.outputs.insert(graph.outputs.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); + + return graph; +} + +// ------------------------------------------------------------------------------------------------------- + +class RemoveSingleInputConcatFixture: public CommonTestUtils::TestsCommon, + public ::testing::WithParamInterface> { +public: + void SetUp() override; +public: + std::shared_ptr function, reference_function; +}; + +void RemoveSingleInputConcatFixture::SetUp() { + // TODO: use auto & [transformed_graph, reference_graph] = this->GetParam() when C++17 + Graph transformed_graph; + Graph reference_graph; + std::tie(transformed_graph, reference_graph) = this->GetParam(); + + function = transformed_graph.createFunction(); + reference_function = reference_graph.createFunction(); +} + +ngraph::pass::Manager createPassManager() { + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + return manager; +} + +void execute_test(std::shared_ptr function, + std::shared_ptr reference_function) { + ngraph::pass::Manager pass_manager = createPassManager(); + pass_manager.run_passes(function); + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, reference_function); + ASSERT_TRUE(result.valid); +} + +TEST_P(RemoveSingleInputConcatFixture, CompareFunctions) { + execute_test(function, reference_function); +} + +INSTANTIATE_TEST_SUITE_P(RemoveSingleInputConcatTestSuite, RemoveSingleInputConcatFixture, + ::testing::Values(std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */), + createGraph(1 /* n_inputs */, false /* has_concat */, 1 /* n_outputs */)), + std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 2 /* n_outputs */), + createGraph(1 /* n_inputs */, false /* has_concat */, 2 /* n_outputs */)), + std::make_tuple(createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */), + createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */)))); + +} // namespace +} // namespace testing