[GNA] Support bias and FQ in SwapInputMatMul transformation (#6996)

* [GNA] Support bias and FQ in SwapInputMatMul transformation

* Updated opset for transformation and removed debug info
This commit is contained in:
Elizaveta Lobanova 2021-08-11 10:10:33 +03:00 committed by GitHub
parent ff500b0bed
commit 0834ae2e6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 307 additions and 204 deletions

View File

@ -704,6 +704,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<SplitConvolution>();
manager.register_pass<HandleTransposesAroundMatMul>();
manager.register_pass<SwapInputMatMul>();
manager.register_pass<SwapInputMatMulWithBias>();
manager.register_pass<SwapInputMatMulWithFq>();
manager.register_pass<InsertTransposeAfterConvOrPool>();
manager.register_pass<ReorderActivationAndPooling>();
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();

View File

@ -9,7 +9,7 @@
#include <ngraph/pass/manager.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <numeric>
@ -20,75 +20,163 @@
using namespace GNAPluginNS;
NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0);
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0);
static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
std::shared_ptr<ngraph::Node> add,
std::shared_ptr<ngraph::Node> bias,
std::shared_ptr<ngraph::Node> fq) {
auto create_transpose =
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
std::vector<size_t> transpose_order(output_shape.size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
auto transpose = std::make_shared<ngraph::opset8::Transpose>(
node, ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order));
transpose->set_friendly_name(transpose_name);
return transpose;
};
ngraph::NodeVector new_ops;
gnalog() << "Swap and transpose inputs for " << matmul_node->get_friendly_name() << "\n";
std::shared_ptr<ngraph::Node> new_matmul = std::make_shared<ngraph::opset8::MatMul>(
matmul_node->input(1).get_source_output(), matmul_node->input(0).get_source_output(),
!matmul_node->get_transpose_b(), !matmul_node->get_transpose_a());
new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/swap_inputs");
new_ops.push_back(new_matmul);
std::shared_ptr<ngraph::Node> old_root_node = matmul_node;
if (bias != nullptr) {
// output of MatMul will be transposed comparing with original one, so the bias should be transposed too
if (bias->get_output_shape(0).size() > 1) {
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
new_ops.push_back(bias);
}
new_matmul = std::make_shared<ngraph::opset8::Add>(new_matmul, bias);
old_root_node = add;
new_ops.push_back(new_matmul);
}
if (fq != nullptr) {
new_matmul = fq->clone_with_new_inputs({new_matmul, fq->input_value(1), fq->input_value(2),
fq->input_value(3), fq->input_value(4)});
old_root_node = fq;
new_ops.push_back(new_matmul);
}
auto output = create_transpose(new_matmul, matmul_node->get_friendly_name());
new_ops.push_back(output);
ngraph::copy_runtime_info(matmul_node, new_ops);
ngraph::replace_node(old_root_node, output);
}
SwapInputMatMul::SwapInputMatMul() {
MATCHER_SCOPE(SwapInputMatMul);
auto constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>({}, ngraph::pattern::rank_equals(2));
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()},
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset7::MatMul>(m.get_match_root());
if (!matmul) {
return false;
}
auto input_a = matmul->input(0).get_source_output();
auto input_b = matmul->input(1).get_source_output();
ngraph::Shape shape_input_a = input_a.get_shape();
auto create_transpose = [this](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
std::vector<size_t> transpose_order(output_shape.size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
auto transpose = register_new_node<ngraph::opset7::Transpose>(
node, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order));
transpose->set_friendly_name(transpose_name);
return transpose;
};
ngraph::NodeVector new_ops;
if (shape_input_a[0] < 8 || ((shape_input_a[0] % 8 != 0 || shape_input_a[1] % 8 != 0))) {
return false;
}
gnalog() << "Swap and transpose inputs for " << matmul->get_friendly_name() << "\n";
auto new_matmul = std::make_shared<ngraph::opset7::MatMul>(input_b, input_a, !matmul->get_transpose_b(), !matmul->get_transpose_a());
new_matmul->set_friendly_name(matmul->get_friendly_name() + "/swap_inputs");
new_ops.push_back(new_matmul);
if (!matmul->get_output_target_inputs(0).empty()) {
auto matmul_out = matmul->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
if (std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(matmul_out) != nullptr) {
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, new_matmul);
auto consumers = matmul_out->output(0).get_target_inputs();
auto traspose_output = create_transpose(matmul_out, matmul->get_friendly_name());
for (auto input : consumers) {
input.replace_source_output(traspose_output);
}
return true;
}
}
auto traspose_output = create_transpose(new_matmul, matmul->get_friendly_name());
new_ops.push_back(traspose_output);
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, traspose_output);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, matcher_name);
this->register_matcher(m, callback);
}
SwapInputMatMulWithBias::SwapInputMatMulWithBias() {
MATCHER_SCOPE(SwapInputMatMulWithBias);
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(),
pattern_map.at(bias).get_node_shared_ptr(), nullptr);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
this->register_matcher(m, callback);
}
SwapInputMatMulWithFq::SwapInputMatMulWithFq() {
MATCHER_SCOPE(SwapInputMatMulWithFq);
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul_out,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
auto add_it = pattern_map.find(add);
auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr());
auto bias_it = pattern_map.find(bias);
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr());
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -16,4 +16,16 @@ public:
NGRAPH_RTTI_DECLARATION;
SwapInputMatMul();
};
class SwapInputMatMulWithBias: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SwapInputMatMulWithBias();
};
class SwapInputMatMulWithFq: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SwapInputMatMulWithFq();
};
} // namespace GNAPluginNS

