[GNA] Rewrite RemoveSingleInputConcatPass using ngraph (#7208)
* initial matcher pass * write test implementation; + add unit tests * base * add unit tests * code review fixes * code review fixes * fix * fix * move RemoveSingleInputConcat before opset to legacy conversion
This commit is contained in:
parent
9e68a673e4
commit
8985feff6f
@ -67,6 +67,7 @@
|
|||||||
#include "transformations/decompose_2d_conv.hpp"
|
#include "transformations/decompose_2d_conv.hpp"
|
||||||
#include "transformations/convert_padded2valid_conv.hpp"
|
#include "transformations/convert_padded2valid_conv.hpp"
|
||||||
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
|
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
|
||||||
|
#include "transformations/remove_single_input_concat.hpp"
|
||||||
|
|
||||||
#include <ngraph/opsets/opset7.hpp>
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
|
|
||||||
@ -738,6 +739,7 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
|||||||
manager.register_pass<SwapInputMatMul>();
|
manager.register_pass<SwapInputMatMul>();
|
||||||
manager.register_pass<InsertTransposeAfterConvOrPool>();
|
manager.register_pass<InsertTransposeAfterConvOrPool>();
|
||||||
manager.register_pass<ReorderActivationAndPooling>();
|
manager.register_pass<ReorderActivationAndPooling>();
|
||||||
|
manager.register_pass<RemoveSingleInputConcat>();
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||||
@ -793,10 +795,9 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
|||||||
passes->registerPass<UnrollTIPass>();
|
passes->registerPass<UnrollTIPass>();
|
||||||
passes->registerPass<RemoveConstPass>();
|
passes->registerPass<RemoveConstPass>();
|
||||||
passes->registerPass<UnrollLSTMCellPass>();
|
passes->registerPass<UnrollLSTMCellPass>();
|
||||||
|
passes->registerPass<RemoveSingleInputConcatPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
passes->registerPass<RemoveSingleInputConcatPass>();
|
|
||||||
|
|
||||||
// fake quantisation aware passes
|
// fake quantisation aware passes
|
||||||
passes->registerPass<FuseFQIntoWeightsPass>();
|
passes->registerPass<FuseFQIntoWeightsPass>();
|
||||||
passes->registerPass<MoveFakeQuantizeLayerIntoQuantParamsPass>();
|
passes->registerPass<MoveFakeQuantizeLayerIntoQuantParamsPass>();
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <openvino/cc/ngraph/itt.hpp>
|
||||||
|
|
||||||
|
#include "transformations/remove_single_input_concat.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
|
||||||
|
using NodeInput = ngraph::Input<ngraph::Node>;
|
||||||
|
using NodeOutput = ngraph::Output<ngraph::Node>;
|
||||||
|
|
||||||
|
namespace GNAPluginNS {
|
||||||
|
NGRAPH_RTTI_DEFINITION(RemoveSingleInputConcat, "RemoveSingleInputConcat", 0);
|
||||||
|
|
||||||
|
RemoveSingleInputConcat::RemoveSingleInputConcat() {
|
||||||
|
MATCHER_SCOPE(RemoveSingleInputConcat);
|
||||||
|
|
||||||
|
auto is_required_node = [](const ngraph::Output<ngraph::Node>& value) {
|
||||||
|
return value.get_node_shared_ptr()->get_input_size() == 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto concat_operation = ngraph::pattern::wrap_type<ngraph::opset8::Concat>(is_required_node);
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
|
auto concat_operation_node = pattern_map.find(concat_operation)->second.get_node_shared_ptr();
|
||||||
|
|
||||||
|
NodeOutput prev_node_output = concat_operation_node->get_input_source_output(0);
|
||||||
|
|
||||||
|
for (NodeInput child_input : concat_operation_node->get_output_target_inputs(0))
|
||||||
|
child_input.replace_source_output(prev_node_output);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_operation, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace GNAPluginNS
|
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace GNAPluginNS {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief remove concat layers with single input
|
||||||
|
*/
|
||||||
|
class RemoveSingleInputConcat : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
RemoveSingleInputConcat();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace GNAPluginNS
|
@ -0,0 +1,143 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "transformations/remove_single_input_concat.hpp"
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using GraphInputs = std::vector<std::shared_ptr<ngraph::opset8::Parameter>>;
|
||||||
|
using GraphOutputs = ngraph::OutputVector;
|
||||||
|
|
||||||
|
struct Graph {
|
||||||
|
std::shared_ptr<ngraph::Function> createFunction();
|
||||||
|
|
||||||
|
GraphInputs inputs;
|
||||||
|
GraphOutputs outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> Graph::createFunction() {
|
||||||
|
ngraph::ResultVector results;
|
||||||
|
std::transform(outputs.begin(), outputs.end(), std::back_inserter(results),
|
||||||
|
[] (ngraph::Output<ngraph::Node> output) {
|
||||||
|
return std::make_shared<ngraph::opset8::Result>(output);
|
||||||
|
});
|
||||||
|
|
||||||
|
ngraph::ParameterVector params(inputs.begin(), inputs.end());
|
||||||
|
|
||||||
|
return std::make_shared<ngraph::Function>(results, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
using Operations = std::vector<std::shared_ptr<ngraph::op::Op>>;
|
||||||
|
|
||||||
|
Graph createGraph(int n_inputs, bool has_concat, int n_outputs) {
|
||||||
|
GraphInputs inputs;
|
||||||
|
Operations outputs;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_inputs; ++i) {
|
||||||
|
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64,
|
||||||
|
ngraph::Shape{1, 3, 64});
|
||||||
|
inputs.push_back(input);
|
||||||
|
outputs.push_back(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Operations new_outputs;
|
||||||
|
for (auto output : outputs) {
|
||||||
|
auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {2});
|
||||||
|
auto add_operation = std::make_shared<ngraph::opset8::Add>(output, add_bias);
|
||||||
|
new_outputs.push_back(add_operation);
|
||||||
|
}
|
||||||
|
outputs.swap(new_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_concat) {
|
||||||
|
auto concat_operation = std::make_shared<ngraph::opset8::Concat>(ngraph::OutputVector(outputs.begin(),
|
||||||
|
outputs.end()),
|
||||||
|
0);
|
||||||
|
outputs = {concat_operation};
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Operations new_outputs;
|
||||||
|
for (auto output : outputs) {
|
||||||
|
for (int i = 0; i < n_outputs; ++i) {
|
||||||
|
auto add_bias = ngraph::opset8::Constant::create(ngraph::element::i64, {1, 1, 1}, {3});
|
||||||
|
auto add_operation = std::make_shared<ngraph::opset8::Add>(output, add_bias);
|
||||||
|
new_outputs.push_back(add_operation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputs.swap(new_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
graph.inputs.swap(inputs);
|
||||||
|
graph.outputs.insert(graph.outputs.end(),
|
||||||
|
std::make_move_iterator(outputs.begin()),
|
||||||
|
std::make_move_iterator(outputs.end()));
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RemoveSingleInputConcatFixture: public CommonTestUtils::TestsCommon,
|
||||||
|
public ::testing::WithParamInterface<std::tuple<Graph /* tranformed */,
|
||||||
|
Graph /* reference */>> {
|
||||||
|
public:
|
||||||
|
void SetUp() override;
|
||||||
|
public:
|
||||||
|
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||||
|
};
|
||||||
|
|
||||||
|
void RemoveSingleInputConcatFixture::SetUp() {
|
||||||
|
// TODO: use auto & [transformed_graph, reference_graph] = this->GetParam() when C++17
|
||||||
|
Graph transformed_graph;
|
||||||
|
Graph reference_graph;
|
||||||
|
std::tie(transformed_graph, reference_graph) = this->GetParam();
|
||||||
|
|
||||||
|
function = transformed_graph.createFunction();
|
||||||
|
reference_function = reference_graph.createFunction();
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::pass::Manager createPassManager() {
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<GNAPluginNS::RemoveSingleInputConcat>();
|
||||||
|
return manager;
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute_test(std::shared_ptr<ngraph::Function> function,
|
||||||
|
std::shared_ptr<ngraph::Function> reference_function) {
|
||||||
|
ngraph::pass::Manager pass_manager = createPassManager();
|
||||||
|
pass_manager.run_passes(function);
|
||||||
|
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||||
|
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||||
|
ASSERT_TRUE(result.valid);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RemoveSingleInputConcatFixture, CompareFunctions) {
|
||||||
|
execute_test(function, reference_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(RemoveSingleInputConcatTestSuite, RemoveSingleInputConcatFixture,
|
||||||
|
::testing::Values(std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */),
|
||||||
|
createGraph(1 /* n_inputs */, false /* has_concat */, 1 /* n_outputs */)),
|
||||||
|
std::make_tuple(createGraph(1 /* n_inputs */, true /* has_concat */, 2 /* n_outputs */),
|
||||||
|
createGraph(1 /* n_inputs */, false /* has_concat */, 2 /* n_outputs */)),
|
||||||
|
std::make_tuple(createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */),
|
||||||
|
createGraph(2 /* n_inputs */, true /* has_concat */, 1 /* n_outputs */))));
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace testing
|
Loading…
Reference in New Issue
Block a user