[GNA] Expanding transformations: swap_input_matmul and handle_transposes_around_matmul (#7333)

* Expanding transformations: swap_input_matmul and handle_transposes_around_matmul

* insert_reshape_around_matmul

* fixed failed of smoke tests
This commit is contained in:
Dmitrii Khurtin 2021-09-14 13:39:33 +03:00 committed by GitHub
parent 39120a7f62
commit ba34a1989c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 934 additions and 215 deletions

View File

@ -66,6 +66,7 @@
#include "transformations/handle_transposes_around_matmul.hpp"
#include "transformations/decompose_2d_conv.hpp"
#include "transformations/convert_padded2valid_conv.hpp"
#include "transformations/insert_reshape_around_matmul.hpp"
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
#include "transformations/remove_single_input_concat.hpp"
@ -730,10 +731,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<SplitConvolutionWithFq>();
manager.register_pass<SplitConvolutionWithBias>();
manager.register_pass<SplitConvolution>();
manager.register_pass<HandleTransposesAroundMatMul>();
manager.register_pass<InsertReshapeAroundMatmulWithTranspose>();
manager.register_pass<InsertReshapeAroundMatmulWithFq>();
manager.register_pass<InsertReshapeAroundMatmulWithAdd>();
manager.register_pass<InsertReshapeAroundMatmul>();
manager.register_pass<SwapInputMatMulWithFq>();
manager.register_pass<SwapInputMatMulWithBias>();
manager.register_pass<SwapInputMatMul>();
manager.register_pass<HandleTransposesAroundMatMul>();
manager.register_pass<InsertTransposeAfterConvOrPool>();
manager.register_pass<ReorderActivationAndPooling>();
manager.register_pass<RemoveSingleInputConcat>();

View File

@ -6,31 +6,33 @@
#include <numeric>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <openvino/cc/ngraph/itt.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ie/ie_common.h>
#include "gna_plugin_log.hpp"
#include "backend/gna_limitations.hpp"
using namespace GNAPluginNS;
namespace GNAPluginNS {
NGRAPH_RTTI_DEFINITION(HandleTransposesAroundMatMul, "HandleTransposesAroundMatMul", 0);
NGRAPH_RTTI_DEFINITION(HandleTransposeBeforeMatMul, "HandleTransposeBeforeMatMul", 0);
NGRAPH_RTTI_DEFINITION(HandleTransposeAfterMatMul, "HandleTransposeAfterMatMul", 0);
static void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
auto shape = transpose_node->get_output_shape(0);
auto reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
auto reshape_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{shape.size()}, shape);
auto reshape_node = std::make_shared<ngraph::opset7::Reshape>(transpose_node->input_value(0), reshape_const, false);
reshape_node->set_friendly_name(transpose_node->get_friendly_name() + "/reshape");
auto reshape_node = std::make_shared<ngraph::opset8::Reshape>(transpose_node->input_value(0), reshape_const, false);
reshape_node->set_friendly_name(transpose_node->get_friendly_name());
ngraph::copy_runtime_info(transpose_node, reshape_node);
transpose_node->output(0).replace(reshape_node->output(0));
}
static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
auto consumers = prev_node->output(0).get_target_inputs();
const auto orig_shape = prev_node->get_output_shape(0);
std::vector<size_t> transpose_ids;
@ -44,13 +46,13 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
std::iota(std::begin(permute_order), std::end(permute_order), 0);
std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]);
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
auto transpose = std::make_shared<ngraph::opset7::Transpose>(prev_node, transpose_order);
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
auto transpose = std::make_shared<ngraph::opset8::Transpose>(prev_node, transpose_order);
transpose->set_friendly_name(base_name + "/in_transpose");
auto reshapeConstAfter = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
auto reshapeConstAfter = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{orig_shape.size()}, orig_shape);
auto reshapeAfter = std::make_shared<ngraph::opset7::Reshape>(transpose, reshapeConstAfter, false);
auto reshapeAfter = std::make_shared<ngraph::opset8::Reshape>(transpose, reshapeConstAfter, false);
reshapeAfter->set_friendly_name(base_name + "/reshape_after_transpose");
ngraph::copy_runtime_info(prev_node, ngraph::NodeVector{transpose, reshapeAfter});
@ -59,74 +61,102 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
}
}
static bool VerifyReshape(const ngraph::Output<ngraph::Node>& reshape_out) {
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);
return in_shape[0] != out_shape[0];
}
HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({ngraph::pattern::any_input(),
ngraph::pattern::any_input()}, VerifyReshape());
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({reshape,
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({}, VerifyReshape);
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({reshape,
ngraph::pattern::any_input()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose});
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()});
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), matmul_input});
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose}),
ngraph::pattern::any_input()});
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fq}),
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose, ngraph::pattern::any_input()})});
auto matmul = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
const auto& pattern_map = matcher.get_pattern_value_map();
auto matmul_iter = pattern_map.find(matmul1);
if (matmul_iter == std::end(pattern_map) &&
(matmul_iter = pattern_map.find(matmul2)) == std::end(pattern_map)) {
return false;
}
auto transpose_reshape_it = pattern_map.find(transpose);
if (transpose_reshape_it != std::end(pattern_map)) {
ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr());
} else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) {
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) {
auto matmul_node = matmul_iter->second.get_node_shared_ptr();
InsertTranspose(reshape_node, matmul_node->get_friendly_name());
}
}
auto iter = pattern_map.find(fq);
if (iter != pattern_map.end() ||
(iter = pattern_map.find(constant)) != pattern_map.end()) {
auto prev_node = iter->second.get_node_shared_ptr();
if (!GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) return false;
auto matmul_node = iter->second.get_node_shared_ptr();
InsertTranspose(prev_node, matmul_node->get_friendly_name());
}
return true;
};
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
this->register_matcher(matcher, callback);
}
HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>();
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(
{reshape_input, ngraph::pattern::any_input()}, VerifyReshape);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
const auto& pattern_map = matcher.get_pattern_value_map();
auto transpose_it = pattern_map.find(transpose);
if (transpose_it != std::end(pattern_map)) {
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
} else {
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
auto matmul_it = pattern_map.find(matmul1);
auto matmul_out = matmul_it != std::end(pattern_map) ? matmul_it->second : pattern_map.at(matmul2);
InsertTranspose(reshape_node, matmul_out.get_node_shared_ptr()->get_friendly_name());
auto iter = pattern_map.find(fq);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
return false;
}
auto node = iter->second.get_node_shared_ptr();
InsertTranspose(node, node->get_friendly_name());
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
this->register_matcher(m, callback);
}
HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>();
auto fq = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({matmul, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, fq});
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({transpose_input, ngraph::pattern::any_input()});
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({reshape_input,
ngraph::pattern::any_input()}, VerifyReshape());
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map();
auto transpose_it = pattern_map.find(transpose);
if (transpose_it != std::end(pattern_map)) {
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
} else {
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
auto matmul_node = pattern_map.at(matmul).get_node_shared_ptr();
InsertTranspose(matmul_node, matmul_node->get_friendly_name());
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
this->register_matcher(m, callback);
}
bool VerifyReshape::operator()(const ngraph::Output<ngraph::Node>& reshape_out) const {
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);
// Check if Reshape changes the final 2d shape of Affine primitive
in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), in_shape.end());
out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), out_shape.end());
return in_shape != out_shape;
auto matcher = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
this->register_matcher(matcher, callback);
}
HandleTransposesAroundMatMul::HandleTransposesAroundMatMul() {
add_matcher<HandleTransposeBeforeMatMul>();
add_matcher<HandleTransposeAfterMatMul>();
}
} // namespace GNAPluginNS

