[GNA] Support bias and FQ in SwapInputMatMul transformation (#6996)
* [GNA] Support bias and FQ in SwapInputMatMul transformation * Updated opset for transformation and removed debug info
This commit is contained in:
parent
ff500b0bed
commit
0834ae2e6d
@ -704,6 +704,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
manager.register_pass<SplitConvolution>();
|
||||
manager.register_pass<HandleTransposesAroundMatMul>();
|
||||
manager.register_pass<SwapInputMatMul>();
|
||||
manager.register_pass<SwapInputMatMulWithBias>();
|
||||
manager.register_pass<SwapInputMatMulWithFq>();
|
||||
manager.register_pass<InsertTransposeAfterConvOrPool>();
|
||||
manager.register_pass<ReorderActivationAndPooling>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||
|
@ -9,7 +9,7 @@
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <numeric>
|
||||
@ -20,75 +20,163 @@
|
||||
using namespace GNAPluginNS;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
|
||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0);
|
||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0);
|
||||
|
||||
static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
|
||||
std::shared_ptr<ngraph::Node> add,
|
||||
std::shared_ptr<ngraph::Node> bias,
|
||||
std::shared_ptr<ngraph::Node> fq) {
|
||||
auto create_transpose =
|
||||
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
||||
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
|
||||
|
||||
std::vector<size_t> transpose_order(output_shape.size());
|
||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
|
||||
|
||||
auto transpose = std::make_shared<ngraph::opset8::Transpose>(
|
||||
node, ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order));
|
||||
transpose->set_friendly_name(transpose_name);
|
||||
return transpose;
|
||||
};
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
|
||||
gnalog() << "Swap and transpose inputs for " << matmul_node->get_friendly_name() << "\n";
|
||||
std::shared_ptr<ngraph::Node> new_matmul = std::make_shared<ngraph::opset8::MatMul>(
|
||||
matmul_node->input(1).get_source_output(), matmul_node->input(0).get_source_output(),
|
||||
!matmul_node->get_transpose_b(), !matmul_node->get_transpose_a());
|
||||
new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/swap_inputs");
|
||||
new_ops.push_back(new_matmul);
|
||||
|
||||
std::shared_ptr<ngraph::Node> old_root_node = matmul_node;
|
||||
if (bias != nullptr) {
|
||||
// output of MatMul will be transposed comparing with original one, so the bias should be transposed too
|
||||
if (bias->get_output_shape(0).size() > 1) {
|
||||
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
|
||||
new_ops.push_back(bias);
|
||||
}
|
||||
|
||||
new_matmul = std::make_shared<ngraph::opset8::Add>(new_matmul, bias);
|
||||
old_root_node = add;
|
||||
new_ops.push_back(new_matmul);
|
||||
}
|
||||
|
||||
if (fq != nullptr) {
|
||||
new_matmul = fq->clone_with_new_inputs({new_matmul, fq->input_value(1), fq->input_value(2),
|
||||
fq->input_value(3), fq->input_value(4)});
|
||||
old_root_node = fq;
|
||||
new_ops.push_back(new_matmul);
|
||||
}
|
||||
|
||||
auto output = create_transpose(new_matmul, matmul_node->get_friendly_name());
|
||||
new_ops.push_back(output);
|
||||
|
||||
ngraph::copy_runtime_info(matmul_node, new_ops);
|
||||
ngraph::replace_node(old_root_node, output);
|
||||
}
|
||||
|
||||
SwapInputMatMul::SwapInputMatMul() {
|
||||
MATCHER_SCOPE(SwapInputMatMul);
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>({}, ngraph::pattern::rank_equals(2));
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
||||
ngraph::pattern::has_static_shape());
|
||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
|
||||
auto matmul = std::dynamic_pointer_cast<ngraph::opset7::MatMul>(m.get_match_root());
|
||||
if (!matmul) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_a = matmul->input(0).get_source_output();
|
||||
auto input_b = matmul->input(1).get_source_output();
|
||||
|
||||
ngraph::Shape shape_input_a = input_a.get_shape();
|
||||
|
||||
auto create_transpose = [this](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
||||
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
|
||||
|
||||
std::vector<size_t> transpose_order(output_shape.size());
|
||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
|
||||
|
||||
auto transpose = register_new_node<ngraph::opset7::Transpose>(
|
||||
node, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order));
|
||||
transpose->set_friendly_name(transpose_name);
|
||||
return transpose;
|
||||
};
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
|
||||
if (shape_input_a[0] < 8 || ((shape_input_a[0] % 8 != 0 || shape_input_a[1] % 8 != 0))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
gnalog() << "Swap and transpose inputs for " << matmul->get_friendly_name() << "\n";
|
||||
auto new_matmul = std::make_shared<ngraph::opset7::MatMul>(input_b, input_a, !matmul->get_transpose_b(), !matmul->get_transpose_a());
|
||||
new_matmul->set_friendly_name(matmul->get_friendly_name() + "/swap_inputs");
|
||||
new_ops.push_back(new_matmul);
|
||||
|
||||
if (!matmul->get_output_target_inputs(0).empty()) {
|
||||
auto matmul_out = matmul->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
if (std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(matmul_out) != nullptr) {
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
ngraph::replace_node(matmul, new_matmul);
|
||||
auto consumers = matmul_out->output(0).get_target_inputs();
|
||||
auto traspose_output = create_transpose(matmul_out, matmul->get_friendly_name());
|
||||
for (auto input : consumers) {
|
||||
input.replace_source_output(traspose_output);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
auto traspose_output = create_transpose(new_matmul, matmul->get_friendly_name());
|
||||
new_ops.push_back(traspose_output);
|
||||
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
ngraph::replace_node(matmul, traspose_output);
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
SwapInputMatMulWithBias::SwapInputMatMulWithBias() {
|
||||
MATCHER_SCOPE(SwapInputMatMulWithBias);
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
||||
ngraph::pattern::has_static_shape());
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(),
|
||||
pattern_map.at(bias).get_node_shared_ptr(), nullptr);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
SwapInputMatMulWithFq::SwapInputMatMulWithFq() {
|
||||
MATCHER_SCOPE(SwapInputMatMulWithFq);
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
||||
ngraph::pattern::has_static_shape());
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
||||
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
|
||||
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul_out,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
auto add_it = pattern_map.find(add);
|
||||
auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr());
|
||||
auto bias_it = pattern_map.find(bias);
|
||||
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
|
||||
SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr());
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -16,4 +16,16 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SwapInputMatMul();
|
||||
};
|
||||
|
||||
class SwapInputMatMulWithBias: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SwapInputMatMulWithBias();
|
||||
};
|
||||
|
||||
class SwapInputMatMulWithFq: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SwapInputMatMulWithFq();
|
||||
};
|
||||
} // namespace GNAPluginNS
|
@ -8,164 +8,165 @@
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
namespace testing {
|
||||
|
||||
TEST(TransformationTests, SwapInputMatMulTestValidConstShape) {
|
||||
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||
const ngraph::Shape data_shape{8, 8};
|
||||
static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shape& input1_shape,
|
||||
const ngraph::Shape& input2_shape,
|
||||
const ngraph::Shape& bias_shape,
|
||||
bool withBias,
|
||||
bool withWeightsFq,
|
||||
bool withOutFq,
|
||||
bool swappedInputs) {
|
||||
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input2_shape);
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
|
||||
|
||||
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1, 8}, {1});
|
||||
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
|
||||
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
|
||||
reference_func = ngraph::clone_function(*func);
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::SwapInputMatMul>();
|
||||
m.run_passes(func);
|
||||
ASSERT_NO_THROW(check_rt_info(func));
|
||||
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, input1_shape, {1});
|
||||
std::shared_ptr<ngraph::Node> const_input = constant;
|
||||
if (withWeightsFq) {
|
||||
auto input_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
|
||||
auto input_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
|
||||
auto output_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
|
||||
auto output_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
|
||||
const_input = std::make_shared<ngraph::opset8::FakeQuantize>(const_input, input_low, input_high,
|
||||
output_low, output_high, 11);
|
||||
}
|
||||
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, true, true) :
|
||||
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params);
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwapInputMatMulTest) {
|
||||
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||
const ngraph::Shape data_shape{8, 8};
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
|
||||
|
||||
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
|
||||
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
|
||||
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::SwapInputMatMul>();
|
||||
m.run_passes(func);
|
||||
ASSERT_NO_THROW(check_rt_info(func));
|
||||
}
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
|
||||
|
||||
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
|
||||
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(input_params, constant, 1, 1);
|
||||
|
||||
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
std::vector<size_t>{1, 0});
|
||||
auto transpose_operation = std::make_shared<ngraph::opset7::Transpose>(matmul_operation, transpose_order);
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(transpose_operation);
|
||||
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwapInputMatMulTestFakeQuantize) {
|
||||
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||
const ngraph::Shape data_shape{8, 8};
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
|
||||
|
||||
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
|
||||
|
||||
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
|
||||
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
|
||||
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
|
||||
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
|
||||
auto fake_quantize_op = std::make_shared<ngraph::opset7::FakeQuantize>(constant, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 11);
|
||||
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(fake_quantize_op, input_params);
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
|
||||
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::SwapInputMatMul>();
|
||||
m.run_passes(func);
|
||||
ASSERT_NO_THROW(check_rt_info(func));
|
||||
}
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
|
||||
|
||||
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
|
||||
|
||||
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
|
||||
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
|
||||
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
|
||||
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
|
||||
auto fake_quantize_op = std::make_shared<ngraph::opset7::FakeQuantize>(constant, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 11);
|
||||
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(input_params, fake_quantize_op, 1 , 1);
|
||||
|
||||
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
std::shared_ptr<ngraph::Node> final_node = matmul;
|
||||
if (withBias) {
|
||||
auto bias = ngraph::opset8::Constant::create(ngraph::element::i64, bias_shape, {1});
|
||||
std::shared_ptr<ngraph::Node> bias_node = bias;
|
||||
if (swappedInputs && bias_shape.size() > 1) {
|
||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
std::vector<size_t>{1, 0});
|
||||
auto transpose_operation = std::make_shared<ngraph::opset7::Transpose>(matmul_operation, transpose_order);
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(transpose_operation);
|
||||
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
bias_node = std::make_shared<ngraph::opset8::Transpose>(bias_node, transpose_order);
|
||||
}
|
||||
final_node = std::make_shared<ngraph::opset8::Add>(matmul, bias_node);
|
||||
}
|
||||
|
||||
if (withOutFq) {
|
||||
auto input_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
|
||||
auto input_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
|
||||
auto output_low = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
|
||||
auto output_high = ngraph::opset8::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
|
||||
final_node = std::make_shared<ngraph::opset8::FakeQuantize>(final_node, input_low, input_high,
|
||||
output_low, output_high, 11);
|
||||
}
|
||||
|
||||
if (swappedInputs) {
|
||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
std::vector<size_t>{1, 0});
|
||||
final_node = std::make_shared<ngraph::opset8::Transpose>(final_node, transpose_order);
|
||||
}
|
||||
|
||||
auto result = std::make_shared<ngraph::opset8::Result>(final_node);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
static void Execute(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::SwapInputMatMulWithFq>();
|
||||
m.register_pass<GNAPluginNS::SwapInputMatMulWithBias>();
|
||||
m.register_pass<GNAPluginNS::SwapInputMatMul>();
|
||||
m.run_passes(function);
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwapInputMatMulTestRank1) {
|
||||
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
|
||||
const ngraph::Shape data_shape{8, 8};
|
||||
typedef std::tuple<
|
||||
std::vector<ngraph::Shape>, // constant input shape, non-const input shape, bias shape
|
||||
bool, // with bias
|
||||
bool, // with weights FakeQuantize
|
||||
bool // with output FakeQuantize
|
||||
> SwapInputMatmulParams;
|
||||
|
||||
{
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
|
||||
static std::string getTestCaseName(testing::TestParamInfo<SwapInputMatmulParams> obj) {
|
||||
std::vector<ngraph::Shape> shapes;
|
||||
bool withBias, withWeightsFq, withOutFq;
|
||||
std::tie(shapes, withBias, withWeightsFq, withOutFq) = obj.param;
|
||||
|
||||
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{8}, {1});
|
||||
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
|
||||
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
|
||||
reference_func = ngraph::clone_function(*func);
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::SwapInputMatMul>();
|
||||
m.run_passes(func);
|
||||
ASSERT_NO_THROW(check_rt_info(func));
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||
ASSERT_TRUE(result.valid);
|
||||
std::ostringstream result;
|
||||
result << "IS1=" << shapes[0] << "_";
|
||||
result << "IS2=" << shapes[1] << "_";
|
||||
result << "BS=" << shapes[2] << "_";
|
||||
result << "bias=" << withBias << "_";
|
||||
result << "wFQ=" << withWeightsFq << "_";
|
||||
result << "oFQ=" << withOutFq;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
class SwapInputMatmul : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
std::vector<ngraph::Shape> shapes;
|
||||
bool withBias, withWeightsFq, withOutFq;
|
||||
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
||||
|
||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
|
||||
reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq,
|
||||
withOutFq, true);
|
||||
}
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
};
|
||||
|
||||
class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
std::vector<ngraph::Shape> shapes;
|
||||
bool withBias, withWeightsFq, withOutFq;
|
||||
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
||||
|
||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
|
||||
reference_function = ngraph::clone_function(*function);
|
||||
}
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
};
|
||||
|
||||
TEST_P(SwapInputMatmul, CompareFunctions) {
|
||||
Execute(function, reference_function);
|
||||
}
|
||||
|
||||
TEST_P(SwapInputMatmulNotApplied, CompareFunctions) {
|
||||
Execute(function, reference_function);
|
||||
}
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_applied = {
|
||||
{{16, 8}, {8, 8}, {16, 8}},
|
||||
{{16, 8}, {8, 8}, {1}},
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_not_applied = {
|
||||
{{1, 8}, {8, 8}, {1, 8}},
|
||||
{{8}, {8, 8}, {8}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes_applied),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||
getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes_not_applied),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||
getTestCaseName);
|
||||
|
||||
} // namespace testing
|
||||
|
Loading…
Reference in New Issue
Block a user