[GNA] Support bias and FQ in SwapInputMatMul transformation (#6996) (#7027)

This commit is contained in:
Elizaveta Lobanova
2021-08-13 12:07:08 +03:00
committed by GitHub
parent 3117879c54
commit 114ed1cb4b
4 changed files with 335 additions and 72 deletions

View File

@@ -685,6 +685,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<SplitConvolution>();
manager.register_pass<InsertTransposeBeforeMatmul>();
manager.register_pass<SwapInputMatMul>();
manager.register_pass<SwapInputMatMulWithBias>();
manager.register_pass<SwapInputMatMulWithFq>();
manager.register_pass<InsertTransposeAfterConvOrPool>();
manager.register_pass<ReorderActivationAndPooling>();
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();

View File

@@ -6,6 +6,7 @@
#include <vector>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
@@ -17,84 +18,160 @@
using namespace GNAPluginNS;
NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0);
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0);
static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset7::MatMul> matmul_node,
std::shared_ptr<ngraph::Node> add,
std::shared_ptr<ngraph::Node> bias,
std::shared_ptr<ngraph::Node> fq) {
auto create_transpose =
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
std::vector<size_t> transpose_order(output_shape.size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
auto transpose = std::make_shared<ngraph::opset7::Transpose>(
node, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order));
transpose->set_friendly_name(transpose_name);
return transpose;
};
ngraph::NodeVector new_ops;
gnalog() << "Swap and transpose inputs for " << matmul_node->get_friendly_name() << "\n";
std::shared_ptr<ngraph::Node> new_matmul = std::make_shared<ngraph::opset7::MatMul>(
matmul_node->input(1).get_source_output(), matmul_node->input(0).get_source_output(),
!matmul_node->get_transpose_b(), !matmul_node->get_transpose_a());
new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/swap_inputs");
new_ops.push_back(new_matmul);
std::shared_ptr<ngraph::Node> old_root_node = matmul_node;
if (bias != nullptr) {
// output of MatMul will be transposed comparing with original one, so the bias should be transposed too
if (bias->get_output_shape(0).size() > 1) {
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
new_ops.push_back(bias);
}
new_matmul = std::make_shared<ngraph::opset7::Add>(new_matmul, bias);
old_root_node = add;
new_ops.push_back(new_matmul);
}
if (fq != nullptr) {
new_matmul = fq->clone_with_new_inputs({new_matmul, fq->input_value(1), fq->input_value(2),
fq->input_value(3), fq->input_value(4)});
old_root_node = fq;
new_ops.push_back(new_matmul);
}
auto output = create_transpose(new_matmul, matmul_node->get_friendly_name());
new_ops.push_back(output);
ngraph::copy_runtime_info(matmul_node, new_ops);
ngraph::replace_node(old_root_node, output);
}
SwapInputMatMul::SwapInputMatMul() {
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(
ngraph::pattern::has_static_shape()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
ngraph::pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset7::MatMul>(m.get_match_root());
if (!matmul) {
auto constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
auto input_a = matmul->input(0).get_source_output();
auto input_b = matmul->input(1).get_source_output();
ngraph::Shape shape_input_a = input_a.get_shape();
auto create_transpose = [this](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
std::vector<size_t> transpose_order(output_shape.size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
auto transpose = register_new_node<ngraph::opset7::Transpose>(
node, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order));
transpose->set_friendly_name(transpose_name);
return transpose;
};
ngraph::NodeVector new_ops;
// Skip FakeQuantize layers
std::shared_ptr<ngraph::Node> input_a_skip_fq = input_a.get_node_shared_ptr();
if (std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(input_a_skip_fq)) {
input_a_skip_fq = input_a_skip_fq->input_value(0).get_node_shared_ptr();
}
std::shared_ptr<ngraph::Node> input_b_skip_fq = input_b.get_node_shared_ptr();
if (std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(input_b_skip_fq)) {
input_b_skip_fq = input_b_skip_fq->input_value(0).get_node_shared_ptr();
}
if (!std::dynamic_pointer_cast<ngraph::opset7::Constant>(input_a_skip_fq) ||
std::dynamic_pointer_cast<ngraph::opset7::Constant>(input_b_skip_fq)) {
return false;
}
if (shape_input_a[0] < 8 || ((shape_input_a[0] % 8 != 0 || shape_input_a[1] % 8 != 0))) {
return false;
}
gnalog() << "Swap and transpose inputs for " << matmul->get_friendly_name() << "\n";
auto new_matmul = std::make_shared<ngraph::opset7::MatMul>(input_b, input_a, !matmul->get_transpose_b(), !matmul->get_transpose_a());
new_matmul->set_friendly_name(matmul->get_friendly_name() + "/swap_inputs");
new_ops.push_back(new_matmul);
if (!matmul->get_output_target_inputs(0).empty()) {
auto matmul_out = matmul->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
if (std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(matmul_out) != nullptr) {
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, new_matmul);
auto consumers = matmul_out->output(0).get_target_inputs();
auto traspose_output = create_transpose(matmul_out, matmul->get_friendly_name());
for (auto input : consumers) {
input.replace_source_output(traspose_output);
}
return true;
}
}
auto traspose_output = create_transpose(new_matmul, matmul->get_friendly_name());
new_ops.push_back(traspose_output);
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, traspose_output);
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset7::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "SwapInputMatMul");
this->register_matcher(m, callback);
}
SwapInputMatMulWithBias::SwapInputMatMulWithBias() {
auto constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto add = ngraph::pattern::wrap_type<ngraph::opset7::Add>({matmul, bias});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset7::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(),
pattern_map.at(bias).get_node_shared_ptr(), nullptr);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "SwapInputMatMulWithBias");
this->register_matcher(m, callback);
}
SwapInputMatMulWithFq::SwapInputMatMulWithFq() {
auto constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto add = ngraph::pattern::wrap_type<ngraph::opset7::Add>({matmul, bias});
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({matmul_out,
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset7::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
auto add_it = pattern_map.find(add);
auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr());
auto bias_it = pattern_map.find(bias);
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr());
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, "SwapInputMatMulWithFq");
this->register_matcher(m, callback);
}