View File

@ -8,164 +8,165 @@
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
namespace testing {
TEST(TransformationTests, SwapInputMatMulTestValidConstShape) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shape& input1_shape,
const ngraph::Shape& input2_shape,
const ngraph::Shape& bias_shape,
bool withBias,
bool withWeightsFq,
bool withOutFq,
bool swappedInputs) {
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input2_shape);
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1, 8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
reference_func = ngraph::clone_function(*func);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, input1_shape, {1});
std::shared_ptr<ngraph::Node> const_input = constant;
if (withWeightsFq) {
auto input_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
const_input = std::make_shared<ngraph::opset8::FakeQuantize>(const_input, input_low, input_high,
output_low, output_high, 11);
}
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, true, true) :
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
TEST(TransformationTests, SwapInputMatMulTest) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(input_params, constant, 1, 1);
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
auto transpose_operation = std::make_shared<ngraph::opset7::Transpose>(matmul_operation, transpose_order);
auto result = std::make_shared<ngraph::opset7::Result>(transpose_operation);
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
TEST(TransformationTests, SwapInputMatMulTestFakeQuantize) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
auto fake_quantize_op = std::make_shared<ngraph::opset7::FakeQuantize>(constant, input_low,
input_high, output_low,
output_high, 11);
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(fake_quantize_op, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
auto fake_quantize_op = std::make_shared<ngraph::opset7::FakeQuantize>(constant, input_low,
input_high, output_low,
output_high, 11);
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(input_params, fake_quantize_op, 1 , 1);
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::shared_ptr<ngraph::Node> final_node = matmul;
if (withBias) {
auto bias = ngraph::opset8::Constant::create(ngraph::element::i64, bias_shape, {1});
std::shared_ptr<ngraph::Node> bias_node = bias;
if (swappedInputs && bias_shape.size() > 1) {
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
auto transpose_operation = std::make_shared<ngraph::opset7::Transpose>(matmul_operation, transpose_order);
auto result = std::make_shared<ngraph::opset7::Result>(transpose_operation);
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
bias_node = std::make_shared<ngraph::opset8::Transpose>(bias_node, transpose_order);
}
final_node = std::make_shared<ngraph::opset8::Add>(matmul, bias_node);
}
if (withOutFq) {
auto input_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
final_node = std::make_shared<ngraph::opset8::FakeQuantize>(final_node, input_low, input_high,
output_low, output_high, 11);
}
if (swappedInputs) {
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
final_node = std::make_shared<ngraph::opset8::Transpose>(final_node, transpose_order);
}
auto result = std::make_shared<ngraph::opset8::Result>(final_node);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
}
static void Execute(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMulWithFq>();
m.register_pass<GNAPluginNS::SwapInputMatMulWithBias>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}
TEST(TransformationTests, SwapInputMatMulTestRank1) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
typedef std::tuple<
std::vector<ngraph::Shape>, // constant input shape, non-const input shape, bias shape
bool, // with bias
bool, // with weights FakeQuantize
bool // with output FakeQuantize
> SwapInputMatmulParams;
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
static std::string getTestCaseName(testing::TestParamInfo<SwapInputMatmulParams> obj) {
std::vector<ngraph::Shape> shapes;
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = obj.param;
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
reference_func = ngraph::clone_function(*func);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
std::ostringstream result;
result << "IS1=" << shapes[0] << "_";
result << "IS2=" << shapes[1] << "_";
result << "BS=" << shapes[2] << "_";
result << "bias=" << withBias << "_";
result << "wFQ=" << withWeightsFq << "_";
result << "oFQ=" << withOutFq;
return result.str();
}
class SwapInputMatmul : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<SwapInputMatmulParams> {
public:
void SetUp() override {
std::vector<ngraph::Shape> shapes;
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq,
withOutFq, true);
}
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<SwapInputMatmulParams> {
public:
void SetUp() override {
std::vector<ngraph::Shape> shapes;
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
reference_function = ngraph::clone_function(*function);
}
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
TEST_P(SwapInputMatmul, CompareFunctions) {
Execute(function, reference_function);
}
TEST_P(SwapInputMatmulNotApplied, CompareFunctions) {
Execute(function, reference_function);
}
const std::vector<std::vector<ngraph::Shape>> input_shapes_applied = {
{{16, 8}, {8, 8}, {16, 8}},
{{16, 8}, {8, 8}, {1}},
};
const std::vector<std::vector<ngraph::Shape>> input_shapes_not_applied = {
{{1, 8}, {8, 8}, {1, 8}},
{{8}, {8, 8}, {8}}
};
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmul,
::testing::Combine(
::testing::ValuesIn(input_shapes_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),
getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied,
::testing::Combine(
::testing::ValuesIn(input_shapes_not_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),
getTestCaseName);
} // namespace testing