[GNA] Expanding transformations: swap_input_matmul and handle_transposes_around_matmul (#7333)
* Expanding transformations: swap_input_matmul and handle_transposes_around_matmul * insert_reshape_around_matmul * fixed failed of smoke tests
This commit is contained in:
parent
39120a7f62
commit
ba34a1989c
@ -66,6 +66,7 @@
|
||||
#include "transformations/handle_transposes_around_matmul.hpp"
|
||||
#include "transformations/decompose_2d_conv.hpp"
|
||||
#include "transformations/convert_padded2valid_conv.hpp"
|
||||
#include "transformations/insert_reshape_around_matmul.hpp"
|
||||
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
|
||||
#include "transformations/remove_single_input_concat.hpp"
|
||||
|
||||
@ -730,10 +731,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
manager.register_pass<SplitConvolutionWithFq>();
|
||||
manager.register_pass<SplitConvolutionWithBias>();
|
||||
manager.register_pass<SplitConvolution>();
|
||||
manager.register_pass<HandleTransposesAroundMatMul>();
|
||||
manager.register_pass<InsertReshapeAroundMatmulWithTranspose>();
|
||||
manager.register_pass<InsertReshapeAroundMatmulWithFq>();
|
||||
manager.register_pass<InsertReshapeAroundMatmulWithAdd>();
|
||||
manager.register_pass<InsertReshapeAroundMatmul>();
|
||||
manager.register_pass<SwapInputMatMulWithFq>();
|
||||
manager.register_pass<SwapInputMatMulWithBias>();
|
||||
manager.register_pass<SwapInputMatMul>();
|
||||
manager.register_pass<HandleTransposesAroundMatMul>();
|
||||
manager.register_pass<InsertTransposeAfterConvOrPool>();
|
||||
manager.register_pass<ReorderActivationAndPooling>();
|
||||
manager.register_pass<RemoveSingleInputConcat>();
|
||||
|
@ -6,31 +6,33 @@
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ie/ie_common.h>
|
||||
|
||||
#include "gna_plugin_log.hpp"
|
||||
#include "backend/gna_limitations.hpp"
|
||||
|
||||
using namespace GNAPluginNS;
|
||||
namespace GNAPluginNS {
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(HandleTransposesAroundMatMul, "HandleTransposesAroundMatMul", 0);
|
||||
NGRAPH_RTTI_DEFINITION(HandleTransposeBeforeMatMul, "HandleTransposeBeforeMatMul", 0);
|
||||
NGRAPH_RTTI_DEFINITION(HandleTransposeAfterMatMul, "HandleTransposeAfterMatMul", 0);
|
||||
|
||||
static void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
|
||||
void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
|
||||
auto shape = transpose_node->get_output_shape(0);
|
||||
auto reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||
auto reshape_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{shape.size()}, shape);
|
||||
auto reshape_node = std::make_shared<ngraph::opset7::Reshape>(transpose_node->input_value(0), reshape_const, false);
|
||||
reshape_node->set_friendly_name(transpose_node->get_friendly_name() + "/reshape");
|
||||
auto reshape_node = std::make_shared<ngraph::opset8::Reshape>(transpose_node->input_value(0), reshape_const, false);
|
||||
reshape_node->set_friendly_name(transpose_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(transpose_node, reshape_node);
|
||||
transpose_node->output(0).replace(reshape_node->output(0));
|
||||
}
|
||||
|
||||
static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
|
||||
void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
|
||||
auto consumers = prev_node->output(0).get_target_inputs();
|
||||
const auto orig_shape = prev_node->get_output_shape(0);
|
||||
std::vector<size_t> transpose_ids;
|
||||
@ -44,13 +46,13 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
|
||||
std::iota(std::begin(permute_order), std::end(permute_order), 0);
|
||||
std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]);
|
||||
|
||||
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
|
||||
auto transpose = std::make_shared<ngraph::opset7::Transpose>(prev_node, transpose_order);
|
||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
|
||||
auto transpose = std::make_shared<ngraph::opset8::Transpose>(prev_node, transpose_order);
|
||||
transpose->set_friendly_name(base_name + "/in_transpose");
|
||||
|
||||
auto reshapeConstAfter = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||
auto reshapeConstAfter = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{orig_shape.size()}, orig_shape);
|
||||
auto reshapeAfter = std::make_shared<ngraph::opset7::Reshape>(transpose, reshapeConstAfter, false);
|
||||
auto reshapeAfter = std::make_shared<ngraph::opset8::Reshape>(transpose, reshapeConstAfter, false);
|
||||
reshapeAfter->set_friendly_name(base_name + "/reshape_after_transpose");
|
||||
ngraph::copy_runtime_info(prev_node, ngraph::NodeVector{transpose, reshapeAfter});
|
||||
|
||||
@ -59,74 +61,102 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
|
||||
}
|
||||
}
|
||||
|
||||
static bool VerifyReshape(const ngraph::Output<ngraph::Node>& reshape_out) {
|
||||
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
|
||||
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);
|
||||
return in_shape[0] != out_shape[0];
|
||||
}
|
||||
|
||||
HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
|
||||
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input()}, VerifyReshape());
|
||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({reshape,
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({}, VerifyReshape);
|
||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({reshape,
|
||||
ngraph::pattern::any_input()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose});
|
||||
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()});
|
||||
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), matmul_input});
|
||||
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose}),
|
||||
ngraph::pattern::any_input()});
|
||||
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fq}),
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose, ngraph::pattern::any_input()})});
|
||||
auto matmul = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||
const auto& pattern_map = matcher.get_pattern_value_map();
|
||||
auto matmul_iter = pattern_map.find(matmul1);
|
||||
if (matmul_iter == std::end(pattern_map) &&
|
||||
(matmul_iter = pattern_map.find(matmul2)) == std::end(pattern_map)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_reshape_it = pattern_map.find(transpose);
|
||||
if (transpose_reshape_it != std::end(pattern_map)) {
|
||||
ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr());
|
||||
} else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) {
|
||||
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
||||
if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) {
|
||||
auto matmul_node = matmul_iter->second.get_node_shared_ptr();
|
||||
InsertTranspose(reshape_node, matmul_node->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
||||
auto iter = pattern_map.find(fq);
|
||||
if (iter != pattern_map.end() ||
|
||||
(iter = pattern_map.find(constant)) != pattern_map.end()) {
|
||||
auto prev_node = iter->second.get_node_shared_ptr();
|
||||
if (!GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) return false;
|
||||
auto matmul_node = iter->second.get_node_shared_ptr();
|
||||
InsertTranspose(prev_node, matmul_node->get_friendly_name());
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>();
|
||||
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
|
||||
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
|
||||
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
|
||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
|
||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
|
||||
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
|
||||
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(
|
||||
{reshape_input, ngraph::pattern::any_input()}, VerifyReshape);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||
const auto& pattern_map = matcher.get_pattern_value_map();
|
||||
auto transpose_it = pattern_map.find(transpose);
|
||||
if (transpose_it != std::end(pattern_map)) {
|
||||
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
|
||||
} else {
|
||||
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
||||
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
|
||||
auto matmul_it = pattern_map.find(matmul1);
|
||||
auto matmul_out = matmul_it != std::end(pattern_map) ? matmul_it->second : pattern_map.at(matmul2);
|
||||
InsertTranspose(reshape_node, matmul_out.get_node_shared_ptr()->get_friendly_name());
|
||||
auto iter = pattern_map.find(fq);
|
||||
if (iter == pattern_map.end() &&
|
||||
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
|
||||
return false;
|
||||
}
|
||||
auto node = iter->second.get_node_shared_ptr();
|
||||
InsertTranspose(node, node->get_friendly_name());
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>();
|
||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({matmul, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, fq});
|
||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({transpose_input, ngraph::pattern::any_input()});
|
||||
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
|
||||
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({reshape_input,
|
||||
ngraph::pattern::any_input()}, VerifyReshape());
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto transpose_it = pattern_map.find(transpose);
|
||||
if (transpose_it != std::end(pattern_map)) {
|
||||
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
|
||||
} else {
|
||||
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
||||
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
|
||||
auto matmul_node = pattern_map.at(matmul).get_node_shared_ptr();
|
||||
InsertTranspose(matmul_node, matmul_node->get_friendly_name());
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
bool VerifyReshape::operator()(const ngraph::Output<ngraph::Node>& reshape_out) const {
|
||||
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
|
||||
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);
|
||||
|
||||
// Check if Reshape changes the final 2d shape of Affine primitive
|
||||
in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), in_shape.end());
|
||||
out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), out_shape.end());
|
||||
return in_shape != out_shape;
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
HandleTransposesAroundMatMul::HandleTransposesAroundMatMul() {
|
||||
add_matcher<HandleTransposeBeforeMatMul>();
|
||||
add_matcher<HandleTransposeAfterMatMul>();
|
||||
}
|
||||
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -8,10 +8,6 @@
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
struct VerifyReshape {
|
||||
bool operator()(const ngraph::Output<ngraph::Node>& reshape_out) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape
|
||||
* before MatMul which changes the batch size:
|
||||
@ -48,13 +44,13 @@ public:
|
||||
* | |
|
||||
* [1, A*B] [1, A*B]
|
||||
*/
|
||||
class HandleTransposeAfterMatMul : public ngraph::pass::MatcherPass {
|
||||
class HandleTransposeAfterMatMul: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HandleTransposeAfterMatMul();
|
||||
};
|
||||
|
||||
class HandleTransposesAroundMatMul: public ngraph::pass::GraphRewrite {
|
||||
class HandleTransposesAroundMatMul : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HandleTransposesAroundMatMul();
|
||||
|
@ -0,0 +1,237 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/insert_reshape_around_matmul.hpp"
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ie/ie_common.h>
|
||||
|
||||
#include "gna_plugin_log.hpp"
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmul, "InsertReshapeAroundMatmul", 0);
|
||||
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithAdd, "InsertReshapeAroundMatmulWithAdd", 0);
|
||||
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithFq, "InsertReshapeAroundMatmulWithFq", 0);
|
||||
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithTranspose, "InsertReshapeAroundMatmulWithTranspose", 0);
|
||||
|
||||
static bool InsertReshape(
|
||||
ngraph::pattern::Matcher &matcher,
|
||||
const std::shared_ptr<ngraph::Node>& input,
|
||||
const std::shared_ptr<ngraph::Node>& matmul1,
|
||||
const std::shared_ptr<ngraph::Node>& matmul2,
|
||||
const std::shared_ptr<ngraph::Node>& add1 = nullptr,
|
||||
const std::shared_ptr<ngraph::Node>& add2 = nullptr,
|
||||
const std::shared_ptr<ngraph::Node>& fake_quantize2 = nullptr,
|
||||
const std::shared_ptr<ngraph::Node>& transpose = nullptr) {
|
||||
const auto& pattern_map = matcher.get_pattern_value_map();
|
||||
size_t matmul_input_index = 1;
|
||||
auto iter = pattern_map.find(matmul1);
|
||||
if (iter == pattern_map.end()) {
|
||||
iter = pattern_map.find(matmul2);
|
||||
if ((iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
matmul_input_index = 0;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> matmul_node = iter->second.get_node_shared_ptr();
|
||||
auto matmul_node_shape = matmul_node->get_output_shape(0);
|
||||
if ((iter = pattern_map.find(input)) == std::end(pattern_map)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> first_node = iter->second.get_node_shared_ptr();
|
||||
auto reshape_input_node = std::dynamic_pointer_cast<ngraph::opset8::Reshape>(first_node);
|
||||
bool need_reshape_before = !reshape_input_node || reshape_input_node->get_output_shape(0).size() != 2;
|
||||
if (need_reshape_before) {
|
||||
auto input_shape = first_node->get_output_shape(0);
|
||||
std::vector<size_t> before_shape(2, 1);
|
||||
std::copy_if(input_shape.begin(), input_shape.end(), before_shape.begin(), [](size_t e) { return e > 1; });
|
||||
auto reshape_before_node = std::make_shared<ngraph::opset8::Reshape>(first_node,
|
||||
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{before_shape.size()}, before_shape), false);
|
||||
reshape_before_node->set_friendly_name(matmul_node->get_friendly_name() + "/reshape_before_matmul");
|
||||
ngraph::copy_runtime_info(first_node, reshape_before_node);
|
||||
matmul_node->input(matmul_input_index).replace_source_output(reshape_before_node->output(0));
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> last_node;
|
||||
iter = pattern_map.find(transpose);
|
||||
if (iter == pattern_map.end() &&
|
||||
(iter = pattern_map.find(fake_quantize2)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(add1)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(add2)) == pattern_map.end()) {
|
||||
last_node = matmul_node;
|
||||
} else {
|
||||
last_node = iter->second.get_node_shared_ptr();
|
||||
}
|
||||
|
||||
auto consumers = last_node->output(0).get_target_inputs();
|
||||
auto last_node_shape = last_node->get_output_shape(0);
|
||||
bool need_reshape_after = false;
|
||||
for (auto consumer : consumers) {
|
||||
auto reshape_output_node = dynamic_cast<ngraph::opset8::Reshape*>(consumer.get_node());
|
||||
if (!reshape_output_node || reshape_output_node->get_output_shape(0).size() != last_node_shape.size()) {
|
||||
need_reshape_after = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_reshape_after) {
|
||||
auto reshape_after_node = std::make_shared<ngraph::opset8::Reshape>(last_node,
|
||||
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{last_node_shape.size()}, last_node_shape), false);
|
||||
reshape_after_node->set_friendly_name(last_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(last_node, reshape_after_node);
|
||||
for (auto consumer : consumers) {
|
||||
consumer.replace_source_output(reshape_after_node);
|
||||
}
|
||||
}
|
||||
|
||||
return need_reshape_before || need_reshape_after;
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Node> CreateMatmulPattern(
|
||||
std::shared_ptr<ngraph::Node>& input,
|
||||
std::shared_ptr<ngraph::Node>& matmul1,
|
||||
std::shared_ptr<ngraph::Node>& matmul2,
|
||||
const ngraph::pattern::op::ValuePredicate& pred = [](const ngraph::Output<ngraph::Node>& output) { return true; }) {
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
input = ngraph::pattern::any_input([](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
return shape.size() > 2 && std::count_if(shape.begin(), shape.end(), [](size_t e) { return e > 1; }) <= 2; });
|
||||
matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, input}, pred);
|
||||
matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({input, matmul_input}, pred);
|
||||
return std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
||||
}
|
||||
|
||||
InsertReshapeAroundMatmul::InsertReshapeAroundMatmul() {
|
||||
MATCHER_SCOPE(InsertReshapeAroundMatmul);
|
||||
|
||||
auto pred = [](const ngraph::Output<ngraph::Node>& node) {
|
||||
const auto& outputs = node.get_node_shared_ptr()->outputs();
|
||||
const auto& inputs = outputs[0].get_target_inputs();
|
||||
if (inputs.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto next_node = inputs.begin()->get_node();
|
||||
return outputs.size() != 1 ||
|
||||
!dynamic_cast<ngraph::opset8::Transpose*>(next_node) &&
|
||||
!dynamic_cast<ngraph::opset8::FakeQuantize*>(next_node) &&
|
||||
!dynamic_cast<ngraph::opset8::Add*>(next_node);
|
||||
};
|
||||
|
||||
std::shared_ptr<ngraph::Node> input;
|
||||
std::shared_ptr<ngraph::Node> matmul1;
|
||||
std::shared_ptr<ngraph::Node> matmul2;
|
||||
auto matmul = CreateMatmulPattern(input, matmul1, matmul2, pred);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||
return InsertReshape(matcher, input, matmul1, matmul2);
|
||||
};
|
||||
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "InsertReshapeAroundMatmul");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
InsertReshapeAroundMatmulWithAdd::InsertReshapeAroundMatmulWithAdd() {
|
||||
MATCHER_SCOPE(InsertReshapeAroundMatmulWithAdd);
|
||||
|
||||
auto pred = [](const ngraph::Output<ngraph::Node>& node) {
|
||||
const auto& outputs = node.get_node_shared_ptr()->outputs();
|
||||
const auto& inputs = outputs[0].get_target_inputs();
|
||||
if (inputs.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto next_node = inputs.begin()->get_node();
|
||||
return outputs.size() != 1 ||
|
||||
!dynamic_cast<ngraph::opset8::Transpose*>(next_node) &&
|
||||
!dynamic_cast<ngraph::opset8::FakeQuantize*>(next_node);
|
||||
};
|
||||
|
||||
std::shared_ptr<ngraph::Node> input;
|
||||
std::shared_ptr<ngraph::Node> matmul1;
|
||||
std::shared_ptr<ngraph::Node> matmul2;
|
||||
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
|
||||
auto add_input = ngraph::pattern::any_input();
|
||||
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input}, pred);
|
||||
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul}, pred);
|
||||
auto add = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add1, add2});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2);
|
||||
};
|
||||
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(add, "InsertReshapeAroundMatmulWithAdd");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
InsertReshapeAroundMatmulWithFq::InsertReshapeAroundMatmulWithFq() {
|
||||
MATCHER_SCOPE(InsertReshapeAroundMatmulWithFq);
|
||||
|
||||
std::shared_ptr<ngraph::Node> input;
|
||||
std::shared_ptr<ngraph::Node> matmul1;
|
||||
std::shared_ptr<ngraph::Node> matmul2;
|
||||
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
|
||||
auto add_input = ngraph::pattern::any_input();
|
||||
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input});
|
||||
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul});
|
||||
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add1, add2});
|
||||
auto fake_quantize2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
|
||||
[](const ngraph::Output<ngraph::Node>& node) {
|
||||
const auto& outputs = node.get_node_shared_ptr()->outputs();
|
||||
const auto& inputs = outputs[0].get_target_inputs();
|
||||
if (inputs.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto next_node = inputs.begin()->get_node();
|
||||
return outputs.size() != 1 ||
|
||||
!dynamic_cast<ngraph::opset8::Transpose*>(next_node);
|
||||
});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2);
|
||||
};
|
||||
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(fake_quantize2, "InsertReshapeAroundMatmulWithFq");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
InsertReshapeAroundMatmulWithTranspose::InsertReshapeAroundMatmulWithTranspose() {
|
||||
MATCHER_SCOPE(InsertReshapeAroundMatmulWithTranspose);
|
||||
|
||||
std::shared_ptr<ngraph::Node> input;
|
||||
std::shared_ptr<ngraph::Node> matmul1;
|
||||
std::shared_ptr<ngraph::Node> matmul2;
|
||||
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
|
||||
auto add_input = ngraph::pattern::any_input();
|
||||
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input});
|
||||
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul});
|
||||
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add1, add2});
|
||||
auto fake_quantize2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fake_quantize2});
|
||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2, transpose);
|
||||
};
|
||||
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(transpose, "InsertReshapeAroundMatmulWithTranspose");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
} // namespace GNAPluginNS
|
@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#ifndef INSERT_RESHAPE_AROUND_MATMUL_HPP
|
||||
#define INSERT_RESHAPE_AROUND_MATMUL_HPP
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
// @brief Insert Reshapes from 3d/4d to 2d before MatMul and from 2d to 3d/4d after MatMul
|
||||
class InsertReshapeAroundMatmul : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
InsertReshapeAroundMatmul();
|
||||
};
|
||||
|
||||
class InsertReshapeAroundMatmulWithAdd : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
InsertReshapeAroundMatmulWithAdd();
|
||||
};
|
||||
|
||||
class InsertReshapeAroundMatmulWithFq : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
InsertReshapeAroundMatmulWithFq();
|
||||
};
|
||||
|
||||
class InsertReshapeAroundMatmulWithTranspose : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
InsertReshapeAroundMatmulWithTranspose();
|
||||
};
|
||||
|
||||
} // namespace GNAPluginNS
|
||||
|
||||
#endif // INSERT_RESHAPE_AROUND_MATMUL_HPP
|
@ -2,31 +2,34 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <transformations/swap_input_matmul_gna.hpp>
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <numeric>
|
||||
#include <transformations/swap_input_matmul_gna.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ie/ie_common.h>
|
||||
|
||||
#include "gna_plugin_log.hpp"
|
||||
|
||||
using namespace GNAPluginNS;
|
||||
namespace GNAPluginNS {
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
|
||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0);
|
||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0);
|
||||
|
||||
static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
|
||||
static void SwapAndTransposeInputs(
|
||||
std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
|
||||
std::shared_ptr<ngraph::Node> add,
|
||||
std::shared_ptr<ngraph::Node> bias,
|
||||
std::shared_ptr<ngraph::Node> fq) {
|
||||
std::shared_ptr<ngraph::Node> fq,
|
||||
const std::string& last_layer_name) {
|
||||
auto create_transpose =
|
||||
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
||||
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
|
||||
@ -56,6 +59,19 @@ static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmu
|
||||
if (bias->get_output_shape(0).size() > 1) {
|
||||
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
|
||||
new_ops.push_back(bias);
|
||||
|
||||
auto transpose_shape = bias->get_output_shape(0);
|
||||
auto matmul_shape = matmul_node->get_output_shape(0);
|
||||
if (transpose_shape.size() > matmul_shape.size()) {
|
||||
std::vector<size_t> reshape_shape(matmul_shape.size(), 1);
|
||||
std::copy_if(transpose_shape.begin(), transpose_shape.end(), reshape_shape.begin(), [](size_t e) { return e > 1; });
|
||||
bias = std::make_shared<ngraph::opset8::Reshape>(bias,
|
||||
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{reshape_shape.size()}, reshape_shape), false);
|
||||
bias->set_friendly_name(add->get_friendly_name() + "/reshape");
|
||||
ngraph::copy_runtime_info(add, bias);
|
||||
new_ops.push_back(bias);
|
||||
}
|
||||
}
|
||||
|
||||
new_matmul = std::make_shared<ngraph::opset8::Add>(new_matmul, bias);
|
||||
@ -70,113 +86,151 @@ static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmu
|
||||
new_ops.push_back(new_matmul);
|
||||
}
|
||||
|
||||
auto output = create_transpose(new_matmul, matmul_node->get_friendly_name());
|
||||
auto output = create_transpose(new_matmul, last_layer_name);
|
||||
new_ops.push_back(output);
|
||||
|
||||
ngraph::copy_runtime_info(matmul_node, new_ops);
|
||||
ngraph::replace_node(old_root_node, output);
|
||||
}
|
||||
|
||||
SwapInputMatMul::SwapInputMatMul() {
|
||||
MATCHER_SCOPE(SwapInputMatMul);
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
static std::shared_ptr<ngraph::Node> CreateMatmul(
|
||||
bool is_first_constant,
|
||||
ngraph::pattern::op::ValuePredicate const_predicate,
|
||||
ngraph::pattern::op::ValuePredicate matmul_predicate = ngraph::pattern::has_static_shape()) {
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, const_predicate);
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
||||
ngraph::pattern::has_static_shape());
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
||||
if (is_first_constant) {
|
||||
return ngraph::pattern::wrap_type<ngraph::opset8::MatMul>(
|
||||
{matmul_input, ngraph::pattern::any_input()}, matmul_predicate);
|
||||
}
|
||||
return ngraph::pattern::wrap_type<ngraph::opset8::MatMul>(
|
||||
{ngraph::pattern::any_input(), matmul_input}, matmul_predicate);
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Node> CreateMatmuls(
|
||||
std::shared_ptr<ngraph::Node>& matmul1,
|
||||
std::shared_ptr<ngraph::Node>& matmul2) {
|
||||
matmul1 = CreateMatmul(
|
||||
true,
|
||||
[](const ngraph::Output<ngraph::Node>& node) { return true; },
|
||||
[](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(node.get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr);
|
||||
auto input_shape = matmul_node->get_input_shape(0);
|
||||
return input_shape.size() == 2 &&
|
||||
(!matmul_node->get_transpose_a() && input_shape[0] > 8 ||
|
||||
matmul_node->get_transpose_a() && input_shape[1] > 8); });
|
||||
matmul2 = CreateMatmul(
|
||||
false,
|
||||
[](const ngraph::Output<ngraph::Node>& node) { return true; },
|
||||
[](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(node.get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
auto first_input_shape = matmul_node->get_input_shape(0);
|
||||
first_input_shape.erase(std::remove(first_input_shape.begin(), first_input_shape.end(), 1), first_input_shape.end());
|
||||
auto second_input_shape = matmul_node->get_input_shape(1);
|
||||
return node.get_partial_shape().is_static() &&
|
||||
second_input_shape.size() == 2 &&
|
||||
(!matmul_node->get_transpose_b() && second_input_shape[1] <= 8 ||
|
||||
matmul_node->get_transpose_b() && second_input_shape[0] <= 8) &&
|
||||
first_input_shape.size() == 2 &&
|
||||
first_input_shape[0] > 8; });
|
||||
return std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
||||
}
|
||||
|
||||
SwapInputMatMul::SwapInputMatMul() {
|
||||
MATCHER_SCOPE(SwapInputMatMul);
|
||||
std::shared_ptr<ngraph::Node> matmul1;
|
||||
std::shared_ptr<ngraph::Node> matmul2;
|
||||
auto matmul = CreateMatmuls(matmul1, matmul2);
|
||||
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto iter = pattern_map.find(matmul1);
|
||||
if (iter == pattern_map.end() &&
|
||||
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr, "");
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "SwapInputMatMul");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
SwapInputMatMulWithBias::SwapInputMatMulWithBias() {
|
||||
MATCHER_SCOPE(SwapInputMatMulWithBias);
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
||||
ngraph::pattern::has_static_shape());
|
||||
std::shared_ptr<ngraph::Node> matmul1;
|
||||
std::shared_ptr<ngraph::Node> matmul2;
|
||||
auto matmul = CreateMatmuls(matmul1, matmul2);
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
||||
auto iter = pattern_map.find(matmul1);
|
||||
if (iter == pattern_map.end() &&
|
||||
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(),
|
||||
pattern_map.at(bias).get_node_shared_ptr(), nullptr);
|
||||
SwapAndTransposeInputs(
|
||||
matmul_node,
|
||||
pattern_map.at(add).get_node_shared_ptr(),
|
||||
pattern_map.at(bias).get_node_shared_ptr(),
|
||||
nullptr,
|
||||
pattern_map.at(add).get_node_shared_ptr()->get_friendly_name());
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(add, "SwapInputMatMulWithBias");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
|
||||
SwapInputMatMulWithFq::SwapInputMatMulWithFq() {
|
||||
MATCHER_SCOPE(SwapInputMatMulWithFq);
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
||||
ngraph::pattern::has_static_shape());
|
||||
std::shared_ptr<ngraph::Node> matmul1;
|
||||
std::shared_ptr<ngraph::Node> matmul2;
|
||||
auto matmul = CreateMatmuls(matmul1, matmul2);
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
||||
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
|
||||
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul_out,
|
||||
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
|
||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
||||
auto iter = pattern_map.find(matmul1);
|
||||
if (iter == pattern_map.end() &&
|
||||
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto iter_add = pattern_map.find(add);
|
||||
auto iter_bias = pattern_map.find(bias);
|
||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
|
||||
IE_ASSERT(matmul_node != nullptr);
|
||||
auto add_it = pattern_map.find(add);
|
||||
auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr());
|
||||
auto bias_it = pattern_map.find(bias);
|
||||
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
|
||||
SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr());
|
||||
SwapAndTransposeInputs(
|
||||
matmul_node,
|
||||
iter_add != pattern_map.end() ? iter_add->second.get_node_shared_ptr() : nullptr,
|
||||
iter_bias != pattern_map.end() ? iter_bias->second.get_node_shared_ptr() : nullptr,
|
||||
pattern_map.at(fq).get_node_shared_ptr(),
|
||||
pattern_map.at(fq).get_node_shared_ptr()->get_friendly_name());
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(fq, "SwapInputMatMulWithFq");
|
||||
this->register_matcher(matcher, callback);
|
||||
}
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -2,15 +2,15 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#ifndef SWAP_INPUT_MATMUL_GNA_HPP
|
||||
#define SWAP_INPUT_MATMUL_GNA_HPP
|
||||
|
||||
#include <memory>
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
// @brief Swaps and transposes inputs of MatMul if its first input is const and its batch size isn't supported by GNA
|
||||
// @brief Swaps and transposes inputs of MatMul if
|
||||
// 1. its first input is const and its batch size isn't supported by GNA
|
||||
// 2. its first input is non-const and its batch size isn't supported by GNA
|
||||
class SwapInputMatMul: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -29,3 +29,5 @@ public:
|
||||
SwapInputMatMulWithFq();
|
||||
};
|
||||
} // namespace GNAPluginNS
|
||||
|
||||
#endif // SWAP_INPUT_MATMUL_GNA_HPP
|
||||
|
@ -99,7 +99,8 @@ const std::vector<std::vector<std::vector<size_t>>> input_shapes = {
|
||||
{{1, 8}, {8, 1}},
|
||||
{{128, 8}, {8, 1}},
|
||||
{{8, 8}, {8, 8}},
|
||||
{{1, 16}, {16, 8}}
|
||||
{{1, 16}, {16, 8}},
|
||||
{{6, 16}, {16, 8}}
|
||||
};
|
||||
|
||||
|
||||
|
@ -0,0 +1,190 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "transformations/insert_reshape_around_matmul.hpp"
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <numeric>
|
||||
|
||||
template<bool ADD, bool ADD_FIRST_INPUT_NOT_CONSTANT, bool FQ>
|
||||
struct InsertReshapeAroundMatmulTest {
|
||||
static std::shared_ptr<ngraph::Node> CreateAdd(std::shared_ptr<ngraph::Node> input, const ngraph::Shape& constant_shape) {
|
||||
std::vector<size_t> data(ngraph::shape_size(constant_shape));
|
||||
std::iota(std::begin(data), std::end(data), 1);
|
||||
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, constant_shape, data);
|
||||
return std::make_shared<ngraph::opset8::Add>(input, constant);
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Node> CreateMatmul(
|
||||
std::shared_ptr<ngraph::Node> input,
|
||||
const ngraph::Shape& matmul_constant_shape) {
|
||||
std::vector<size_t> data(ngraph::shape_size(matmul_constant_shape));
|
||||
std::iota(std::begin(data), std::end(data), 1);
|
||||
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, matmul_constant_shape, data);
|
||||
std::shared_ptr<ngraph::Node> node;
|
||||
node = std::make_shared<ngraph::opset8::MatMul>(input, constant);
|
||||
|
||||
if (ADD) {
|
||||
auto matmul_shape = node->get_output_shape(0);
|
||||
data.resize(ngraph::shape_size(matmul_shape));
|
||||
std::iota(std::begin(data), std::end(data), 1);
|
||||
std::vector<size_t> constant_add_shape(2, 1);
|
||||
std::copy_if(matmul_shape.begin(), matmul_shape.end(), constant_add_shape.begin(), [](size_t e) { return e > 1; });
|
||||
auto constant_add = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{constant_add_shape}, data);
|
||||
if (ADD_FIRST_INPUT_NOT_CONSTANT) {
|
||||
node = std::make_shared<ngraph::opset8::Add>(node, constant_add);
|
||||
} else {
|
||||
node = std::make_shared<ngraph::opset8::Add>(constant_add, node);
|
||||
}
|
||||
}
|
||||
|
||||
if (FQ) {
|
||||
node = std::make_shared<ngraph::opset8::FakeQuantize>(
|
||||
node,
|
||||
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
255);
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Function> CreateFunction(
|
||||
const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& matmul_constant_shape,
|
||||
const ngraph::Shape& result_shape) {
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
|
||||
auto before = std::make_shared<ngraph::opset8::Relu>(input);
|
||||
auto matmul = CreateMatmul(before, matmul_constant_shape);
|
||||
auto after = std::make_shared<ngraph::opset8::Relu>(matmul);
|
||||
return std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{std::make_shared<ngraph::opset8::Result>(after)},
|
||||
ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Function> CreateReferenceFunction(
|
||||
const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& reshape_before_shape,
|
||||
const ngraph::Shape& matmul_constant_shape,
|
||||
const ngraph::Shape& reshape_after_shape,
|
||||
const ngraph::Shape& result_shape) {
|
||||
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
|
||||
auto before = std::make_shared<ngraph::opset8::Relu>(input);
|
||||
auto reshape_before_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{reshape_before_shape.size()}, reshape_before_shape);
|
||||
auto reshape_before = std::make_shared<ngraph::opset8::Reshape>(before, reshape_before_constant, false);
|
||||
auto matmul = CreateMatmul(reshape_before, matmul_constant_shape);
|
||||
auto reshape_after_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{reshape_after_shape.size()}, reshape_after_shape);
|
||||
auto reshape_after = std::make_shared<ngraph::opset8::Reshape>(matmul, reshape_after_constant, false);
|
||||
auto after = std::make_shared<ngraph::opset8::Relu>(reshape_after);
|
||||
return std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{std::make_shared<ngraph::opset8::Result>(after)},
|
||||
ngraph::ParameterVector{input});
|
||||
}
|
||||
}; // struct InsertReshapeAroundMatmulTest
|
||||
|
||||
namespace {
|
||||
|
||||
void RunTest(const std::shared_ptr<ngraph::Function>& func, const std::shared_ptr<ngraph::Function>& reference_func) {
|
||||
{
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithTranspose>();
|
||||
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithFq>();
|
||||
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithAdd>();
|
||||
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmul>();
|
||||
m.run_passes(func);
|
||||
ASSERT_NO_THROW(check_rt_info(func));
|
||||
}
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(TransformationTests, InsertReshapeAroundMatmul) {
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateFunction({1, 6, 1, 8}, {8, 10}, {1, 6, 1, 10}),
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}),
|
||||
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd) {
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd_AddFirstInputConstant) {
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertReshapeAroundMatmulWithFq) {
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertReshapeAroundMatmulWithAddAndFq) {
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
RunTest(
|
||||
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||
}
|
@ -20,7 +20,8 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
|
||||
bool withBias,
|
||||
bool withWeightsFq,
|
||||
bool withOutFq,
|
||||
bool swappedInputs) {
|
||||
bool swappedInputs,
|
||||
bool needTranspose) {
|
||||
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input2_shape);
|
||||
|
||||
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, input1_shape, {1});
|
||||
@ -33,14 +34,14 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
|
||||
const_input = std::make_shared<ngraph::opset8::FakeQuantize>(const_input, input_low, input_high,
|
||||
output_low, output_high, 11);
|
||||
}
|
||||
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, true, true) :
|
||||
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params);
|
||||
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, needTranspose, needTranspose) :
|
||||
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params, needTranspose, needTranspose);
|
||||
|
||||
std::shared_ptr<ngraph::Node> final_node = matmul;
|
||||
if (withBias) {
|
||||
auto bias = ngraph::opset8::Constant::create(ngraph::element::i64, bias_shape, {1});
|
||||
std::shared_ptr<ngraph::Node> bias_node = bias;
|
||||
if (swappedInputs && bias_shape.size() > 1) {
|
||||
if (needTranspose && bias_shape.size() > 1) {
|
||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
std::vector<size_t>{1, 0});
|
||||
bias_node = std::make_shared<ngraph::opset8::Transpose>(bias_node, transpose_order);
|
||||
@ -57,7 +58,7 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
|
||||
output_low, output_high, 11);
|
||||
}
|
||||
|
||||
if (swappedInputs) {
|
||||
if (needTranspose) {
|
||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||
std::vector<size_t>{1, 0});
|
||||
final_node = std::make_shared<ngraph::opset8::Transpose>(final_node, transpose_order);
|
||||
@ -104,6 +105,12 @@ static std::string getTestCaseName(testing::TestParamInfo<SwapInputMatmulParams>
|
||||
return result.str();
|
||||
}
|
||||
|
||||
enum class MatmulInputType {
|
||||
FirstInputConstant,
|
||||
SecondInputConstant
|
||||
}; // enum class MatmulInputType
|
||||
|
||||
template<MatmulInputType E>
|
||||
class SwapInputMatmul : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
||||
public:
|
||||
@ -112,14 +119,24 @@ public:
|
||||
bool withBias, withWeightsFq, withOutFq;
|
||||
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
||||
|
||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
|
||||
bool swap_inputs = false;
|
||||
switch (E) {
|
||||
case MatmulInputType::FirstInputConstant:
|
||||
break;
|
||||
case MatmulInputType::SecondInputConstant:
|
||||
swap_inputs = true;
|
||||
break;
|
||||
}
|
||||
|
||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false);
|
||||
reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq,
|
||||
withOutFq, true);
|
||||
withOutFq, !swap_inputs, true);
|
||||
}
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
};
|
||||
|
||||
template<MatmulInputType E>
|
||||
class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
||||
public:
|
||||
@ -128,42 +145,92 @@ public:
|
||||
bool withBias, withWeightsFq, withOutFq;
|
||||
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
||||
|
||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
|
||||
bool swap_inputs = false;
|
||||
switch (E) {
|
||||
case MatmulInputType::FirstInputConstant:
|
||||
break;
|
||||
case MatmulInputType::SecondInputConstant:
|
||||
swap_inputs = true;
|
||||
break;
|
||||
}
|
||||
|
||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false);
|
||||
reference_function = ngraph::clone_function(*function);
|
||||
}
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
};
|
||||
|
||||
TEST_P(SwapInputMatmul, CompareFunctions) {
|
||||
using SwapInputMatmulWithFirstInputConstant = SwapInputMatmul<MatmulInputType::FirstInputConstant>;
|
||||
using SwapInputMatmulWithSecondInputConstant = SwapInputMatmul<MatmulInputType::SecondInputConstant>;
|
||||
using SwapInputMatmulWithFirstInputConstantNotApplied = SwapInputMatmulNotApplied<MatmulInputType::FirstInputConstant>;
|
||||
using SwapInputMatmulWithSecondInputConstantNotApplied = SwapInputMatmulNotApplied<MatmulInputType::SecondInputConstant>;
|
||||
|
||||
TEST_P(SwapInputMatmulWithFirstInputConstant, CompareFunctions) {
|
||||
Execute(function, reference_function);
|
||||
}
|
||||
|
||||
TEST_P(SwapInputMatmulNotApplied, CompareFunctions) {
|
||||
TEST_P(SwapInputMatmulWithFirstInputConstantNotApplied, CompareFunctions) {
|
||||
Execute(function, reference_function);
|
||||
}
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_applied = {
|
||||
TEST_P(SwapInputMatmulWithSecondInputConstant, CompareFunctions) {
|
||||
Execute(function, reference_function);
|
||||
}
|
||||
|
||||
TEST_P(SwapInputMatmulWithSecondInputConstantNotApplied, CompareFunctions) {
|
||||
Execute(function, reference_function);
|
||||
}
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_first_constant_applied = {
|
||||
{{16, 8}, {8, 8}, {16, 8}},
|
||||
{{16, 8}, {8, 8}, {1}},
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_not_applied = {
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_first_constant_not_applied = {
|
||||
{{1, 8}, {8, 8}, {1, 8}},
|
||||
{{8}, {8, 8}, {8}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmul,
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_second_constant_applied = {
|
||||
{{64, 6}, {100, 64}, {100, 6}},
|
||||
{{64, 6}, {100, 64}, {1}},
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_second_constant_not_applied = {
|
||||
{{64, 16}, {100, 64}, {100, 16}},
|
||||
{{64, 6}, {8, 64}, {8, 6}},
|
||||
{{8, 1}, {8, 8}, {8, 1}},
|
||||
{{8}, {8, 8}, {8}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstant,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes_applied),
|
||||
::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_applied),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||
getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstantNotApplied,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes_not_applied),
|
||||
::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_not_applied),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||
getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstant,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_applied),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||
getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstantNotApplied,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_not_applied),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||
|
@ -70,56 +70,117 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& inpu
|
||||
|
||||
namespace handle_transpose_after_matmul {
|
||||
|
||||
std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_after_transpose) {
|
||||
std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
||||
const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& matmul_shape,
|
||||
const ngraph::Shape& reshape_shape,
|
||||
bool create_reshape_after_transpose,
|
||||
bool enable_last_reshape,
|
||||
bool enable_add,
|
||||
bool matmul_on_left_side,
|
||||
bool enable_fq) {
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||
|
||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||
std::iota(std::begin(data), std::end(data), 1);
|
||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||
const auto matmul_output_shape = matmul->get_output_shape(0);
|
||||
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||
const auto matmul_output_shape = node->get_output_shape(0);
|
||||
if (enable_add) {
|
||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||
if (matmul_on_left_side) {
|
||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||
} else {
|
||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||
}
|
||||
}
|
||||
|
||||
if (enable_fq) {
|
||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||
node,
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
255);
|
||||
}
|
||||
|
||||
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
|
||||
auto transpose = std::make_shared<ngraph::opset7::Transpose>(matmul, transpose_order);
|
||||
auto transpose = std::make_shared<ngraph::opset7::Transpose>(node, transpose_order);
|
||||
const auto transpose_output_shape = transpose->get_output_shape(0);
|
||||
|
||||
std::shared_ptr<ngraph::opset7::Reshape> reshape;
|
||||
std::shared_ptr<ngraph::Node> reshape;
|
||||
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
||||
if (create_reshape_after_transpose) {
|
||||
const auto matmul_output_shape = matmul->get_output_shape(0);
|
||||
const auto matmul_output_shape = node->get_output_shape(0);
|
||||
auto reshape_after_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{matmul_output_shape.size()}, matmul_output_shape);
|
||||
auto reshape_after_transpose = std::make_shared<ngraph::opset7::Reshape>(transpose, reshape_after_transpose_const, false);
|
||||
reshape = reshape_after_transpose;
|
||||
if (enable_last_reshape) {
|
||||
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_after_transpose, shape_const, false);
|
||||
}
|
||||
} else {
|
||||
reshape = transpose;
|
||||
if (enable_last_reshape) {
|
||||
reshape = std::make_shared<ngraph::opset7::Reshape>(transpose, shape_const, false);
|
||||
const auto reshape_output_shape = reshape->get_output_shape(0);
|
||||
}
|
||||
}
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_instead_of_transpose) {
|
||||
std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
||||
const ngraph::Shape& input_shape,
|
||||
const ngraph::Shape& matmul_shape,
|
||||
const ngraph::Shape& reshape_shape,
|
||||
bool create_reshape_instead_of_transpose,
|
||||
bool enable_last_reshape,
|
||||
bool enable_add,
|
||||
bool matmul_on_left_side,
|
||||
bool enable_fq) {
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||
|
||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||
std::iota(std::begin(data), std::end(data), 1);
|
||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||
const auto matmul_output_shape = node->get_output_shape(0);
|
||||
if (enable_add) {
|
||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||
if (matmul_on_left_side) {
|
||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||
} else {
|
||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset7::Reshape> reshape;
|
||||
if (enable_fq) {
|
||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||
node,
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
255);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> reshape;
|
||||
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
||||
if (create_reshape_instead_of_transpose) {
|
||||
const auto matmul_output_shape = matmul->get_output_shape(0);
|
||||
auto reshape_instead_of_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{matmul_output_shape.size()}, {matmul_output_shape[1], matmul_output_shape[0]});
|
||||
auto reshape_instead_of_transpose = std::make_shared<ngraph::opset7::Reshape>(matmul, reshape_instead_of_transpose_const, false);
|
||||
auto reshape_instead_of_transpose = std::make_shared<ngraph::opset7::Reshape>(node, reshape_instead_of_transpose_const, false);
|
||||
reshape = reshape_instead_of_transpose;
|
||||
if (enable_last_reshape) {
|
||||
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_instead_of_transpose, shape_const, false);
|
||||
}
|
||||
} else {
|
||||
reshape = std::make_shared<ngraph::opset7::Reshape>(matmul, shape_const, false);
|
||||
reshape = node;
|
||||
if (enable_last_reshape) {
|
||||
reshape = std::make_shared<ngraph::opset7::Reshape>(node, shape_const, false);
|
||||
}
|
||||
}
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
|
||||
@ -153,6 +214,9 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTest) {
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateMatmulFunction({1, 16}, {8, 2}, {2, 1}, false),
|
||||
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 16}, {8, 2}, {2, 1}, true));
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, false),
|
||||
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, true));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) {
|
||||
@ -177,25 +241,59 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
||||
for (auto enable_add : { true, false}) {
|
||||
for (auto matmul_on_left_side : { true, false}) {
|
||||
for (auto enable_fq : { true, false}) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, false),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, true));
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
||||
for (auto enable_add : { true, false }) {
|
||||
for (auto matmul_on_left_side : { true, false }) {
|
||||
for (auto enable_fq : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, false),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, true));
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
|
||||
for (auto enable_add : { true, false }) {
|
||||
for (auto matmul_on_left_side : { true, false }) {
|
||||
for (auto enable_fq : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false));
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
|
||||
for (auto enable_last_reshape : { true, false }) {
|
||||
for (auto enable_add : { true, false }) {
|
||||
for (auto matmul_on_left_side : { true, false }) {
|
||||
for (auto enable_fq : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false));
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user