View File

@ -8,10 +8,6 @@
namespace GNAPluginNS {
struct VerifyReshape {
bool operator()(const ngraph::Output<ngraph::Node>& reshape_out) const;
};
/**
* @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape
* before MatMul which changes the batch size:
@ -48,13 +44,13 @@ public:
* | |
* [1, A*B] [1, A*B]
*/
class HandleTransposeAfterMatMul : public ngraph::pass::MatcherPass {
class HandleTransposeAfterMatMul: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HandleTransposeAfterMatMul();
};
class HandleTransposesAroundMatMul: public ngraph::pass::GraphRewrite {
class HandleTransposesAroundMatMul : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
HandleTransposesAroundMatMul();

View File

@ -0,0 +1,237 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/insert_reshape_around_matmul.hpp"
#include <openvino/cc/ngraph/itt.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ie/ie_common.h>
#include "gna_plugin_log.hpp"
namespace GNAPluginNS {
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmul, "InsertReshapeAroundMatmul", 0);
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithAdd, "InsertReshapeAroundMatmulWithAdd", 0);
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithFq, "InsertReshapeAroundMatmulWithFq", 0);
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithTranspose, "InsertReshapeAroundMatmulWithTranspose", 0);
static bool InsertReshape(
ngraph::pattern::Matcher &matcher,
const std::shared_ptr<ngraph::Node>& input,
const std::shared_ptr<ngraph::Node>& matmul1,
const std::shared_ptr<ngraph::Node>& matmul2,
const std::shared_ptr<ngraph::Node>& add1 = nullptr,
const std::shared_ptr<ngraph::Node>& add2 = nullptr,
const std::shared_ptr<ngraph::Node>& fake_quantize2 = nullptr,
const std::shared_ptr<ngraph::Node>& transpose = nullptr) {
const auto& pattern_map = matcher.get_pattern_value_map();
size_t matmul_input_index = 1;
auto iter = pattern_map.find(matmul1);
if (iter == pattern_map.end()) {
iter = pattern_map.find(matmul2);
if ((iter = pattern_map.find(matmul2)) == pattern_map.end()) {
return false;
}
matmul_input_index = 0;
}
std::shared_ptr<ngraph::Node> matmul_node = iter->second.get_node_shared_ptr();
auto matmul_node_shape = matmul_node->get_output_shape(0);
if ((iter = pattern_map.find(input)) == std::end(pattern_map)) {
return false;
}
std::shared_ptr<ngraph::Node> first_node = iter->second.get_node_shared_ptr();
auto reshape_input_node = std::dynamic_pointer_cast<ngraph::opset8::Reshape>(first_node);
bool need_reshape_before = !reshape_input_node || reshape_input_node->get_output_shape(0).size() != 2;
if (need_reshape_before) {
auto input_shape = first_node->get_output_shape(0);
std::vector<size_t> before_shape(2, 1);
std::copy_if(input_shape.begin(), input_shape.end(), before_shape.begin(), [](size_t e) { return e > 1; });
auto reshape_before_node = std::make_shared<ngraph::opset8::Reshape>(first_node,
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{before_shape.size()}, before_shape), false);
reshape_before_node->set_friendly_name(matmul_node->get_friendly_name() + "/reshape_before_matmul");
ngraph::copy_runtime_info(first_node, reshape_before_node);
matmul_node->input(matmul_input_index).replace_source_output(reshape_before_node->output(0));
}
std::shared_ptr<ngraph::Node> last_node;
iter = pattern_map.find(transpose);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(fake_quantize2)) == pattern_map.end() &&
(iter = pattern_map.find(add1)) == pattern_map.end() &&
(iter = pattern_map.find(add2)) == pattern_map.end()) {
last_node = matmul_node;
} else {
last_node = iter->second.get_node_shared_ptr();
}
auto consumers = last_node->output(0).get_target_inputs();
auto last_node_shape = last_node->get_output_shape(0);
bool need_reshape_after = false;
for (auto consumer : consumers) {
auto reshape_output_node = dynamic_cast<ngraph::opset8::Reshape*>(consumer.get_node());
if (!reshape_output_node || reshape_output_node->get_output_shape(0).size() != last_node_shape.size()) {
need_reshape_after = true;
break;
}
}
if (need_reshape_after) {
auto reshape_after_node = std::make_shared<ngraph::opset8::Reshape>(last_node,
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{last_node_shape.size()}, last_node_shape), false);
reshape_after_node->set_friendly_name(last_node->get_friendly_name());
ngraph::copy_runtime_info(last_node, reshape_after_node);
for (auto consumer : consumers) {
consumer.replace_source_output(reshape_after_node);
}
}
return need_reshape_before || need_reshape_after;
}
static std::shared_ptr<ngraph::Node> CreateMatmulPattern(
std::shared_ptr<ngraph::Node>& input,
std::shared_ptr<ngraph::Node>& matmul1,
std::shared_ptr<ngraph::Node>& matmul2,
const ngraph::pattern::op::ValuePredicate& pred = [](const ngraph::Output<ngraph::Node>& output) { return true; }) {
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
input = ngraph::pattern::any_input([](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
return shape.size() > 2 && std::count_if(shape.begin(), shape.end(), [](size_t e) { return e > 1; }) <= 2; });
matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, input}, pred);
matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({input, matmul_input}, pred);
return std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
}
InsertReshapeAroundMatmul::InsertReshapeAroundMatmul() {
MATCHER_SCOPE(InsertReshapeAroundMatmul);
auto pred = [](const ngraph::Output<ngraph::Node>& node) {
const auto& outputs = node.get_node_shared_ptr()->outputs();
const auto& inputs = outputs[0].get_target_inputs();
if (inputs.empty()) {
return true;
}
auto next_node = inputs.begin()->get_node();
return outputs.size() != 1 ||
!dynamic_cast<ngraph::opset8::Transpose*>(next_node) &&
!dynamic_cast<ngraph::opset8::FakeQuantize*>(next_node) &&
!dynamic_cast<ngraph::opset8::Add*>(next_node);
};
std::shared_ptr<ngraph::Node> input;
std::shared_ptr<ngraph::Node> matmul1;
std::shared_ptr<ngraph::Node> matmul2;
auto matmul = CreateMatmulPattern(input, matmul1, matmul2, pred);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
return InsertReshape(matcher, input, matmul1, matmul2);
};
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "InsertReshapeAroundMatmul");
this->register_matcher(matcher, callback);
}
InsertReshapeAroundMatmulWithAdd::InsertReshapeAroundMatmulWithAdd() {
MATCHER_SCOPE(InsertReshapeAroundMatmulWithAdd);
auto pred = [](const ngraph::Output<ngraph::Node>& node) {
const auto& outputs = node.get_node_shared_ptr()->outputs();
const auto& inputs = outputs[0].get_target_inputs();
if (inputs.empty()) {
return true;
}
auto next_node = inputs.begin()->get_node();
return outputs.size() != 1 ||
!dynamic_cast<ngraph::opset8::Transpose*>(next_node) &&
!dynamic_cast<ngraph::opset8::FakeQuantize*>(next_node);
};
std::shared_ptr<ngraph::Node> input;
std::shared_ptr<ngraph::Node> matmul1;
std::shared_ptr<ngraph::Node> matmul2;
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
auto add_input = ngraph::pattern::any_input();
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input}, pred);
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul}, pred);
auto add = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add1, add2});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2);
};
auto matcher = std::make_shared<ngraph::pattern::Matcher>(add, "InsertReshapeAroundMatmulWithAdd");
this->register_matcher(matcher, callback);
}
InsertReshapeAroundMatmulWithFq::InsertReshapeAroundMatmulWithFq() {
MATCHER_SCOPE(InsertReshapeAroundMatmulWithFq);
std::shared_ptr<ngraph::Node> input;
std::shared_ptr<ngraph::Node> matmul1;
std::shared_ptr<ngraph::Node> matmul2;
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
auto add_input = ngraph::pattern::any_input();
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input});
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul});
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add1, add2});
auto fake_quantize2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
[](const ngraph::Output<ngraph::Node>& node) {
const auto& outputs = node.get_node_shared_ptr()->outputs();
const auto& inputs = outputs[0].get_target_inputs();
if (inputs.empty()) {
return true;
}
auto next_node = inputs.begin()->get_node();
return outputs.size() != 1 ||
!dynamic_cast<ngraph::opset8::Transpose*>(next_node);
});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2);
};
auto matcher = std::make_shared<ngraph::pattern::Matcher>(fake_quantize2, "InsertReshapeAroundMatmulWithFq");
this->register_matcher(matcher, callback);
}
InsertReshapeAroundMatmulWithTranspose::InsertReshapeAroundMatmulWithTranspose() {
MATCHER_SCOPE(InsertReshapeAroundMatmulWithTranspose);
std::shared_ptr<ngraph::Node> input;
std::shared_ptr<ngraph::Node> matmul1;
std::shared_ptr<ngraph::Node> matmul2;
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
auto add_input = ngraph::pattern::any_input();
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input});
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul});
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add1, add2});
auto fake_quantize2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fake_quantize2});
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2, transpose);
};
auto matcher = std::make_shared<ngraph::pattern::Matcher>(transpose, "InsertReshapeAroundMatmulWithTranspose");
this->register_matcher(matcher, callback);
}
} // namespace GNAPluginNS

