[GNA] Rewrite RemoveSingleInputConcatPass using ngraph (#7208)
* initial matcher pass * write test implementation; + add unit tests * base * add unit tests * code review fixes * code review fixes * fix * fix * move RemoveSingleInputConcat before opset to legacy conversion
This commit is contained in:
parent
9e68a673e4
commit
8985feff6f
@ -67,6 +67,7 @@
|
||||
#include "transformations/decompose_2d_conv.hpp"
|
||||
#include "transformations/convert_padded2valid_conv.hpp"
|
||||
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
|
||||
#include "transformations/remove_single_input_concat.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
|
||||
@ -738,6 +739,7 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
manager.register_pass<SwapInputMatMul>();
|
||||
manager.register_pass<InsertTransposeAfterConvOrPool>();
|
||||
manager.register_pass<ReorderActivationAndPooling>();
|
||||
manager.register_pass<RemoveSingleInputConcat>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
@ -793,9 +795,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
passes->registerPass<UnrollTIPass>();
|
||||
passes->registerPass<RemoveConstPass>();
|
||||
passes->registerPass<UnrollLSTMCellPass>();
|
||||
}
|
||||
|
||||
passes->registerPass<RemoveSingleInputConcatPass>();
|
||||
}
|
||||
|
||||
// fake quantisation aware passes
|
||||
passes->registerPass<FuseFQIntoWeightsPass>();
|
||||
|
@ -0,0 +1,47 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
|
||||
#include "transformations/remove_single_input_concat.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
using NodeInput = ngraph::Input<ngraph::Node>;
|
||||
using NodeOutput = ngraph::Output<ngraph::Node>;
|
||||
|
||||
namespace GNAPluginNS {
|
||||
NGRAPH_RTTI_DEFINITION(RemoveSingleInputConcat, "RemoveSingleInputConcat", 0);
|
||||
|
||||
RemoveSingleInputConcat::RemoveSingleInputConcat() {
|
||||
MATCHER_SCOPE(RemoveSingleInputConcat);
|
||||
|
||||
auto is_required_node = [](const ngraph::Output<ngraph::Node>& value) {
|
||||
return value.get_node_shared_ptr()->get_input_size() == 1;
|
||||
};
|
||||
|
||||
auto concat_operation = ngraph::pattern::wrap_type<ngraph::opset8::Concat>(is_required_node);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto concat_operation_node = pattern_map.find(concat_operation)->second.get_node_shared_ptr();
|
||||
|
||||
NodeOutput prev_node_output = concat_operation_node->get_input_source_output(0);
|
||||
|
||||
for (NodeInput child_input : concat_operation_node->get_output_target_inputs(0))
|
||||
child_input.replace_source_output(prev_node_output);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_operation, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
} // namespace GNAPluginNS
|
@ -0,0 +1,20 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
/**
|
||||
* @brief remove concat layers with single input
|
||||
*/
|
||||
class RemoveSingleInputConcat : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
RemoveSingleInputConcat();
|
||||
};
|
||||
|
||||
} // namespace GNAPluginNS
|
@ -0,0 +1,143 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "transformations/remove_single_input_concat.hpp"
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
using GraphInputs = std::vector<std::shared_ptr<ngraph::opset8::Parameter>>;
|
||||
using GraphOutputs = ngraph::OutputVector;
|
||||
|
||||
struct Graph {
|
||||
std::shared_ptr<ngraph::Function> createFunction();
|
||||
|
||||
GraphInputs inputs;
|
||||
GraphOutputs outputs;
|
||||
};
|
||||
|
||||
std::shared_ptr<ngraph::Function> Graph::createFunction() {
|
||||
ngraph::ResultVector results;
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(results),
|
||||
[] (ngraph::Output<ngraph::Node> output) {
|
||||
return std::make_shared<ngraph::opset8::Result>(output);
|
||||
});
|
||||
|
||||
ngraph::ParameterVector params(inputs.begin(), inputs.end());
|
||||
|
||||
return std::make_shared<ngraph::Function>(results, params);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------
|
||||
|
||||
using Operations = std::vector<std::shared_ptr<ngraph::op::Op>>;
|
||||
|
||||
Graph createGraph(int n_inputs, bool has_concat, int n_outputs) {
|
||||
GraphInputs inputs;
|
||||
Operations outputs;
|
||||
|
||||
for (int i = 0; i < n_inputs; ++i) {
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64,
|
||||
ngraph::Shape{1, 3, 64});
|
||||
inputs.push_back(input);
|
||||
outputs.push_back(input);
|
||||
}
|
||||
|
||||
{
|
||||
Operations new_outputs;
|
||||
for (auto output : outputs) {
|
||||
auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {2});
|
||||
auto add_operation = std::make_shared<ngraph::opset8::Add>(output, add_bias);
|
||||
new_outputs.push_back(add_operation);
|
||||
}
|
||||
outputs.swap(new_outputs);
|
||||
}
|
||||
|
||||
if (has_concat) {
|
||||
auto concat_operation = std::make_shared<ngraph::opset8::Concat>(ngraph::OutputVector(outputs.begin(),
|
||||
outputs.end()),
|
||||
0);
|
||||
outputs = {concat_operation};
|
||||
}
|
||||
|
||||
{
|
||||
Operations new_outputs;
|
||||
for (auto output : outputs) {
|
||||
for (int i = 0; i < n_outputs; ++i) {
|
||||
auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {3});
|
||||
auto add_operation = std::make_shared<ngraph::opset8::Add>(output, add_bias);
|
||||
new_outputs.push_back(add_operation);
|
||||
}
|
||||
}
|
||||
outputs.swap(new_outputs);
|
||||
}
|
||||
|
||||
Graph graph;
|
||||
graph.inputs.swap(inputs);
|
||||
graph.outputs.insert(graph.outputs.end(),
|
||||
std::make_move_iterator(outputs.begin()),
|
||||
std::make_move_iterator(outputs.end()));
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------
|
||||
|
||||
class RemoveSingleInputConcatFixture: public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<std::tuple<Graph /* tranformed */,
|
||||
Graph /* reference */>> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
};
|
||||
|
||||
void RemoveSingleInputConcatFixture::SetUp() {
|
||||
// TODO: use auto & [transformed_graph, reference_graph] = this->GetParam() when C++17
|
||||
Graph transformed_graph;
|
||||
Graph reference_graph;
|
||||
std::tie(transformed_graph, reference_graph) = this->GetParam();
|
||||
|
||||
function = transformed_graph.createFunction();
|
||||
reference_function = reference_graph.createFunction();
|
||||
}
|
||||
|
||||
ngraph::pass::Manager createPassManager() {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<GNAPluginNS::RemoveSingleInputConcat>();
|
||||
return manager;
|
||||
}
|
||||
|
||||
void execute_test(std::shared_ptr<ngraph::Function> function,
|
||||
std::shared_ptr<ngraph::Function> reference_function) {
|
||||
ngraph::pass::Manager pass_manager = createPassManager();
|
||||
pass_manager.run_passes(function);
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
TEST_P(RemoveSingleInputConcatFixture, CompareFunctions) {
|
||||
execute_test(function, reference_function);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RemoveSingleInputConcatTestSuite, RemoveSingleInputConcatFixture,
|
||||
::testing::Values(std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */),
|
||||
createGraph(1 /* n_inputs */, false /* has_concat */, 1 /* n_outputs */)),
|
||||
std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 2 /* n_outputs */),
|
||||
createGraph(1 /* n_inputs */, false /* has_concat */, 2 /* n_outputs */)),
|
||||
std::make_tuple(createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */),
|
||||
createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */))));
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
Loading…
Reference in New Issue
Block a user