[GNA] Expanding transformations: swap_input_matmul and handle_transposes_around_matmul (#7333)
* Expanding transformations: swap_input_matmul and handle_transposes_around_matmul * insert_reshape_around_matmul * fixed failed of smoke tests
This commit is contained in:
parent
39120a7f62
commit
ba34a1989c
@ -66,6 +66,7 @@
|
|||||||
#include "transformations/handle_transposes_around_matmul.hpp"
|
#include "transformations/handle_transposes_around_matmul.hpp"
|
||||||
#include "transformations/decompose_2d_conv.hpp"
|
#include "transformations/decompose_2d_conv.hpp"
|
||||||
#include "transformations/convert_padded2valid_conv.hpp"
|
#include "transformations/convert_padded2valid_conv.hpp"
|
||||||
|
#include "transformations/insert_reshape_around_matmul.hpp"
|
||||||
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
|
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
|
||||||
#include "transformations/remove_single_input_concat.hpp"
|
#include "transformations/remove_single_input_concat.hpp"
|
||||||
|
|
||||||
@ -730,10 +731,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
|||||||
manager.register_pass<SplitConvolutionWithFq>();
|
manager.register_pass<SplitConvolutionWithFq>();
|
||||||
manager.register_pass<SplitConvolutionWithBias>();
|
manager.register_pass<SplitConvolutionWithBias>();
|
||||||
manager.register_pass<SplitConvolution>();
|
manager.register_pass<SplitConvolution>();
|
||||||
manager.register_pass<HandleTransposesAroundMatMul>();
|
manager.register_pass<InsertReshapeAroundMatmulWithTranspose>();
|
||||||
|
manager.register_pass<InsertReshapeAroundMatmulWithFq>();
|
||||||
|
manager.register_pass<InsertReshapeAroundMatmulWithAdd>();
|
||||||
|
manager.register_pass<InsertReshapeAroundMatmul>();
|
||||||
manager.register_pass<SwapInputMatMulWithFq>();
|
manager.register_pass<SwapInputMatMulWithFq>();
|
||||||
manager.register_pass<SwapInputMatMulWithBias>();
|
manager.register_pass<SwapInputMatMulWithBias>();
|
||||||
manager.register_pass<SwapInputMatMul>();
|
manager.register_pass<SwapInputMatMul>();
|
||||||
|
manager.register_pass<HandleTransposesAroundMatMul>();
|
||||||
manager.register_pass<InsertTransposeAfterConvOrPool>();
|
manager.register_pass<InsertTransposeAfterConvOrPool>();
|
||||||
manager.register_pass<ReorderActivationAndPooling>();
|
manager.register_pass<ReorderActivationAndPooling>();
|
||||||
manager.register_pass<RemoveSingleInputConcat>();
|
manager.register_pass<RemoveSingleInputConcat>();
|
||||||
|
@ -6,31 +6,33 @@
|
|||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include <ngraph/opsets/opset7.hpp>
|
#include <openvino/cc/ngraph/itt.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
|
||||||
#include <ngraph/pattern/op/or.hpp>
|
|
||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pattern/op/or.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <ie/ie_common.h>
|
||||||
|
|
||||||
#include "gna_plugin_log.hpp"
|
#include "gna_plugin_log.hpp"
|
||||||
#include "backend/gna_limitations.hpp"
|
#include "backend/gna_limitations.hpp"
|
||||||
|
|
||||||
using namespace GNAPluginNS;
|
namespace GNAPluginNS {
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(HandleTransposesAroundMatMul, "HandleTransposesAroundMatMul", 0);
|
NGRAPH_RTTI_DEFINITION(HandleTransposesAroundMatMul, "HandleTransposesAroundMatMul", 0);
|
||||||
NGRAPH_RTTI_DEFINITION(HandleTransposeBeforeMatMul, "HandleTransposeBeforeMatMul", 0);
|
NGRAPH_RTTI_DEFINITION(HandleTransposeBeforeMatMul, "HandleTransposeBeforeMatMul", 0);
|
||||||
NGRAPH_RTTI_DEFINITION(HandleTransposeAfterMatMul, "HandleTransposeAfterMatMul", 0);
|
NGRAPH_RTTI_DEFINITION(HandleTransposeAfterMatMul, "HandleTransposeAfterMatMul", 0);
|
||||||
|
|
||||||
static void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
|
void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
|
||||||
auto shape = transpose_node->get_output_shape(0);
|
auto shape = transpose_node->get_output_shape(0);
|
||||||
auto reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
auto reshape_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||||
ngraph::Shape{shape.size()}, shape);
|
ngraph::Shape{shape.size()}, shape);
|
||||||
auto reshape_node = std::make_shared<ngraph::opset7::Reshape>(transpose_node->input_value(0), reshape_const, false);
|
auto reshape_node = std::make_shared<ngraph::opset8::Reshape>(transpose_node->input_value(0), reshape_const, false);
|
||||||
reshape_node->set_friendly_name(transpose_node->get_friendly_name() + "/reshape");
|
reshape_node->set_friendly_name(transpose_node->get_friendly_name());
|
||||||
ngraph::copy_runtime_info(transpose_node, reshape_node);
|
ngraph::copy_runtime_info(transpose_node, reshape_node);
|
||||||
transpose_node->output(0).replace(reshape_node->output(0));
|
transpose_node->output(0).replace(reshape_node->output(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
|
void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
|
||||||
auto consumers = prev_node->output(0).get_target_inputs();
|
auto consumers = prev_node->output(0).get_target_inputs();
|
||||||
const auto orig_shape = prev_node->get_output_shape(0);
|
const auto orig_shape = prev_node->get_output_shape(0);
|
||||||
std::vector<size_t> transpose_ids;
|
std::vector<size_t> transpose_ids;
|
||||||
@ -44,13 +46,13 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
|
|||||||
std::iota(std::begin(permute_order), std::end(permute_order), 0);
|
std::iota(std::begin(permute_order), std::end(permute_order), 0);
|
||||||
std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]);
|
std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]);
|
||||||
|
|
||||||
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
|
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
|
||||||
auto transpose = std::make_shared<ngraph::opset7::Transpose>(prev_node, transpose_order);
|
auto transpose = std::make_shared<ngraph::opset8::Transpose>(prev_node, transpose_order);
|
||||||
transpose->set_friendly_name(base_name + "/in_transpose");
|
transpose->set_friendly_name(base_name + "/in_transpose");
|
||||||
|
|
||||||
auto reshapeConstAfter = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
auto reshapeConstAfter = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||||
ngraph::Shape{orig_shape.size()}, orig_shape);
|
ngraph::Shape{orig_shape.size()}, orig_shape);
|
||||||
auto reshapeAfter = std::make_shared<ngraph::opset7::Reshape>(transpose, reshapeConstAfter, false);
|
auto reshapeAfter = std::make_shared<ngraph::opset8::Reshape>(transpose, reshapeConstAfter, false);
|
||||||
reshapeAfter->set_friendly_name(base_name + "/reshape_after_transpose");
|
reshapeAfter->set_friendly_name(base_name + "/reshape_after_transpose");
|
||||||
ngraph::copy_runtime_info(prev_node, ngraph::NodeVector{transpose, reshapeAfter});
|
ngraph::copy_runtime_info(prev_node, ngraph::NodeVector{transpose, reshapeAfter});
|
||||||
|
|
||||||
@ -59,74 +61,102 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool VerifyReshape(const ngraph::Output<ngraph::Node>& reshape_out) {
|
||||||
|
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
|
||||||
|
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);
|
||||||
|
return in_shape[0] != out_shape[0];
|
||||||
|
}
|
||||||
|
|
||||||
HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
|
HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
|
||||||
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({ngraph::pattern::any_input(),
|
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
ngraph::pattern::any_input()}, VerifyReshape());
|
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant, ngraph::pattern::any_input(),
|
||||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({reshape,
|
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||||
|
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({}, VerifyReshape);
|
||||||
|
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({reshape,
|
||||||
ngraph::pattern::any_input()});
|
ngraph::pattern::any_input()});
|
||||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose});
|
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
|
||||||
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()});
|
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose}),
|
||||||
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), matmul_input});
|
ngraph::pattern::any_input()});
|
||||||
|
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
|
||||||
|
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fq}),
|
||||||
|
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose, ngraph::pattern::any_input()})});
|
||||||
auto matmul = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
auto matmul = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||||
const auto& pattern_map = m.get_pattern_value_map();
|
const auto& pattern_map = matcher.get_pattern_value_map();
|
||||||
|
auto matmul_iter = pattern_map.find(matmul1);
|
||||||
|
if (matmul_iter == std::end(pattern_map) &&
|
||||||
|
(matmul_iter = pattern_map.find(matmul2)) == std::end(pattern_map)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transpose_reshape_it = pattern_map.find(transpose);
|
||||||
|
if (transpose_reshape_it != std::end(pattern_map)) {
|
||||||
|
ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr());
|
||||||
|
} else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) {
|
||||||
|
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
||||||
|
if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) {
|
||||||
|
auto matmul_node = matmul_iter->second.get_node_shared_ptr();
|
||||||
|
InsertTranspose(reshape_node, matmul_node->get_friendly_name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = pattern_map.find(fq);
|
||||||
|
if (iter != pattern_map.end() ||
|
||||||
|
(iter = pattern_map.find(constant)) != pattern_map.end()) {
|
||||||
|
auto prev_node = iter->second.get_node_shared_ptr();
|
||||||
|
if (!GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) return false;
|
||||||
|
auto matmul_node = iter->second.get_node_shared_ptr();
|
||||||
|
InsertTranspose(prev_node, matmul_node->get_friendly_name());
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
|
||||||
|
this->register_matcher(matcher, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
||||||
|
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>();
|
||||||
|
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
|
||||||
|
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
|
||||||
|
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
|
||||||
|
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
||||||
|
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||||
|
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
|
||||||
|
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
|
||||||
|
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
|
||||||
|
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(
|
||||||
|
{reshape_input, ngraph::pattern::any_input()}, VerifyReshape);
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||||
|
const auto& pattern_map = matcher.get_pattern_value_map();
|
||||||
auto transpose_it = pattern_map.find(transpose);
|
auto transpose_it = pattern_map.find(transpose);
|
||||||
if (transpose_it != std::end(pattern_map)) {
|
if (transpose_it != std::end(pattern_map)) {
|
||||||
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
|
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
|
||||||
} else {
|
} else {
|
||||||
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
||||||
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
|
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
|
||||||
auto matmul_it = pattern_map.find(matmul1);
|
auto iter = pattern_map.find(fq);
|
||||||
auto matmul_out = matmul_it != std::end(pattern_map) ? matmul_it->second : pattern_map.at(matmul2);
|
if (iter == pattern_map.end() &&
|
||||||
InsertTranspose(reshape_node, matmul_out.get_node_shared_ptr()->get_friendly_name());
|
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto node = iter->second.get_node_shared_ptr();
|
||||||
|
InsertTranspose(node, node->get_friendly_name());
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
|
||||||
this->register_matcher(m, callback);
|
this->register_matcher(matcher, callback);
|
||||||
}
|
|
||||||
|
|
||||||
HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
|
||||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>();
|
|
||||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({matmul, ngraph::pattern::any_input(),
|
|
||||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
|
||||||
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, fq});
|
|
||||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({transpose_input, ngraph::pattern::any_input()});
|
|
||||||
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
|
|
||||||
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({reshape_input,
|
|
||||||
ngraph::pattern::any_input()}, VerifyReshape());
|
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
|
||||||
const auto& pattern_map = m.get_pattern_value_map();
|
|
||||||
auto transpose_it = pattern_map.find(transpose);
|
|
||||||
if (transpose_it != std::end(pattern_map)) {
|
|
||||||
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
|
|
||||||
} else {
|
|
||||||
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
|
||||||
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
|
|
||||||
auto matmul_node = pattern_map.at(matmul).get_node_shared_ptr();
|
|
||||||
InsertTranspose(matmul_node, matmul_node->get_friendly_name());
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
|
|
||||||
this->register_matcher(m, callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool VerifyReshape::operator()(const ngraph::Output<ngraph::Node>& reshape_out) const {
|
|
||||||
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
|
|
||||||
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);
|
|
||||||
|
|
||||||
// Check if Reshape changes the final 2d shape of Affine primitive
|
|
||||||
in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), in_shape.end());
|
|
||||||
out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), out_shape.end());
|
|
||||||
return in_shape != out_shape;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
HandleTransposesAroundMatMul::HandleTransposesAroundMatMul() {
|
HandleTransposesAroundMatMul::HandleTransposesAroundMatMul() {
|
||||||
add_matcher<HandleTransposeBeforeMatMul>();
|
add_matcher<HandleTransposeBeforeMatMul>();
|
||||||
add_matcher<HandleTransposeAfterMatMul>();
|
add_matcher<HandleTransposeAfterMatMul>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace GNAPluginNS
|
||||||
|
@ -8,10 +8,6 @@
|
|||||||
|
|
||||||
namespace GNAPluginNS {
|
namespace GNAPluginNS {
|
||||||
|
|
||||||
struct VerifyReshape {
|
|
||||||
bool operator()(const ngraph::Output<ngraph::Node>& reshape_out) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape
|
* @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape
|
||||||
* before MatMul which changes the batch size:
|
* before MatMul which changes the batch size:
|
||||||
@ -48,16 +44,16 @@ public:
|
|||||||
* | |
|
* | |
|
||||||
* [1, A*B] [1, A*B]
|
* [1, A*B] [1, A*B]
|
||||||
*/
|
*/
|
||||||
class HandleTransposeAfterMatMul : public ngraph::pass::MatcherPass {
|
class HandleTransposeAfterMatMul: public ngraph::pass::MatcherPass {
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
HandleTransposeAfterMatMul();
|
|
||||||
};
|
|
||||||
|
|
||||||
class HandleTransposesAroundMatMul: public ngraph::pass::GraphRewrite {
|
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
HandleTransposesAroundMatMul();
|
HandleTransposeAfterMatMul();
|
||||||
|
};
|
||||||
|
|
||||||
|
class HandleTransposesAroundMatMul : public ngraph::pass::GraphRewrite {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
HandleTransposesAroundMatMul();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace GNAPluginNS
|
} // namespace GNAPluginNS
|
||||||
|
@ -0,0 +1,237 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/insert_reshape_around_matmul.hpp"
|
||||||
|
#include <openvino/cc/ngraph/itt.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pattern/op/or.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <ie/ie_common.h>
|
||||||
|
|
||||||
|
#include "gna_plugin_log.hpp"
|
||||||
|
|
||||||
|
namespace GNAPluginNS {
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmul, "InsertReshapeAroundMatmul", 0);
|
||||||
|
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithAdd, "InsertReshapeAroundMatmulWithAdd", 0);
|
||||||
|
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithFq, "InsertReshapeAroundMatmulWithFq", 0);
|
||||||
|
NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithTranspose, "InsertReshapeAroundMatmulWithTranspose", 0);
|
||||||
|
|
||||||
|
static bool InsertReshape(
|
||||||
|
ngraph::pattern::Matcher &matcher,
|
||||||
|
const std::shared_ptr<ngraph::Node>& input,
|
||||||
|
const std::shared_ptr<ngraph::Node>& matmul1,
|
||||||
|
const std::shared_ptr<ngraph::Node>& matmul2,
|
||||||
|
const std::shared_ptr<ngraph::Node>& add1 = nullptr,
|
||||||
|
const std::shared_ptr<ngraph::Node>& add2 = nullptr,
|
||||||
|
const std::shared_ptr<ngraph::Node>& fake_quantize2 = nullptr,
|
||||||
|
const std::shared_ptr<ngraph::Node>& transpose = nullptr) {
|
||||||
|
const auto& pattern_map = matcher.get_pattern_value_map();
|
||||||
|
size_t matmul_input_index = 1;
|
||||||
|
auto iter = pattern_map.find(matmul1);
|
||||||
|
if (iter == pattern_map.end()) {
|
||||||
|
iter = pattern_map.find(matmul2);
|
||||||
|
if ((iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
matmul_input_index = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> matmul_node = iter->second.get_node_shared_ptr();
|
||||||
|
auto matmul_node_shape = matmul_node->get_output_shape(0);
|
||||||
|
if ((iter = pattern_map.find(input)) == std::end(pattern_map)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> first_node = iter->second.get_node_shared_ptr();
|
||||||
|
auto reshape_input_node = std::dynamic_pointer_cast<ngraph::opset8::Reshape>(first_node);
|
||||||
|
bool need_reshape_before = !reshape_input_node || reshape_input_node->get_output_shape(0).size() != 2;
|
||||||
|
if (need_reshape_before) {
|
||||||
|
auto input_shape = first_node->get_output_shape(0);
|
||||||
|
std::vector<size_t> before_shape(2, 1);
|
||||||
|
std::copy_if(input_shape.begin(), input_shape.end(), before_shape.begin(), [](size_t e) { return e > 1; });
|
||||||
|
auto reshape_before_node = std::make_shared<ngraph::opset8::Reshape>(first_node,
|
||||||
|
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{before_shape.size()}, before_shape), false);
|
||||||
|
reshape_before_node->set_friendly_name(matmul_node->get_friendly_name() + "/reshape_before_matmul");
|
||||||
|
ngraph::copy_runtime_info(first_node, reshape_before_node);
|
||||||
|
matmul_node->input(matmul_input_index).replace_source_output(reshape_before_node->output(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> last_node;
|
||||||
|
iter = pattern_map.find(transpose);
|
||||||
|
if (iter == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(fake_quantize2)) == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(add1)) == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(add2)) == pattern_map.end()) {
|
||||||
|
last_node = matmul_node;
|
||||||
|
} else {
|
||||||
|
last_node = iter->second.get_node_shared_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto consumers = last_node->output(0).get_target_inputs();
|
||||||
|
auto last_node_shape = last_node->get_output_shape(0);
|
||||||
|
bool need_reshape_after = false;
|
||||||
|
for (auto consumer : consumers) {
|
||||||
|
auto reshape_output_node = dynamic_cast<ngraph::opset8::Reshape*>(consumer.get_node());
|
||||||
|
if (!reshape_output_node || reshape_output_node->get_output_shape(0).size() != last_node_shape.size()) {
|
||||||
|
need_reshape_after = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (need_reshape_after) {
|
||||||
|
auto reshape_after_node = std::make_shared<ngraph::opset8::Reshape>(last_node,
|
||||||
|
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{last_node_shape.size()}, last_node_shape), false);
|
||||||
|
reshape_after_node->set_friendly_name(last_node->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info(last_node, reshape_after_node);
|
||||||
|
for (auto consumer : consumers) {
|
||||||
|
consumer.replace_source_output(reshape_after_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return need_reshape_before || need_reshape_after;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Node> CreateMatmulPattern(
|
||||||
|
std::shared_ptr<ngraph::Node>& input,
|
||||||
|
std::shared_ptr<ngraph::Node>& matmul1,
|
||||||
|
std::shared_ptr<ngraph::Node>& matmul2,
|
||||||
|
const ngraph::pattern::op::ValuePredicate& pred = [](const ngraph::Output<ngraph::Node>& output) { return true; }) {
|
||||||
|
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
|
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||||
|
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||||
|
input = ngraph::pattern::any_input([](const ngraph::Output<ngraph::Node>& node) {
|
||||||
|
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||||
|
return shape.size() > 2 && std::count_if(shape.begin(), shape.end(), [](size_t e) { return e > 1; }) <= 2; });
|
||||||
|
matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, input}, pred);
|
||||||
|
matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({input, matmul_input}, pred);
|
||||||
|
return std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
||||||
|
}
|
||||||
|
|
||||||
|
InsertReshapeAroundMatmul::InsertReshapeAroundMatmul() {
|
||||||
|
MATCHER_SCOPE(InsertReshapeAroundMatmul);
|
||||||
|
|
||||||
|
auto pred = [](const ngraph::Output<ngraph::Node>& node) {
|
||||||
|
const auto& outputs = node.get_node_shared_ptr()->outputs();
|
||||||
|
const auto& inputs = outputs[0].get_target_inputs();
|
||||||
|
if (inputs.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto next_node = inputs.begin()->get_node();
|
||||||
|
return outputs.size() != 1 ||
|
||||||
|
!dynamic_cast<ngraph::opset8::Transpose*>(next_node) &&
|
||||||
|
!dynamic_cast<ngraph::opset8::FakeQuantize*>(next_node) &&
|
||||||
|
!dynamic_cast<ngraph::opset8::Add*>(next_node);
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> input;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul1;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul2;
|
||||||
|
auto matmul = CreateMatmulPattern(input, matmul1, matmul2, pred);
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||||
|
return InsertReshape(matcher, input, matmul1, matmul2);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "InsertReshapeAroundMatmul");
|
||||||
|
this->register_matcher(matcher, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
InsertReshapeAroundMatmulWithAdd::InsertReshapeAroundMatmulWithAdd() {
|
||||||
|
MATCHER_SCOPE(InsertReshapeAroundMatmulWithAdd);
|
||||||
|
|
||||||
|
auto pred = [](const ngraph::Output<ngraph::Node>& node) {
|
||||||
|
const auto& outputs = node.get_node_shared_ptr()->outputs();
|
||||||
|
const auto& inputs = outputs[0].get_target_inputs();
|
||||||
|
if (inputs.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto next_node = inputs.begin()->get_node();
|
||||||
|
return outputs.size() != 1 ||
|
||||||
|
!dynamic_cast<ngraph::opset8::Transpose*>(next_node) &&
|
||||||
|
!dynamic_cast<ngraph::opset8::FakeQuantize*>(next_node);
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> input;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul1;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul2;
|
||||||
|
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
|
||||||
|
auto add_input = ngraph::pattern::any_input();
|
||||||
|
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input}, pred);
|
||||||
|
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul}, pred);
|
||||||
|
auto add = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add1, add2});
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||||
|
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(add, "InsertReshapeAroundMatmulWithAdd");
|
||||||
|
this->register_matcher(matcher, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
InsertReshapeAroundMatmulWithFq::InsertReshapeAroundMatmulWithFq() {
|
||||||
|
MATCHER_SCOPE(InsertReshapeAroundMatmulWithFq);
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> input;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul1;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul2;
|
||||||
|
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
|
||||||
|
auto add_input = ngraph::pattern::any_input();
|
||||||
|
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input});
|
||||||
|
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul});
|
||||||
|
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add1, add2});
|
||||||
|
auto fake_quantize2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
||||||
|
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()},
|
||||||
|
[](const ngraph::Output<ngraph::Node>& node) {
|
||||||
|
const auto& outputs = node.get_node_shared_ptr()->outputs();
|
||||||
|
const auto& inputs = outputs[0].get_target_inputs();
|
||||||
|
if (inputs.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto next_node = inputs.begin()->get_node();
|
||||||
|
return outputs.size() != 1 ||
|
||||||
|
!dynamic_cast<ngraph::opset8::Transpose*>(next_node);
|
||||||
|
});
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||||
|
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(fake_quantize2, "InsertReshapeAroundMatmulWithFq");
|
||||||
|
this->register_matcher(matcher, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
InsertReshapeAroundMatmulWithTranspose::InsertReshapeAroundMatmulWithTranspose() {
|
||||||
|
MATCHER_SCOPE(InsertReshapeAroundMatmulWithTranspose);
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> input;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul1;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul2;
|
||||||
|
auto matmul = CreateMatmulPattern(input, matmul1, matmul2);
|
||||||
|
auto add_input = ngraph::pattern::any_input();
|
||||||
|
auto add1 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, add_input});
|
||||||
|
auto add2 = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, matmul});
|
||||||
|
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add1, add2});
|
||||||
|
auto fake_quantize2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
||||||
|
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||||
|
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fake_quantize2});
|
||||||
|
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||||
|
return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2, transpose);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(transpose, "InsertReshapeAroundMatmulWithTranspose");
|
||||||
|
this->register_matcher(matcher, callback);
|
||||||
|
}
|
||||||
|
} // namespace GNAPluginNS
|
@ -0,0 +1,39 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef INSERT_RESHAPE_AROUND_MATMUL_HPP
|
||||||
|
#define INSERT_RESHAPE_AROUND_MATMUL_HPP
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace GNAPluginNS {
|
||||||
|
|
||||||
|
// @brief Insert Reshapes from 3d/4d to 2d before MatMul and from 2d to 3d/4d after MatMul
|
||||||
|
class InsertReshapeAroundMatmul : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
InsertReshapeAroundMatmul();
|
||||||
|
};
|
||||||
|
|
||||||
|
class InsertReshapeAroundMatmulWithAdd : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
InsertReshapeAroundMatmulWithAdd();
|
||||||
|
};
|
||||||
|
|
||||||
|
class InsertReshapeAroundMatmulWithFq : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
InsertReshapeAroundMatmulWithFq();
|
||||||
|
};
|
||||||
|
|
||||||
|
class InsertReshapeAroundMatmulWithTranspose : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
InsertReshapeAroundMatmulWithTranspose();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace GNAPluginNS
|
||||||
|
|
||||||
|
#endif // INSERT_RESHAPE_AROUND_MATMUL_HPP
|
@ -2,31 +2,34 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include <transformations/swap_input_matmul_gna.hpp>
|
||||||
#include <openvino/cc/ngraph/itt.hpp>
|
#include <openvino/cc/ngraph/itt.hpp>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
#include <ngraph/pattern/op/or.hpp>
|
|
||||||
#include <ngraph/opsets/opset8.hpp>
|
|
||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <transformations/swap_input_matmul_gna.hpp>
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pattern/op/or.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <ie/ie_common.h>
|
||||||
|
|
||||||
#include "gna_plugin_log.hpp"
|
#include "gna_plugin_log.hpp"
|
||||||
|
|
||||||
using namespace GNAPluginNS;
|
namespace GNAPluginNS {
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
|
NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
|
||||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0);
|
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0);
|
||||||
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0);
|
NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0);
|
||||||
|
|
||||||
static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
|
static void SwapAndTransposeInputs(
|
||||||
std::shared_ptr<ngraph::Node> add,
|
std::shared_ptr<ngraph::opset8::MatMul> matmul_node,
|
||||||
std::shared_ptr<ngraph::Node> bias,
|
std::shared_ptr<ngraph::Node> add,
|
||||||
std::shared_ptr<ngraph::Node> fq) {
|
std::shared_ptr<ngraph::Node> bias,
|
||||||
|
std::shared_ptr<ngraph::Node> fq,
|
||||||
|
const std::string& last_layer_name) {
|
||||||
auto create_transpose =
|
auto create_transpose =
|
||||||
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
[](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
||||||
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
|
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
|
||||||
@ -52,15 +55,28 @@ static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmu
|
|||||||
|
|
||||||
std::shared_ptr<ngraph::Node> old_root_node = matmul_node;
|
std::shared_ptr<ngraph::Node> old_root_node = matmul_node;
|
||||||
if (bias != nullptr) {
|
if (bias != nullptr) {
|
||||||
// output of MatMul will be transposed comparing with original one, so the bias should be transposed too
|
// output of MatMul will be transposed comparing with original one, so the bias should be transposed too
|
||||||
if (bias->get_output_shape(0).size() > 1) {
|
if (bias->get_output_shape(0).size() > 1) {
|
||||||
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
|
bias = create_transpose(bias, bias->get_friendly_name() + "/transpose");
|
||||||
new_ops.push_back(bias);
|
new_ops.push_back(bias);
|
||||||
}
|
|
||||||
|
|
||||||
new_matmul = std::make_shared<ngraph::opset8::Add>(new_matmul, bias);
|
auto transpose_shape = bias->get_output_shape(0);
|
||||||
old_root_node = add;
|
auto matmul_shape = matmul_node->get_output_shape(0);
|
||||||
new_ops.push_back(new_matmul);
|
if (transpose_shape.size() > matmul_shape.size()) {
|
||||||
|
std::vector<size_t> reshape_shape(matmul_shape.size(), 1);
|
||||||
|
std::copy_if(transpose_shape.begin(), transpose_shape.end(), reshape_shape.begin(), [](size_t e) { return e > 1; });
|
||||||
|
bias = std::make_shared<ngraph::opset8::Reshape>(bias,
|
||||||
|
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
|
||||||
|
ngraph::Shape{reshape_shape.size()}, reshape_shape), false);
|
||||||
|
bias->set_friendly_name(add->get_friendly_name() + "/reshape");
|
||||||
|
ngraph::copy_runtime_info(add, bias);
|
||||||
|
new_ops.push_back(bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new_matmul = std::make_shared<ngraph::opset8::Add>(new_matmul, bias);
|
||||||
|
old_root_node = add;
|
||||||
|
new_ops.push_back(new_matmul);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fq != nullptr) {
|
if (fq != nullptr) {
|
||||||
@ -70,113 +86,151 @@ static void SwapAndTransposeInputs(std::shared_ptr<ngraph::opset8::MatMul> matmu
|
|||||||
new_ops.push_back(new_matmul);
|
new_ops.push_back(new_matmul);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto output = create_transpose(new_matmul, matmul_node->get_friendly_name());
|
auto output = create_transpose(new_matmul, last_layer_name);
|
||||||
new_ops.push_back(output);
|
new_ops.push_back(output);
|
||||||
|
|
||||||
ngraph::copy_runtime_info(matmul_node, new_ops);
|
ngraph::copy_runtime_info(matmul_node, new_ops);
|
||||||
ngraph::replace_node(old_root_node, output);
|
ngraph::replace_node(old_root_node, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
SwapInputMatMul::SwapInputMatMul() {
|
static std::shared_ptr<ngraph::Node> CreateMatmul(
|
||||||
MATCHER_SCOPE(SwapInputMatMul);
|
bool is_first_constant,
|
||||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
ngraph::pattern::op::ValuePredicate const_predicate,
|
||||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
ngraph::pattern::op::ValuePredicate matmul_predicate = ngraph::pattern::has_static_shape()) {
|
||||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, const_predicate);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
if (is_first_constant) {
|
||||||
ngraph::pattern::has_static_shape());
|
return ngraph::pattern::wrap_type<ngraph::opset8::MatMul>(
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
{matmul_input, ngraph::pattern::any_input()}, matmul_predicate);
|
||||||
|
}
|
||||||
|
return ngraph::pattern::wrap_type<ngraph::opset8::MatMul>(
|
||||||
|
{ngraph::pattern::any_input(), matmul_input}, matmul_predicate);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Node> CreateMatmuls(
|
||||||
|
std::shared_ptr<ngraph::Node>& matmul1,
|
||||||
|
std::shared_ptr<ngraph::Node>& matmul2) {
|
||||||
|
matmul1 = CreateMatmul(
|
||||||
|
true,
|
||||||
|
[](const ngraph::Output<ngraph::Node>& node) { return true; },
|
||||||
|
[](const ngraph::Output<ngraph::Node>& node) {
|
||||||
|
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(node.get_node_shared_ptr());
|
||||||
|
IE_ASSERT(matmul_node != nullptr);
|
||||||
|
auto input_shape = matmul_node->get_input_shape(0);
|
||||||
|
return input_shape.size() == 2 &&
|
||||||
|
(!matmul_node->get_transpose_a() && input_shape[0] > 8 ||
|
||||||
|
matmul_node->get_transpose_a() && input_shape[1] > 8); });
|
||||||
|
matmul2 = CreateMatmul(
|
||||||
|
false,
|
||||||
|
[](const ngraph::Output<ngraph::Node>& node) { return true; },
|
||||||
|
[](const ngraph::Output<ngraph::Node>& node) {
|
||||||
|
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(node.get_node_shared_ptr());
|
||||||
|
IE_ASSERT(matmul_node != nullptr);
|
||||||
|
auto first_input_shape = matmul_node->get_input_shape(0);
|
||||||
|
first_input_shape.erase(std::remove(first_input_shape.begin(), first_input_shape.end(), 1), first_input_shape.end());
|
||||||
|
auto second_input_shape = matmul_node->get_input_shape(1);
|
||||||
|
return node.get_partial_shape().is_static() &&
|
||||||
|
second_input_shape.size() == 2 &&
|
||||||
|
(!matmul_node->get_transpose_b() && second_input_shape[1] <= 8 ||
|
||||||
|
matmul_node->get_transpose_b() && second_input_shape[0] <= 8) &&
|
||||||
|
first_input_shape.size() == 2 &&
|
||||||
|
first_input_shape[0] > 8; });
|
||||||
|
return std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
||||||
|
}
|
||||||
|
|
||||||
|
SwapInputMatMul::SwapInputMatMul() {
|
||||||
|
MATCHER_SCOPE(SwapInputMatMul);
|
||||||
|
std::shared_ptr<ngraph::Node> matmul1;
|
||||||
|
std::shared_ptr<ngraph::Node> matmul2;
|
||||||
|
auto matmul = CreateMatmuls(matmul1, matmul2);
|
||||||
|
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
const auto& pattern_map = m.get_pattern_value_map();
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
auto iter = pattern_map.find(matmul1);
|
||||||
|
if (iter == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
|
||||||
IE_ASSERT(matmul_node != nullptr);
|
IE_ASSERT(matmul_node != nullptr);
|
||||||
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr);
|
SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr, "");
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, matcher_name);
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "SwapInputMatMul");
|
||||||
this->register_matcher(m, callback);
|
this->register_matcher(matcher, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
SwapInputMatMulWithBias::SwapInputMatMulWithBias() {
|
SwapInputMatMulWithBias::SwapInputMatMulWithBias() {
|
||||||
MATCHER_SCOPE(SwapInputMatMulWithBias);
|
MATCHER_SCOPE(SwapInputMatMulWithBias);
|
||||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
std::shared_ptr<ngraph::Node> matmul1;
|
||||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
std::shared_ptr<ngraph::Node> matmul2;
|
||||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
auto matmul = CreateMatmuls(matmul1, matmul2);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
|
||||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
|
||||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
|
||||||
ngraph::pattern::has_static_shape());
|
|
||||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
||||||
|
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
|
||||||
const auto& pattern_map = m.get_pattern_value_map();
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
auto iter = pattern_map.find(matmul1);
|
||||||
|
if (iter == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
|
||||||
IE_ASSERT(matmul_node != nullptr);
|
IE_ASSERT(matmul_node != nullptr);
|
||||||
SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(),
|
SwapAndTransposeInputs(
|
||||||
pattern_map.at(bias).get_node_shared_ptr(), nullptr);
|
matmul_node,
|
||||||
|
pattern_map.at(add).get_node_shared_ptr(),
|
||||||
|
pattern_map.at(bias).get_node_shared_ptr(),
|
||||||
|
nullptr,
|
||||||
|
pattern_map.at(add).get_node_shared_ptr()->get_friendly_name());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(add, "SwapInputMatMulWithBias");
|
||||||
this->register_matcher(m, callback);
|
this->register_matcher(matcher, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
SwapInputMatMulWithFq::SwapInputMatMulWithFq() {
|
SwapInputMatMulWithFq::SwapInputMatMulWithFq() {
|
||||||
MATCHER_SCOPE(SwapInputMatMulWithFq);
|
MATCHER_SCOPE(SwapInputMatMulWithFq);
|
||||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
std::shared_ptr<ngraph::Node> matmul1;
|
||||||
auto shape = node.get_node_shared_ptr()->get_output_shape(0);
|
std::shared_ptr<ngraph::Node> matmul2;
|
||||||
if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) {
|
auto matmul = CreateMatmuls(matmul1, matmul2);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
|
||||||
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
|
||||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({matmul_input, ngraph::pattern::any_input()},
|
|
||||||
ngraph::pattern::has_static_shape());
|
|
||||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
auto bias = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, bias});
|
||||||
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
|
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
|
||||||
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul_out,
|
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input,
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||||
|
auto callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
|
||||||
const auto& pattern_map = m.get_pattern_value_map();
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
|
auto iter = pattern_map.find(matmul1);
|
||||||
|
if (iter == pattern_map.end() &&
|
||||||
|
(iter = pattern_map.find(matmul2)) == pattern_map.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter_add = pattern_map.find(add);
|
||||||
|
auto iter_bias = pattern_map.find(bias);
|
||||||
|
auto matmul_node = std::dynamic_pointer_cast<ngraph::opset8::MatMul>(iter->second.get_node_shared_ptr());
|
||||||
IE_ASSERT(matmul_node != nullptr);
|
IE_ASSERT(matmul_node != nullptr);
|
||||||
auto add_it = pattern_map.find(add);
|
SwapAndTransposeInputs(
|
||||||
auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr());
|
matmul_node,
|
||||||
auto bias_it = pattern_map.find(bias);
|
iter_add != pattern_map.end() ? iter_add->second.get_node_shared_ptr() : nullptr,
|
||||||
auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr());
|
iter_bias != pattern_map.end() ? iter_bias->second.get_node_shared_ptr() : nullptr,
|
||||||
SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr());
|
pattern_map.at(fq).get_node_shared_ptr(),
|
||||||
|
pattern_map.at(fq).get_node_shared_ptr()->get_friendly_name());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, matcher_name);
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(fq, "SwapInputMatMulWithFq");
|
||||||
this->register_matcher(m, callback);
|
this->register_matcher(matcher, callback);
|
||||||
}
|
}
|
||||||
|
} // namespace GNAPluginNS
|
||||||
|
@ -2,15 +2,15 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#pragma once
|
#ifndef SWAP_INPUT_MATMUL_GNA_HPP
|
||||||
|
#define SWAP_INPUT_MATMUL_GNA_HPP
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <transformations_visibility.hpp>
|
|
||||||
#include <ngraph/pass/graph_rewrite.hpp>
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
namespace GNAPluginNS {
|
namespace GNAPluginNS {
|
||||||
|
// @brief Swaps and transposes inputs of MatMul if
|
||||||
// @brief Swaps and transposes inputs of MatMul if its first input is const and its batch size isn't supported by GNA
|
// 1. its first input is const and its batch size isn't supported by GNA
|
||||||
|
// 2. its first input is non-const and its batch size isn't supported by GNA
|
||||||
class SwapInputMatMul: public ngraph::pass::MatcherPass {
|
class SwapInputMatMul: public ngraph::pass::MatcherPass {
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
@ -28,4 +28,6 @@ public:
|
|||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
SwapInputMatMulWithFq();
|
SwapInputMatMulWithFq();
|
||||||
};
|
};
|
||||||
} // namespace GNAPluginNS
|
} // namespace GNAPluginNS
|
||||||
|
|
||||||
|
#endif // SWAP_INPUT_MATMUL_GNA_HPP
|
||||||
|
@ -99,7 +99,8 @@ const std::vector<std::vector<std::vector<size_t>>> input_shapes = {
|
|||||||
{{1, 8}, {8, 1}},
|
{{1, 8}, {8, 1}},
|
||||||
{{128, 8}, {8, 1}},
|
{{128, 8}, {8, 1}},
|
||||||
{{8, 8}, {8, 8}},
|
{{8, 8}, {8, 8}},
|
||||||
{{1, 16}, {16, 8}}
|
{{1, 16}, {16, 8}},
|
||||||
|
{{6, 16}, {16, 8}}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -110,4 +111,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_convert_matmul_to_fc, ConvertMatmulToFcPass,
|
|||||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||||
::testing::ValuesIn(configs)),
|
::testing::ValuesIn(configs)),
|
||||||
ConvertMatmulToFcPass::getTestCaseName);
|
ConvertMatmulToFcPass::getTestCaseName);
|
||||||
} // namespace LayerTestsDefinitions
|
} // namespace LayerTestsDefinitions
|
||||||
|
@ -0,0 +1,190 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "transformations/insert_reshape_around_matmul.hpp"
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
template<bool ADD, bool ADD_FIRST_INPUT_NOT_CONSTANT, bool FQ>
|
||||||
|
struct InsertReshapeAroundMatmulTest {
|
||||||
|
static std::shared_ptr<ngraph::Node> CreateAdd(std::shared_ptr<ngraph::Node> input, const ngraph::Shape& constant_shape) {
|
||||||
|
std::vector<size_t> data(ngraph::shape_size(constant_shape));
|
||||||
|
std::iota(std::begin(data), std::end(data), 1);
|
||||||
|
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, constant_shape, data);
|
||||||
|
return std::make_shared<ngraph::opset8::Add>(input, constant);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Node> CreateMatmul(
|
||||||
|
std::shared_ptr<ngraph::Node> input,
|
||||||
|
const ngraph::Shape& matmul_constant_shape) {
|
||||||
|
std::vector<size_t> data(ngraph::shape_size(matmul_constant_shape));
|
||||||
|
std::iota(std::begin(data), std::end(data), 1);
|
||||||
|
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, matmul_constant_shape, data);
|
||||||
|
std::shared_ptr<ngraph::Node> node;
|
||||||
|
node = std::make_shared<ngraph::opset8::MatMul>(input, constant);
|
||||||
|
|
||||||
|
if (ADD) {
|
||||||
|
auto matmul_shape = node->get_output_shape(0);
|
||||||
|
data.resize(ngraph::shape_size(matmul_shape));
|
||||||
|
std::iota(std::begin(data), std::end(data), 1);
|
||||||
|
std::vector<size_t> constant_add_shape(2, 1);
|
||||||
|
std::copy_if(matmul_shape.begin(), matmul_shape.end(), constant_add_shape.begin(), [](size_t e) { return e > 1; });
|
||||||
|
auto constant_add = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{constant_add_shape}, data);
|
||||||
|
if (ADD_FIRST_INPUT_NOT_CONSTANT) {
|
||||||
|
node = std::make_shared<ngraph::opset8::Add>(node, constant_add);
|
||||||
|
} else {
|
||||||
|
node = std::make_shared<ngraph::opset8::Add>(constant_add, node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FQ) {
|
||||||
|
node = std::make_shared<ngraph::opset8::FakeQuantize>(
|
||||||
|
node,
|
||||||
|
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
255);
|
||||||
|
}
|
||||||
|
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Function> CreateFunction(
|
||||||
|
const ngraph::Shape& input_shape,
|
||||||
|
const ngraph::Shape& matmul_constant_shape,
|
||||||
|
const ngraph::Shape& result_shape) {
|
||||||
|
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
|
||||||
|
auto before = std::make_shared<ngraph::opset8::Relu>(input);
|
||||||
|
auto matmul = CreateMatmul(before, matmul_constant_shape);
|
||||||
|
auto after = std::make_shared<ngraph::opset8::Relu>(matmul);
|
||||||
|
return std::make_shared<ngraph::Function>(
|
||||||
|
ngraph::ResultVector{std::make_shared<ngraph::opset8::Result>(after)},
|
||||||
|
ngraph::ParameterVector{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Function> CreateReferenceFunction(
|
||||||
|
const ngraph::Shape& input_shape,
|
||||||
|
const ngraph::Shape& reshape_before_shape,
|
||||||
|
const ngraph::Shape& matmul_constant_shape,
|
||||||
|
const ngraph::Shape& reshape_after_shape,
|
||||||
|
const ngraph::Shape& result_shape) {
|
||||||
|
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
|
||||||
|
auto before = std::make_shared<ngraph::opset8::Relu>(input);
|
||||||
|
auto reshape_before_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
|
||||||
|
ngraph::Shape{reshape_before_shape.size()}, reshape_before_shape);
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset8::Reshape>(before, reshape_before_constant, false);
|
||||||
|
auto matmul = CreateMatmul(reshape_before, matmul_constant_shape);
|
||||||
|
auto reshape_after_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
|
||||||
|
ngraph::Shape{reshape_after_shape.size()}, reshape_after_shape);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset8::Reshape>(matmul, reshape_after_constant, false);
|
||||||
|
auto after = std::make_shared<ngraph::opset8::Relu>(reshape_after);
|
||||||
|
return std::make_shared<ngraph::Function>(
|
||||||
|
ngraph::ResultVector{std::make_shared<ngraph::opset8::Result>(after)},
|
||||||
|
ngraph::ParameterVector{input});
|
||||||
|
}
|
||||||
|
}; // struct InsertReshapeAroundMatmulTest
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void RunTest(const std::shared_ptr<ngraph::Function>& func, const std::shared_ptr<ngraph::Function>& reference_func) {
|
||||||
|
{
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithTranspose>();
|
||||||
|
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithFq>();
|
||||||
|
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmulWithAdd>();
|
||||||
|
m.register_pass<GNAPluginNS::InsertReshapeAroundMatmul>();
|
||||||
|
m.run_passes(func);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(func));
|
||||||
|
}
|
||||||
|
|
||||||
|
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||||
|
const FunctionsComparator::Result result = func_comparator(func, reference_func);
|
||||||
|
ASSERT_TRUE(result.valid);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(TransformationTests, InsertReshapeAroundMatmul) {
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateFunction({1, 6, 1, 8}, {8, 10}, {1, 6, 1, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd) {
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||||
|
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd_AddFirstInputConstant) {
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||||
|
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<true, false, false>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, InsertReshapeAroundMatmulWithFq) {
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||||
|
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<false, false, true>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, InsertReshapeAroundMatmulWithAddAndFq) {
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||||
|
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
RunTest(
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
|
||||||
|
InsertReshapeAroundMatmulTest<true, true, true>::
|
||||||
|
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
|
||||||
|
}
|
@ -20,7 +20,8 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
|
|||||||
bool withBias,
|
bool withBias,
|
||||||
bool withWeightsFq,
|
bool withWeightsFq,
|
||||||
bool withOutFq,
|
bool withOutFq,
|
||||||
bool swappedInputs) {
|
bool swappedInputs,
|
||||||
|
bool needTranspose) {
|
||||||
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input2_shape);
|
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input2_shape);
|
||||||
|
|
||||||
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, input1_shape, {1});
|
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, input1_shape, {1});
|
||||||
@ -33,14 +34,14 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
|
|||||||
const_input = std::make_shared<ngraph::opset8::FakeQuantize>(const_input, input_low, input_high,
|
const_input = std::make_shared<ngraph::opset8::FakeQuantize>(const_input, input_low, input_high,
|
||||||
output_low, output_high, 11);
|
output_low, output_high, 11);
|
||||||
}
|
}
|
||||||
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, true, true) :
|
auto matmul = swappedInputs ? std::make_shared<ngraph::opset8::MatMul>(input_params, const_input, needTranspose, needTranspose) :
|
||||||
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params);
|
std::make_shared<ngraph::opset8::MatMul>(const_input, input_params, needTranspose, needTranspose);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> final_node = matmul;
|
std::shared_ptr<ngraph::Node> final_node = matmul;
|
||||||
if (withBias) {
|
if (withBias) {
|
||||||
auto bias = ngraph::opset8::Constant::create(ngraph::element::i64, bias_shape, {1});
|
auto bias = ngraph::opset8::Constant::create(ngraph::element::i64, bias_shape, {1});
|
||||||
std::shared_ptr<ngraph::Node> bias_node = bias;
|
std::shared_ptr<ngraph::Node> bias_node = bias;
|
||||||
if (swappedInputs && bias_shape.size() > 1) {
|
if (needTranspose && bias_shape.size() > 1) {
|
||||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||||
std::vector<size_t>{1, 0});
|
std::vector<size_t>{1, 0});
|
||||||
bias_node = std::make_shared<ngraph::opset8::Transpose>(bias_node, transpose_order);
|
bias_node = std::make_shared<ngraph::opset8::Transpose>(bias_node, transpose_order);
|
||||||
@ -57,7 +58,7 @@ static std::shared_ptr<ngraph::Function> CreateMatMulFunction(const ngraph::Shap
|
|||||||
output_low, output_high, 11);
|
output_low, output_high, 11);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (swappedInputs) {
|
if (needTranspose) {
|
||||||
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2},
|
||||||
std::vector<size_t>{1, 0});
|
std::vector<size_t>{1, 0});
|
||||||
final_node = std::make_shared<ngraph::opset8::Transpose>(final_node, transpose_order);
|
final_node = std::make_shared<ngraph::opset8::Transpose>(final_node, transpose_order);
|
||||||
@ -104,6 +105,12 @@ static std::string getTestCaseName(testing::TestParamInfo<SwapInputMatmulParams>
|
|||||||
return result.str();
|
return result.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum class MatmulInputType {
|
||||||
|
FirstInputConstant,
|
||||||
|
SecondInputConstant
|
||||||
|
}; // enum class MatmulInputType
|
||||||
|
|
||||||
|
template<MatmulInputType E>
|
||||||
class SwapInputMatmul : public CommonTestUtils::TestsCommon,
|
class SwapInputMatmul : public CommonTestUtils::TestsCommon,
|
||||||
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
||||||
public:
|
public:
|
||||||
@ -112,14 +119,24 @@ public:
|
|||||||
bool withBias, withWeightsFq, withOutFq;
|
bool withBias, withWeightsFq, withOutFq;
|
||||||
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
||||||
|
|
||||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
|
bool swap_inputs = false;
|
||||||
|
switch (E) {
|
||||||
|
case MatmulInputType::FirstInputConstant:
|
||||||
|
break;
|
||||||
|
case MatmulInputType::SecondInputConstant:
|
||||||
|
swap_inputs = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false);
|
||||||
reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq,
|
reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq,
|
||||||
withOutFq, true);
|
withOutFq, !swap_inputs, true);
|
||||||
}
|
}
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<MatmulInputType E>
|
||||||
class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon,
|
class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon,
|
||||||
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
public ::testing::WithParamInterface<SwapInputMatmulParams> {
|
||||||
public:
|
public:
|
||||||
@ -128,42 +145,92 @@ public:
|
|||||||
bool withBias, withWeightsFq, withOutFq;
|
bool withBias, withWeightsFq, withOutFq;
|
||||||
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam();
|
||||||
|
|
||||||
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false);
|
bool swap_inputs = false;
|
||||||
|
switch (E) {
|
||||||
|
case MatmulInputType::FirstInputConstant:
|
||||||
|
break;
|
||||||
|
case MatmulInputType::SecondInputConstant:
|
||||||
|
swap_inputs = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false);
|
||||||
reference_function = ngraph::clone_function(*function);
|
reference_function = ngraph::clone_function(*function);
|
||||||
}
|
}
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(SwapInputMatmul, CompareFunctions) {
|
using SwapInputMatmulWithFirstInputConstant = SwapInputMatmul<MatmulInputType::FirstInputConstant>;
|
||||||
|
using SwapInputMatmulWithSecondInputConstant = SwapInputMatmul<MatmulInputType::SecondInputConstant>;
|
||||||
|
using SwapInputMatmulWithFirstInputConstantNotApplied = SwapInputMatmulNotApplied<MatmulInputType::FirstInputConstant>;
|
||||||
|
using SwapInputMatmulWithSecondInputConstantNotApplied = SwapInputMatmulNotApplied<MatmulInputType::SecondInputConstant>;
|
||||||
|
|
||||||
|
TEST_P(SwapInputMatmulWithFirstInputConstant, CompareFunctions) {
|
||||||
Execute(function, reference_function);
|
Execute(function, reference_function);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SwapInputMatmulNotApplied, CompareFunctions) {
|
TEST_P(SwapInputMatmulWithFirstInputConstantNotApplied, CompareFunctions) {
|
||||||
Execute(function, reference_function);
|
Execute(function, reference_function);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_applied = {
|
TEST_P(SwapInputMatmulWithSecondInputConstant, CompareFunctions) {
|
||||||
|
Execute(function, reference_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(SwapInputMatmulWithSecondInputConstantNotApplied, CompareFunctions) {
|
||||||
|
Execute(function, reference_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_first_constant_applied = {
|
||||||
{{16, 8}, {8, 8}, {16, 8}},
|
{{16, 8}, {8, 8}, {16, 8}},
|
||||||
{{16, 8}, {8, 8}, {1}},
|
{{16, 8}, {8, 8}, {1}},
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<std::vector<ngraph::Shape>> input_shapes_not_applied = {
|
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_first_constant_not_applied = {
|
||||||
{{1, 8}, {8, 8}, {1, 8}},
|
{{1, 8}, {8, 8}, {1, 8}},
|
||||||
{{8}, {8, 8}, {8}}
|
{{8}, {8, 8}, {8}}
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmul,
|
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_second_constant_applied = {
|
||||||
|
{{64, 6}, {100, 64}, {100, 6}},
|
||||||
|
{{64, 6}, {100, 64}, {1}},
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<std::vector<ngraph::Shape>> input_shapes_for_matmul_with_second_constant_not_applied = {
|
||||||
|
{{64, 16}, {100, 64}, {100, 16}},
|
||||||
|
{{64, 6}, {8, 64}, {8, 6}},
|
||||||
|
{{8, 1}, {8, 8}, {8, 1}},
|
||||||
|
{{8}, {8, 8}, {8}}
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstant,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::ValuesIn(input_shapes_applied),
|
::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_applied),
|
||||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||||
getTestCaseName);
|
getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied,
|
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstantNotApplied,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::ValuesIn(input_shapes_not_applied),
|
::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_not_applied),
|
||||||
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
|
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||||
|
getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstant,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_applied),
|
||||||
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
|
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||||
|
getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstantNotApplied,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_not_applied),
|
||||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||||
::testing::ValuesIn(std::vector<bool>{false, true})),
|
::testing::ValuesIn(std::vector<bool>{false, true})),
|
||||||
|
@ -70,56 +70,117 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& inpu
|
|||||||
|
|
||||||
namespace handle_transpose_after_matmul {
|
namespace handle_transpose_after_matmul {
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(const ngraph::Shape& input_shape,
|
std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
||||||
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_after_transpose) {
|
const ngraph::Shape& input_shape,
|
||||||
|
const ngraph::Shape& matmul_shape,
|
||||||
|
const ngraph::Shape& reshape_shape,
|
||||||
|
bool create_reshape_after_transpose,
|
||||||
|
bool enable_last_reshape,
|
||||||
|
bool enable_add,
|
||||||
|
bool matmul_on_left_side,
|
||||||
|
bool enable_fq) {
|
||||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||||
|
|
||||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||||
std::iota(std::begin(data), std::end(data), 1);
|
std::iota(std::begin(data), std::end(data), 1);
|
||||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||||
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||||
const auto matmul_output_shape = matmul->get_output_shape(0);
|
const auto matmul_output_shape = node->get_output_shape(0);
|
||||||
|
if (enable_add) {
|
||||||
|
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||||
|
if (matmul_on_left_side) {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||||
|
} else {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enable_fq) {
|
||||||
|
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||||
|
node,
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
255);
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
|
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
|
||||||
auto transpose = std::make_shared<ngraph::opset7::Transpose>(matmul, transpose_order);
|
auto transpose = std::make_shared<ngraph::opset7::Transpose>(node, transpose_order);
|
||||||
const auto transpose_output_shape = transpose->get_output_shape(0);
|
const auto transpose_output_shape = transpose->get_output_shape(0);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::opset7::Reshape> reshape;
|
std::shared_ptr<ngraph::Node> reshape;
|
||||||
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
||||||
if (create_reshape_after_transpose) {
|
if (create_reshape_after_transpose) {
|
||||||
const auto matmul_output_shape = matmul->get_output_shape(0);
|
const auto matmul_output_shape = node->get_output_shape(0);
|
||||||
auto reshape_after_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
auto reshape_after_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||||
ngraph::Shape{matmul_output_shape.size()}, matmul_output_shape);
|
ngraph::Shape{matmul_output_shape.size()}, matmul_output_shape);
|
||||||
auto reshape_after_transpose = std::make_shared<ngraph::opset7::Reshape>(transpose, reshape_after_transpose_const, false);
|
auto reshape_after_transpose = std::make_shared<ngraph::opset7::Reshape>(transpose, reshape_after_transpose_const, false);
|
||||||
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_after_transpose, shape_const, false);
|
reshape = reshape_after_transpose;
|
||||||
|
if (enable_last_reshape) {
|
||||||
|
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_after_transpose, shape_const, false);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
reshape = std::make_shared<ngraph::opset7::Reshape>(transpose, shape_const, false);
|
reshape = transpose;
|
||||||
const auto reshape_output_shape = reshape->get_output_shape(0);
|
if (enable_last_reshape) {
|
||||||
|
reshape = std::make_shared<ngraph::opset7::Reshape>(transpose, shape_const, false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
|
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
|
||||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
|
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& input_shape,
|
std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
||||||
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_instead_of_transpose) {
|
const ngraph::Shape& input_shape,
|
||||||
|
const ngraph::Shape& matmul_shape,
|
||||||
|
const ngraph::Shape& reshape_shape,
|
||||||
|
bool create_reshape_instead_of_transpose,
|
||||||
|
bool enable_last_reshape,
|
||||||
|
bool enable_add,
|
||||||
|
bool matmul_on_left_side,
|
||||||
|
bool enable_fq) {
|
||||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||||
|
|
||||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||||
std::iota(std::begin(data), std::end(data), 1);
|
std::iota(std::begin(data), std::end(data), 1);
|
||||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||||
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||||
|
const auto matmul_output_shape = node->get_output_shape(0);
|
||||||
|
if (enable_add) {
|
||||||
|
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||||
|
if (matmul_on_left_side) {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||||
|
} else {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::opset7::Reshape> reshape;
|
if (enable_fq) {
|
||||||
|
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||||
|
node,
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
255);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> reshape;
|
||||||
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
||||||
if (create_reshape_instead_of_transpose) {
|
if (create_reshape_instead_of_transpose) {
|
||||||
const auto matmul_output_shape = matmul->get_output_shape(0);
|
|
||||||
auto reshape_instead_of_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
auto reshape_instead_of_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||||
ngraph::Shape{matmul_output_shape.size()}, {matmul_output_shape[1], matmul_output_shape[0]});
|
ngraph::Shape{matmul_output_shape.size()}, {matmul_output_shape[1], matmul_output_shape[0]});
|
||||||
auto reshape_instead_of_transpose = std::make_shared<ngraph::opset7::Reshape>(matmul, reshape_instead_of_transpose_const, false);
|
auto reshape_instead_of_transpose = std::make_shared<ngraph::opset7::Reshape>(node, reshape_instead_of_transpose_const, false);
|
||||||
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_instead_of_transpose, shape_const, false);
|
reshape = reshape_instead_of_transpose;
|
||||||
|
if (enable_last_reshape) {
|
||||||
|
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_instead_of_transpose, shape_const, false);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
reshape = std::make_shared<ngraph::opset7::Reshape>(matmul, shape_const, false);
|
reshape = node;
|
||||||
|
if (enable_last_reshape) {
|
||||||
|
reshape = std::make_shared<ngraph::opset7::Reshape>(node, shape_const, false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
|
auto result = std::make_shared<ngraph::opset7::Result>(reshape);
|
||||||
@ -153,6 +214,9 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTest) {
|
|||||||
RunTest(
|
RunTest(
|
||||||
handle_transpose_before_matmul::CreateMatmulFunction({1, 16}, {8, 2}, {2, 1}, false),
|
handle_transpose_before_matmul::CreateMatmulFunction({1, 16}, {8, 2}, {2, 1}, false),
|
||||||
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 16}, {8, 2}, {2, 1}, true));
|
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 16}, {8, 2}, {2, 1}, true));
|
||||||
|
RunTest(
|
||||||
|
handle_transpose_before_matmul::CreateMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, false),
|
||||||
|
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) {
|
TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) {
|
||||||
@ -177,25 +241,59 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
||||||
RunTest(
|
for (auto enable_add : { true, false}) {
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, false),
|
for (auto matmul_on_left_side : { true, false}) {
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, true));
|
for (auto enable_fq : { true, false}) {
|
||||||
|
RunTest(
|
||||||
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
|
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||||
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
|
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
||||||
RunTest(
|
for (auto enable_add : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, false),
|
for (auto matmul_on_left_side : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, true));
|
for (auto enable_fq : { true, false }) {
|
||||||
|
RunTest(
|
||||||
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
|
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||||
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
|
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
|
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
|
||||||
RunTest(
|
for (auto enable_add : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false),
|
for (auto matmul_on_left_side : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false));
|
for (auto enable_fq : { true, false }) {
|
||||||
|
RunTest(
|
||||||
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
|
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||||
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
|
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
|
TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
|
||||||
RunTest(
|
for (auto enable_last_reshape : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false),
|
for (auto enable_add : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false));
|
for (auto matmul_on_left_side : { true, false }) {
|
||||||
|
for (auto enable_fq : { true, false }) {
|
||||||
|
RunTest(
|
||||||
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
|
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq),
|
||||||
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
|
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user