View File

@ -0,0 +1,39 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#ifndef INSERT_RESHAPE_AROUND_MATMUL_HPP
#define INSERT_RESHAPE_AROUND_MATMUL_HPP
#include <ngraph/pass/graph_rewrite.hpp>
namespace GNAPluginNS {
// @brief Insert Reshapes from 3d/4d to 2d before MatMul and from 2d to 3d/4d after MatMul
class InsertReshapeAroundMatmul : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
InsertReshapeAroundMatmul();
};
class InsertReshapeAroundMatmulWithAdd : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
InsertReshapeAroundMatmulWithAdd();
};
class InsertReshapeAroundMatmulWithFq : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
InsertReshapeAroundMatmulWithFq();
};
class InsertReshapeAroundMatmulWithTranspose : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
InsertReshapeAroundMatmulWithTranspose();
};
} // namespace GNAPluginNS
#endif // INSERT_RESHAPE_AROUND_MATMUL_HPP

View File

@ -2,31 +2,34 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <transformations/swap_input_matmul_gna.hpp>
#include <openvino/cc/ngraph/itt.hpp>
#include <memory>
#include <vector>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <numeric>
#include <transformations/swap_input_matmul_gna.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ie/ie_common.h>
#include "gna_plugin_log.hpp"
using namespace GNAPluginNS;
namespace GNAPluginNS {
NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0);
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0);
static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
static void SwapAndTransposeInputs(
std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
std::shared_ptr<ngraph::Node> add,
std::shared_ptr<ngraph::Node> bias,
std::shared_ptr<ngraph::Node> fq) {
std::shared_ptr<ngraph::Node> fq,
const std::string& last_layer_name) {
auto create_transpose =
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
@ -56,6 +59,19 @@ static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmu
if (bias->get_output_shape(0).size() > 1) {
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
new_ops.push_back(bias);
auto transpose_shape = bias->get_output_shape(0);
auto matmul_shape = matmul_node->get_output_shape(0);
if (transpose_shape.size() > matmul_shape.size()) {
std::vector<size_t> reshape_shape(matmul_shape.size(), 1);
std::copy_if(transpose_shape.begin(), transpose_shape.end(), reshape_shape.begin(), [](size_t e) { return e > 1; });
bias = std::make_shared<ngraph::opset8::Reshape>(bias,
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{reshape_shape.size()}, reshape_shape), false);
bias->set_friendly_name(add->get_friendly_name() + "/reshape");
ngraph::copy_runtime_info(add, bias);
new_ops.push_back(bias);
}
}
new_matmul = std::make_shared<ngraph::opset8::Add>(new_matmul, bias);
@ -70,113 +86,151 @@ static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmu
new_ops.push_back(new_matmul);
}
auto output = create_transpose(new_matmul, matmul_node->get_friendly_name());
auto output = create_transpose(new_matmul, last_layer_name);
new_ops.push_back(output);
ngraph::copy_runtime_info(matmul_node, new_ops);
ngraph::replace_node(old_root_node, output);
}
SwapInputMatMul::SwapInputMatMul() {
MATCHER_SCOPE(SwapInputMatMul);
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
static std::shared_ptr<ngraph::Node> CreateMatmul(
bool is_first_constant,
ngraph::pattern::op::ValuePredicate const_predicate,
ngraph::pattern::op::ValuePredicate matmul_predicate = ngraph::pattern::has_static_shape()) {
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, const_predicate);
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
if (is_first_constant) {
return ngraph::pattern::wrap_type<ngraph::opset8::MatMul>(
{matmul_input, ngraph::pattern::any_input()}, matmul_predicate);
}
return ngraph::pattern::wrap_type<ngraph::opset8::MatMul>(
{ngraph::pattern::any_input(), matmul_input}, matmul_predicate);
}
static std::shared_ptr<ngraph::Node> CreateMatmuls(
std::shared_ptr<ngraph::Node>& matmul1,
std::shared_ptr<ngraph::Node>& matmul2) {
matmul1 = CreateMatmul(
true,
[](const ngraph::Output<ngraph::Node>& node) { return true; },
[](const ngraph::Output<ngraph::Node>& node) {
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(node.get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr);
auto input_shape = matmul_node->get_input_shape(0);
return input_shape.size() == 2 &&
(!matmul_node->get_transpose_a() && input_shape[0] > 8 ||
matmul_node->get_transpose_a() && input_shape[1] > 8); });
matmul2 = CreateMatmul(
false,
[](const ngraph::Output<ngraph::Node>& node) { return true; },
[](const ngraph::Output<ngraph::Node>& node) {
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(node.get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
auto first_input_shape = matmul_node->get_input_shape(0);
first_input_shape.erase(std::remove(first_input_shape.begin(), first_input_shape.end(), 1), first_input_shape.end());
auto second_input_shape = matmul_node->get_input_shape(1);
return node.get_partial_shape().is_static() &&
second_input_shape.size() == 2 &&
(!matmul_node->get_transpose_b() && second_input_shape[1] <= 8 ||
matmul_node->get_transpose_b() && second_input_shape[0] <= 8) &&
first_input_shape.size() == 2 &&
first_input_shape[0] > 8; });
return std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
}
SwapInputMatMul::SwapInputMatMul() {
MATCHER_SCOPE(SwapInputMatMul);
std::shared_ptr<ngraph::Node> matmul1;
std::shared_ptr<ngraph::Node> matmul2;
auto matmul = CreateMatmuls(matmul1, matmul2);
auto callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto iter = pattern_map.find(matmul1);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
return false;
}
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr, "");
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, matcher_name);
this->register_matcher(m, callback);
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "SwapInputMatMul");
this->register_matcher(matcher, callback);
}
SwapInputMatMulWithBias::SwapInputMatMulWithBias() {
MATCHER_SCOPE(SwapInputMatMulWithBias);
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
std::shared_ptr<ngraph::Node> matmul1;
std::shared_ptr<ngraph::Node> matmul2;
auto matmul = CreateMatmuls(matmul1, matmul2);
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
auto iter = pattern_map.find(matmul1);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
return false;
}
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(),
pattern_map.at(bias).get_node_shared_ptr(), nullptr);
SwapAndTransposeInputs(
matmul_node,
pattern_map.at(add).get_node_shared_ptr(),
pattern_map.at(bias).get_node_shared_ptr(),
nullptr,
pattern_map.at(add).get_node_shared_ptr()->get_friendly_name());
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
this->register_matcher(m, callback);
auto matcher = std::make_shared<ngraph::pattern::Matcher>(add, "SwapInputMatMulWithBias");
this->register_matcher(matcher, callback);
}
SwapInputMatMulWithFq::SwapInputMatMulWithFq() {
MATCHER_SCOPE(SwapInputMatMulWithFq);
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
return false;
}
return true;
});
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
std::shared_ptr<ngraph::Node> matmul1;
std::shared_ptr<ngraph::Node> matmul2;
auto matmul = CreateMatmuls(matmul1, matmul2);
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul_out,
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
auto iter = pattern_map.find(matmul1);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
return false;
}
auto iter_add = pattern_map.find(add);
auto iter_bias = pattern_map.find(bias);
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
IE_ASSERT(matmul_node != nullptr);
auto add_it = pattern_map.find(add);
auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr());
auto bias_it = pattern_map.find(bias);
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr());
SwapAndTransposeInputs(
matmul_node,
iter_add != pattern_map.end() ? iter_add->second.get_node_shared_ptr() : nullptr,
iter_bias != pattern_map.end() ? iter_bias->second.get_node_shared_ptr() : nullptr,
pattern_map.at(fq).get_node_shared_ptr(),
pattern_map.at(fq).get_node_shared_ptr()->get_friendly_name());
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, matcher_name);
this->register_matcher(m, callback);
auto matcher = std::make_shared<ngraph::pattern::Matcher>(fq, "SwapInputMatMulWithFq");
this->register_matcher(matcher, callback);
}
} // namespace GNAPluginNS

