[GNA] Rewrite RemoveSingleInputConcatPass using ngraph (#7208)

* initial matcher pass

* write test implementation; + add unit tests

* base

* add unit tests

* code review fixes

* code review fixes

* fix

* fix

* move RemoveSingleInputConcat before opset to legacy conversion
This commit is contained in:
Evgeny Kotov 2021-09-07 10:56:41 +03:00 committed by GitHub
parent 9e68a673e4
commit 8985feff6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 213 additions and 2 deletions

View File

@ -67,6 +67,7 @@
#include "transformations/decompose_2d_conv.hpp"
#include "transformations/convert_padded2valid_conv.hpp"
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
#include "transformations/remove_single_input_concat.hpp"
#include <ngraph/opsets/opset7.hpp>
@ -738,6 +739,7 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<SwapInputMatMul>();
manager.register_pass<InsertTransposeAfterConvOrPool>();
manager.register_pass<ReorderActivationAndPooling>();
manager.register_pass<RemoveSingleInputConcat>();
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
@ -793,10 +795,9 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<UnrollTIPass>();
passes->registerPass<RemoveConstPass>();
passes->registerPass<UnrollLSTMCellPass>();
passes->registerPass<RemoveSingleInputConcatPass>();
}
passes->registerPass<RemoveSingleInputConcatPass>();
// fake quantisation aware passes
passes->registerPass<FuseFQIntoWeightsPass>();
passes->registerPass<MoveFakeQuantizeLayerIntoQuantParamsPass>();

View File

@ -0,0 +1,47 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/cc/ngraph/itt.hpp>
#include "transformations/remove_single_input_concat.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pass/manager.hpp>
using NodeInput = ngraph::Input<ngraph::Node>;
using NodeOutput = ngraph::Output<ngraph::Node>;
namespace GNAPluginNS {
NGRAPH_RTTI_DEFINITION(RemoveSingleInputConcat, "RemoveSingleInputConcat", 0);
RemoveSingleInputConcat::RemoveSingleInputConcat() {
MATCHER_SCOPE(RemoveSingleInputConcat);
auto is_required_node = [](const ngraph::Output<ngraph::Node>& value) {
return value.get_node_shared_ptr()->get_input_size() == 1;
};
auto concat_operation = ngraph::pattern::wrap_type<ngraph::opset8::Concat>(is_required_node);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto concat_operation_node = pattern_map.find(concat_operation)->second.get_node_shared_ptr();
NodeOutput prev_node_output = concat_operation_node->get_input_source_output(0);
for (NodeInput child_input : concat_operation_node->get_output_target_inputs(0))
child_input.replace_source_output(prev_node_output);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_operation, matcher_name);
this->register_matcher(m, callback);
}
} // namespace GNAPluginNS

View File

@ -0,0 +1,20 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace GNAPluginNS {
/**
* @brief remove concat layers with single input
*/
class RemoveSingleInputConcat : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
RemoveSingleInputConcat();
};
} // namespace GNAPluginNS

View File

@ -0,0 +1,143 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "transformations/remove_single_input_concat.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
namespace testing {
namespace {
using GraphInputs = std::vector<std::shared_ptr<ngraph::opset8::Parameter>>;
using GraphOutputs = ngraph::OutputVector;
struct Graph {
std::shared_ptr<ngraph::Function> createFunction();
GraphInputs inputs;
GraphOutputs outputs;
};
std::shared_ptr<ngraph::Function> Graph::createFunction() {
ngraph::ResultVector results;
std::transform(outputs.begin(), outputs.end(), std::back_inserter(results),
[] (ngraph::Output<ngraph::Node> output) {
return std::make_shared<ngraph::opset8::Result>(output);
});
ngraph::ParameterVector params(inputs.begin(), inputs.end());
return std::make_shared<ngraph::Function>(results, params);
}
// -------------------------------------------------------------------------------------------------------
using Operations = std::vector<std::shared_ptr<ngraph::op::Op>>;
Graph createGraph(int n_inputs, bool has_concat, int n_outputs) {
GraphInputs inputs;
Operations outputs;
for (int i = 0; i < n_inputs; ++i) {
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64,
ngraph::Shape{1, 3, 64});
inputs.push_back(input);
outputs.push_back(input);
}
{
Operations new_outputs;
for (auto output : outputs) {
auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {2});
auto add_operation = std::make_shared<ngraph::opset8::Add>(output, add_bias);
new_outputs.push_back(add_operation);
}
outputs.swap(new_outputs);
}
if (has_concat) {
auto concat_operation = std::make_shared<ngraph::opset8::Concat>(ngraph::OutputVector(outputs.begin(),
outputs.end()),
0);
outputs = {concat_operation};
}
{
Operations new_outputs;
for (auto output : outputs) {
for (int i = 0; i < n_outputs; ++i) {
auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {3});
auto add_operation = std::make_shared<ngraph::opset8::Add>(output, add_bias);
new_outputs.push_back(add_operation);
}
}
outputs.swap(new_outputs);
}
Graph graph;
graph.inputs.swap(inputs);
graph.outputs.insert(graph.outputs.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
return graph;
}
// -------------------------------------------------------------------------------------------------------
class RemoveSingleInputConcatFixture: public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<std::tuple<Graph /* tranformed */,
Graph /* reference */>> {
public:
void SetUp() override;
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
void RemoveSingleInputConcatFixture::SetUp() {
// TODO: use auto & [transformed_graph, reference_graph] = this->GetParam() when C++17
Graph transformed_graph;
Graph reference_graph;
std::tie(transformed_graph, reference_graph) = this->GetParam();
function = transformed_graph.createFunction();
reference_function = reference_graph.createFunction();
}
ngraph::pass::Manager createPassManager() {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<GNAPluginNS::RemoveSingleInputConcat>();
return manager;
}
void execute_test(std::shared_ptr<ngraph::Function> function,
std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager pass_manager = createPassManager();
pass_manager.run_passes(function);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}
TEST_P(RemoveSingleInputConcatFixture, CompareFunctions) {
execute_test(function, reference_function);
}
INSTANTIATE_TEST_SUITE_P(RemoveSingleInputConcatTestSuite, RemoveSingleInputConcatFixture,
::testing::Values(std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */),
createGraph(1 /* n_inputs */, false /* has_concat */, 1 /* n_outputs */)),
std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 2 /* n_outputs */),
createGraph(1 /* n_inputs */, false /* has_concat */, 2 /* n_outputs */)),
std::make_tuple(createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */),
createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */))));
} // namespace
} // namespace testing