View File

@@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@@ -16,4 +16,16 @@ public:
NGRAPH_RTTI_DECLARATION;
SwapInputMatMul();
};
class SwapInputMatMulWithBias: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SwapInputMatMulWithBias();
};
class SwapInputMatMulWithFq: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SwapInputMatMulWithFq();
};
} // namespace GNAPluginNS

View File

@@ -0,0 +1,172 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "transformations/swap_input_matmul_gna.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
namespace testing {
static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shape& input1_shape,
const ngraph::Shape& input2_shape,
const ngraph::Shape& bias_shape,
bool withBias,
bool withWeightsFq,
bool withOutFq,
bool swappedInputs) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input2_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, input1_shape, {1});
std::shared_ptr<ngraph::Node> const_input = constant;
if (withWeightsFq) {
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
const_input = std::make_shared<ngraph::opset7::FakeQuantize>(const_input, input_low, input_high,
output_low, output_high, 11);
}
auto matmul = swappedInputs ? std::make_shared<ngraph::opset7::MatMul>(input_params, const_input, true, true) :
std::make_shared<ngraph::opset7::MatMul>(const_input, input_params);
std::shared_ptr<ngraph::Node> final_node = matmul;
if (withBias) {
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, bias_shape, {1});
std::shared_ptr<ngraph::Node> bias_node = bias;
if (swappedInputs && bias_shape.size() > 1) {
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
bias_node = std::make_shared<ngraph::opset7::Transpose>(bias_node, transpose_order);
}
final_node = std::make_shared<ngraph::opset7::Add>(matmul, bias_node);
}
if (withOutFq) {
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
final_node = std::make_shared<ngraph::opset7::FakeQuantize>(final_node, input_low, input_high,
output_low, output_high, 11);
}
if (swappedInputs) {
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
final_node = std::make_shared<ngraph::opset7::Transpose>(final_node, transpose_order);
}
auto result = std::make_shared<ngraph::opset7::Result>(final_node);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
}
static void Execute(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMulWithFq>();
m.register_pass<GNAPluginNS::SwapInputMatMulWithBias>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}
typedef std::tuple<
std::vector<ngraph::Shape>, // constant input shape, non-const input shape, bias shape
bool, // with bias
bool, // with weights FakeQuantize
bool // with output FakeQuantize
> SwapInputMatmulParams;
static std::string getTestCaseName(testing::TestParamInfo<SwapInputMatmulParams> obj) {
std::vector<ngraph::Shape> shapes;
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = obj.param;
std::ostringstream result;
result << "IS1=" << shapes[0] << "_";
result << "IS2=" << shapes[1] << "_";
result << "BS=" << shapes[2] << "_";
result << "bias=" << withBias << "_";
result << "wFQ=" << withWeightsFq << "_";
result << "oFQ=" << withOutFq;
return result.str();
}
class SwapInputMatmul : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<SwapInputMatmulParams> {
public:
void SetUp() override {
std::vector<ngraph::Shape> shapes;
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq,
withOutFq, true);
}
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<SwapInputMatmulParams> {
public:
void SetUp() override {
std::vector<ngraph::Shape> shapes;
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
reference_function = ngraph::clone_function(*function);
}
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
TEST_P(SwapInputMatmul, CompareFunctions) {
Execute(function, reference_function);
}
TEST_P(SwapInputMatmulNotApplied, CompareFunctions) {
Execute(function, reference_function);
}
const std::vector<std::vector<ngraph::Shape>> input_shapes_applied = {
{{16, 8}, {8, 8}, {16, 8}},
{{16, 8}, {8, 8}, {1}}
};
const std::vector<std::vector<ngraph::Shape>> input_shapes_not_applied = {
{{1, 8}, {8, 8}, {1, 8}},
{{8}, {8, 8}, {8}}
};
INSTANTIATE_TEST_CASE_P(smoke_swap_input_matmul, SwapInputMatmul,
::testing::Combine(
::testing::ValuesIn(input_shapes_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),
getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied,
::testing::Combine(
::testing::ValuesIn(input_shapes_not_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),
getTestCaseName);
} // namespace testing