View File

@ -2,15 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#ifndef SWAP_INPUT_MATMUL_GNA_HPP
#define SWAP_INPUT_MATMUL_GNA_HPP
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace GNAPluginNS {
// @brief Swaps and transposes inputs of MatMul if its first input is const and its batch size isn't supported by GNA
// @brief Swaps and transposes inputs of MatMul if
// 1. its first input is const and its batch size isn't supported by GNA
// 2. its first input is non-const and its batch size isn't supported by GNA
class SwapInputMatMul: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
@ -29,3 +29,5 @@ public:
SwapInputMatMulWithFq();
};
} // namespace GNAPluginNS
#endif // SWAP_INPUT_MATMUL_GNA_HPP

View File

@ -99,7 +99,8 @@ const std::vector<std::vector<std::vector<size_t>>> input_shapes = {
{{1, 8}, {8, 1}},
{{128, 8}, {8, 1}},
{{8, 8}, {8, 8}},
{{1, 16}, {16, 8}}
{{1, 16}, {16, 8}},
{{6, 16}, {16, 8}}
};

View File

@ -0,0 +1,190 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "transformations/insert_reshape_around_matmul.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include <numeric>
template<bool ADD, bool ADD_FIRST_INPUT_NOT_CONSTANT, bool FQ>
struct InsertReshapeAroundMatmulTest {
static std::shared_ptr<ngraph::Node> CreateAdd(std::shared_ptr<ngraph::Node> input, const ngraph::Shape& constant_shape) {
std::vector<size_t> data(ngraph::shape_size(constant_shape));
std::iota(std::begin(data), std::end(data), 1);
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, constant_shape, data);
return std::make_shared<ngraph::opset8::Add>(input, constant);
}
static std::shared_ptr<ngraph::Node> CreateMatmul(
std::shared_ptr<ngraph::Node> input,
const ngraph::Shape& matmul_constant_shape) {
std::vector<size_t> data(ngraph::shape_size(matmul_constant_shape));
std::iota(std::begin(data), std::end(data), 1);
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, matmul_constant_shape, data);
std::shared_ptr<ngraph::Node> node;
node = std::make_shared<ngraph::opset8::MatMul>(input, constant);
if (ADD) {
auto matmul_shape = node->get_output_shape(0);
data.resize(ngraph::shape_size(matmul_shape));
std::iota(std::begin(data), std::end(data), 1);
std::vector<size_t> constant_add_shape(2, 1);
std::copy_if(matmul_shape.begin(), matmul_shape.end(), constant_add_shape.begin(), [](size_t e) { return e > 1; });
auto constant_add = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{constant_add_shape}, data);
if (ADD_FIRST_INPUT_NOT_CONSTANT) {
node = std::make_shared<ngraph::opset8::Add>(node, constant_add);
} else {
node = std::make_shared<ngraph::opset8::Add>(constant_add, node);
}
}
if (FQ) {
node = std::make_shared<ngraph::opset8::FakeQuantize>(
node,
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}),
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}),
255);
}
return node;
}
static std::shared_ptr<ngraph::Function> CreateFunction(
const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_constant_shape,
const ngraph::Shape& result_shape) {
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
auto before = std::make_shared<ngraph::opset8::Relu>(input);
auto matmul = CreateMatmul(before, matmul_constant_shape);
auto after = std::make_shared<ngraph::opset8::Relu>(matmul);
return std::make_shared<ngraph::Function>(
ngraph::ResultVector{std::make_shared<ngraph::opset8::Result>(after)},
ngraph::ParameterVector{input});
}
static std::shared_ptr<ngraph::Function> CreateReferenceFunction(
const ngraph::Shape& input_shape,
const ngraph::Shape& reshape_before_shape,
const ngraph::Shape& matmul_constant_shape,
const ngraph::Shape& reshape_after_shape,
const ngraph::Shape& result_shape) {
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
auto before = std::make_shared<ngraph::opset8::Relu>(input);
auto reshape_before_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
ngraph::Shape{reshape_before_shape.size()}, reshape_before_shape);
auto reshape_before = std::make_shared<ngraph::opset8::Reshape>(before, reshape_before_constant, false);
auto matmul = CreateMatmul(reshape_before, matmul_constant_shape);
auto reshape_after_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
ngraph::Shape{reshape_after_shape.size()}, reshape_after_shape);
auto reshape_after = std::make_shared<ngraph::opset8::Reshape>(matmul, reshape_after_constant, false);
auto after = std::make_shared<ngraph::opset8::Relu>(reshape_after);
return std::make_shared<ngraph::Function>(
ngraph::ResultVector{std::make_shared<ngraph::opset8::Result>(after)},
ngraph::ParameterVector{input});
}
}; // struct InsertReshapeAroundMatmulTest
namespace {
void RunTest(const std::shared_ptr<ngraph::Function>& func, const std::shared_ptr<ngraph::Function>& reference_func) {
{
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithTranspose>();
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithFq>();
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithAdd>();
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
} // namespace
TEST(TransformationTests, InsertReshapeAroundMatmul) {
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateFunction({1, 6, 1, 8}, {8, 10}, {1, 6, 1, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd) {
RunTest(
InsertReshapeAroundMatmulTest<true, true, false>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<true, true, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd_AddFirstInputConstant) {
RunTest(
InsertReshapeAroundMatmulTest<true, false, false>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<true, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithFq) {
RunTest(
InsertReshapeAroundMatmulTest<false, false, true>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<false, false, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<false, false, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithAddAndFq) {
RunTest(
InsertReshapeAroundMatmulTest<true, true, true>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<true, true, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
}

View File

@ -20,7 +20,8 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
bool withBias,
bool withWeightsFq,
bool withOutFq,
bool swappedInputs) {
bool swappedInputs,
bool needTranspose) {
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input2_shape);
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, input1_shape, {1});
@ -33,14 +34,14 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
const_input = std::make_shared<ngraph::opset8::FakeQuantize>(const_input, input_low, input_high,
output_low, output_high, 11);
}
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, true, true) :
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params);
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, needTranspose, needTranspose) :
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params, needTranspose, needTranspose);
std::shared_ptr<ngraph::Node> final_node = matmul;
if (withBias) {
auto bias = ngraph::opset8::Constant::create(ngraph::element::i64, bias_shape, {1});
std::shared_ptr<ngraph::Node> bias_node = bias;
if (swappedInputs && bias_shape.size() > 1) {
if (needTranspose && bias_shape.size() > 1) {
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
bias_node = std::make_shared<ngraph::opset8::Transpose>(bias_node, transpose_order);
@ -57,7 +58,7 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
output_low, output_high, 11);
}
if (swappedInputs) {
if (needTranspose) {
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
final_node = std::make_shared<ngraph::opset8::Transpose>(final_node, transpose_order);
@ -104,6 +105,12 @@ static std::string getTestCaseName(testing::TestParamInfo<SwapInputMatmulParams>
return result.str();
}
enum class MatmulInputType {
FirstInputConstant,
SecondInputConstant
}; // enum class MatmulInputType
template<MatmulInputType E>
class SwapInputMatmul : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<SwapInputMatmulParams> {
public:
@ -112,14 +119,24 @@ public:
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
bool swap_inputs = false;
switch (E) {
case MatmulInputType::FirstInputConstant:
break;
case MatmulInputType::SecondInputConstant:
swap_inputs = true;
break;
}
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false);
reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq,
withOutFq, true);
withOutFq, !swap_inputs, true);
}
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
template<MatmulInputType E>
class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<SwapInputMatmulParams> {
public:
@ -128,42 +145,92 @@ public:
bool withBias, withWeightsFq, withOutFq;
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
bool swap_inputs = false;
switch (E) {
case MatmulInputType::FirstInputConstant:
break;
case MatmulInputType::SecondInputConstant:
swap_inputs = true;
break;
}
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false);
reference_function = ngraph::clone_function(*function);
}
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
TEST_P(SwapInputMatmul, CompareFunctions) {
using SwapInputMatmulWithFirstInputConstant = SwapInputMatmul<MatmulInputType::FirstInputConstant>;
using SwapInputMatmulWithSecondInputConstant = SwapInputMatmul<MatmulInputType::SecondInputConstant>;
using SwapInputMatmulWithFirstInputConstantNotApplied = SwapInputMatmulNotApplied<MatmulInputType::FirstInputConstant>;
using SwapInputMatmulWithSecondInputConstantNotApplied = SwapInputMatmulNotApplied<MatmulInputType::SecondInputConstant>;
TEST_P(SwapInputMatmulWithFirstInputConstant, CompareFunctions) {
Execute(function, reference_function);
}
TEST_P(SwapInputMatmulNotApplied, CompareFunctions) {
TEST_P(SwapInputMatmulWithFirstInputConstantNotApplied, CompareFunctions) {
Execute(function, reference_function);
}
const std::vector<std::vector<ngraph::Shape>> input_shapes_applied = {
TEST_P(SwapInputMatmulWithSecondInputConstant, CompareFunctions) {
Execute(function, reference_function);
}
TEST_P(SwapInputMatmulWithSecondInputConstantNotApplied, CompareFunctions) {
Execute(function, reference_function);
}
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_first_constant_applied = {
{{16, 8}, {8, 8}, {16, 8}},
{{16, 8}, {8, 8}, {1}},
};
const std::vector<std::vector<ngraph::Shape>> input_shapes_not_applied = {
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_first_constant_not_applied = {
{{1, 8}, {8, 8}, {1, 8}},
{{8}, {8, 8}, {8}}
};
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmul,
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_second_constant_applied = {
{{64, 6}, {100, 64}, {100, 6}},
{{64, 6}, {100, 64}, {1}},
};
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_second_constant_not_applied = {
{{64, 16}, {100, 64}, {100, 16}},
{{64, 6}, {8, 64}, {8, 6}},
{{8, 1}, {8, 8}, {8, 1}},
{{8}, {8, 8}, {8}}
};
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstant,
::testing::Combine(
::testing::ValuesIn(input_shapes_applied),
::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),
getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied,
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstantNotApplied,
::testing::Combine(
::testing::ValuesIn(input_shapes_not_applied),
::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_not_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),
getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstant,
::testing::Combine(
::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),
getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstantNotApplied,
::testing::Combine(
::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_not_applied),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true}),
::testing::ValuesIn(std::vector<bool>{false, true})),

View File

@ -70,56 +70,117 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& inpu
namespace handle_transpose_after_matmul {
std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_after_transpose) {
std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_shape,
const ngraph::Shape& reshape_shape,
bool create_reshape_after_transpose,
bool enable_last_reshape,
bool enable_add,
bool matmul_on_left_side,
bool enable_fq) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
std::iota(std::begin(data), std::end(data), 1);
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
const auto matmul_output_shape = matmul->get_output_shape(0);
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
const auto matmul_output_shape = node->get_output_shape(0);
if (enable_add) {
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
if (matmul_on_left_side) {
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
} else {
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
}
}
if (enable_fq) {
node = std::make_shared<ngraph::opset7::FakeQuantize>(
node,
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
255);
}
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
auto transpose = std::make_shared<ngraph::opset7::Transpose>(matmul, transpose_order);
auto transpose = std::make_shared<ngraph::opset7::Transpose>(node, transpose_order);
const auto transpose_output_shape = transpose->get_output_shape(0);
std::shared_ptr<ngraph::opset7::Reshape> reshape;
std::shared_ptr<ngraph::Node> reshape;
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
if (create_reshape_after_transpose) {
const auto matmul_output_shape = matmul->get_output_shape(0);
const auto matmul_output_shape = node->get_output_shape(0);
auto reshape_after_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
ngraph::Shape{matmul_output_shape.size()}, matmul_output_shape);
auto reshape_after_transpose = std::make_shared<ngraph::opset7::Reshape>(transpose, reshape_after_transpose_const, false);
reshape = reshape_after_transpose;
if (enable_last_reshape) {
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_after_transpose, shape_const, false);
}
} else {
reshape = transpose;
if (enable_last_reshape) {
reshape = std::make_shared<ngraph::opset7::Reshape>(transpose, shape_const, false);
const auto reshape_output_shape = reshape->get_output_shape(0);
}
}
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}
std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_instead_of_transpose) {
std::shared_ptr<ngraph::Function> CreateMatmulFunction(
const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_shape,
const ngraph::Shape& reshape_shape,
bool create_reshape_instead_of_transpose,
bool enable_last_reshape,
bool enable_add,
bool matmul_on_left_side,
bool enable_fq) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
std::iota(std::begin(data), std::end(data), 1);
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
const auto matmul_output_shape = node->get_output_shape(0);
if (enable_add) {
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
if (matmul_on_left_side) {
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
} else {
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
}
}
std::shared_ptr<ngraph::opset7::Reshape> reshape;
if (enable_fq) {
node = std::make_shared<ngraph::opset7::FakeQuantize>(
node,
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
255);
}
std::shared_ptr<ngraph::Node> reshape;
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
if (create_reshape_instead_of_transpose) {
const auto matmul_output_shape = matmul->get_output_shape(0);
auto reshape_instead_of_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
ngraph::Shape{matmul_output_shape.size()}, {matmul_output_shape[1], matmul_output_shape[0]});
auto reshape_instead_of_transpose = std::make_shared<ngraph::opset7::Reshape>(matmul, reshape_instead_of_transpose_const, false);
auto reshape_instead_of_transpose = std::make_shared<ngraph::opset7::Reshape>(node, reshape_instead_of_transpose_const, false);
reshape = reshape_instead_of_transpose;
if (enable_last_reshape) {
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_instead_of_transpose, shape_const, false);
}
} else {
reshape = std::make_shared<ngraph::opset7::Reshape>(matmul, shape_const, false);
reshape = node;
if (enable_last_reshape) {
reshape = std::make_shared<ngraph::opset7::Reshape>(node, shape_const, false);
}
}
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
@ -153,6 +214,9 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTest) {
RunTest(
handle_transpose_before_matmul::CreateMatmulFunction({1, 16}, {8, 2}, {2, 1}, false),
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 16}, {8, 2}, {2, 1}, true));
RunTest(
handle_transpose_before_matmul::CreateMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, false),
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, true));
}
TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) {
@ -177,25 +241,59 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
}
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
for (auto enable_add : { true, false}) {
for (auto matmul_on_left_side : { true, false}) {
for (auto enable_fq : { true, false}) {
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, false),
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, true));
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
}
}
}
}
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
for (auto enable_add : { true, false }) {
for (auto matmul_on_left_side : { true, false }) {
for (auto enable_fq : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, false),
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, true));
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
}
}
}
}
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
for (auto enable_add : { true, false }) {
for (auto matmul_on_left_side : { true, false }) {
for (auto enable_fq : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false),
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false));
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq));
}
}
}
}
TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
for (auto enable_last_reshape : { true, false }) {
for (auto enable_add : { true, false }) {
for (auto matmul_on_left_side : { true, false }) {
for (auto enable_fq : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false),
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false));
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq));
}
}
}
}
}