[GNA] Extend matmul handling support (#9810)
Support matmuls with two non-const inputs. Detect concat inputs to matmul as changing batch size and handle appropriately. Enable tests in GNA_SW_EXACT mode for convolution stride > kernel size.
This commit is contained in:
parent
068bdff787
commit
ef0a080323
@ -395,7 +395,7 @@ bool AreLayersSupported(InferenceEngine::CNNNetwork& network, std::string& errMe
|
||||
check_result = false;
|
||||
}
|
||||
} else if (info.isConcat()) {
|
||||
if (!ValidateConcatAxis(layer, errMessage) && userWarning) {
|
||||
if (userWarning && !ValidateConcatAxis(layer, errMessage)) {
|
||||
std::cout << errMessage;
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@
|
||||
namespace GNAPluginNS {
|
||||
|
||||
/**
|
||||
* @brief checks if it's a reshape from 4d to 3d tensor inserted after convolution
|
||||
* @brief checks if it's a reshape from 4d to 3d tensor
|
||||
* @param layer Non-functional layer
|
||||
*/
|
||||
inline bool IsReshapeFrom4dTo3d(InferenceEngine::CNNLayerPtr layer) {
|
||||
@ -24,8 +24,7 @@ inline bool IsReshapeFrom4dTo3d(InferenceEngine::CNNLayerPtr layer) {
|
||||
auto input_dims = layer->insData[0].lock()->getDims();
|
||||
auto output_dims = layer->outData[0]->getDims();
|
||||
// If H input dimension is not 1, it can't be just skipped during reshape to 3d
|
||||
size_t h_dim = input_dims[2];
|
||||
if (input_dims.size() != 4 || output_dims.size() != 3 || h_dim != 1) {
|
||||
if (input_dims.size() != 4 || output_dims.size() != 3 || input_dims[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -61,7 +60,7 @@ inline bool IsReshapeFrom3dTo4d(InferenceEngine::CNNLayerPtr layer) {
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief searchs for a pattern: Permute(NHWC->NCHW) -> ... -> Convolution -> ... -> Permute(NCHW->NHWC) or
|
||||
* @brief searches for a pattern: Permute(NHWC->NCHW) -> ... -> Convolution -> ... -> Permute(NCHW->NHWC) or
|
||||
* Reshape(NHWC->NCHW) -> ... -> Convolution -> ... -> Reshape(NCHW->NHWC) if Convolution has only one input/output
|
||||
* dimension not equal to 1,
|
||||
* if the original convolution layout is 3d, 3d->4d/4d->3d reshapes will be inserted before/after the convolution,
|
||||
|
@ -208,7 +208,7 @@ void Config::UpdateFromMap(const std::map<std::string, std::string>& config) {
|
||||
if (value == PluginConfigParams::LOG_WARNING || value == PluginConfigParams::LOG_NONE) {
|
||||
gnaFlags.log_level = value;
|
||||
} else {
|
||||
log << "Currently only LOG_LEVEL = LOG_WARNING or LOG_NONE is supported, not " << value;
|
||||
log << "Currently only LOG_LEVEL = LOG_WARNING and LOG_NONE are supported, not " << value;
|
||||
THROW_GNA_EXCEPTION << "Currently only LOG_LEVEL = LOG_WARNING and LOG_NONE are supported, not " << value;
|
||||
}
|
||||
} else {
|
||||
|
@ -88,21 +88,36 @@ bool VerifyReshape(const ngraph::Output<ngraph::Node>& reshape_out) {
|
||||
return in_shape[0] != out_shape[0];
|
||||
}
|
||||
|
||||
bool VerifyConcat(const ngraph::Output<ngraph::Node>& node) {
|
||||
auto concat_node = std::dynamic_pointer_cast<ngraph::opset8::Concat>(node.get_node_shared_ptr());
|
||||
return (concat_node->get_axis() == 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
|
||||
auto concat1 = ngraph::pattern::wrap_type<ngraph::opset8::Concat>(VerifyConcat);
|
||||
auto reshape1 = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(VerifyReshape);
|
||||
auto transpose_input1 = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{concat1, reshape1});
|
||||
auto transpose1 = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input1, ngraph::pattern::any_input()});
|
||||
|
||||
auto concat2 = ngraph::pattern::wrap_type<ngraph::opset8::Concat>(VerifyConcat);
|
||||
auto reshape2 = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(VerifyReshape);
|
||||
auto transpose_input2 = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{concat2, reshape2});
|
||||
auto transpose2 = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input2, ngraph::pattern::any_input()});
|
||||
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({}, VerifyReshape);
|
||||
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({reshape,
|
||||
ngraph::pattern::any_input()});
|
||||
|
||||
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose}),
|
||||
ngraph::pattern::any_input()});
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape1, concat1, transpose1, constant, fq, ngraph::pattern::any_input()}),
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape2, concat2, transpose2})});
|
||||
|
||||
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fq}),
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose, ngraph::pattern::any_input()})});
|
||||
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape1, concat1, transpose1, constant, fq}),
|
||||
ngraph::pattern::any_input()});
|
||||
|
||||
auto matmul = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
|
||||
@ -114,17 +129,25 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
|
||||
}
|
||||
|
||||
auto matmul_node = matmul_iter->second.get_node_shared_ptr();
|
||||
auto transpose_reshape_it = pattern_map.find(transpose);
|
||||
auto transpose_reshape_it = pattern_map.find(transpose1);
|
||||
if (transpose_reshape_it != std::end(pattern_map)) {
|
||||
ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr());
|
||||
} else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) {
|
||||
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
|
||||
if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) {
|
||||
InsertTranspose(reshape_node, matmul_node->get_friendly_name(), true);
|
||||
} else {
|
||||
std::shared_ptr<ngraph::Node> prev_node = nullptr;
|
||||
if ((transpose_reshape_it = pattern_map.find(reshape1)) != std::end(pattern_map)) {
|
||||
prev_node = pattern_map.at(reshape1).get_node_shared_ptr();
|
||||
} else if ((transpose_reshape_it = pattern_map.find(concat1)) != std::end(pattern_map)) {
|
||||
prev_node = pattern_map.at(concat1).get_node_shared_ptr();
|
||||
}
|
||||
|
||||
if (prev_node) {
|
||||
if (GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) {
|
||||
InsertTranspose(prev_node, matmul_node->get_friendly_name(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose the constant input if it's the first input
|
||||
// Transpose the first input if it's a constant
|
||||
auto iter = pattern_map.find(fq);
|
||||
if (iter != pattern_map.end() ||
|
||||
(iter = pattern_map.find(constant)) != pattern_map.end()) {
|
||||
@ -133,6 +156,24 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
|
||||
InsertTranspose(prev_node, prev_node->get_friendly_name(), true);
|
||||
}
|
||||
}
|
||||
|
||||
transpose_reshape_it = pattern_map.find(transpose2);
|
||||
if (transpose_reshape_it != std::end(pattern_map)) {
|
||||
ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr());
|
||||
} else {
|
||||
std::shared_ptr<ngraph::Node> prev_node = nullptr;
|
||||
if ((transpose_reshape_it = pattern_map.find(reshape2)) != std::end(pattern_map)) {
|
||||
prev_node = pattern_map.at(reshape2).get_node_shared_ptr();
|
||||
} else if ((transpose_reshape_it = pattern_map.find(concat2)) != std::end(pattern_map)) {
|
||||
prev_node = pattern_map.at(concat2).get_node_shared_ptr();
|
||||
}
|
||||
|
||||
if (prev_node) {
|
||||
if (GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) {
|
||||
InsertTranspose(prev_node, matmul_node->get_friendly_name(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -9,11 +9,12 @@
|
||||
namespace GNAPluginNS {
|
||||
|
||||
/**
|
||||
* @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape
|
||||
* before MatMul which changes the batch size:
|
||||
* @brief Inserts Transpose before MatMul or removes it (if it exists)
|
||||
* if there is Reshape/Concat layer before MatMul which changes the batch size,
|
||||
* or transpose the input if there's a Constant/FQ layer on the first input:
|
||||
* [1, A*B] [1, A*B]
|
||||
* | |
|
||||
* Reshape Reshape
|
||||
* Reshape / Concat Reshape / Concat
|
||||
* | |
|
||||
* [A, B] [A, B]
|
||||
* | |
|
||||
|
@ -66,17 +66,25 @@ static void SwapAndTransposeInputs(
|
||||
gnalog() << "Swap and transpose inputs for " << matmul_node->get_friendly_name() << "\n";
|
||||
|
||||
bool first_input_const = false;
|
||||
bool second_input_const = false;
|
||||
auto first_input = matmul_node->input_value(0).get_node_shared_ptr();
|
||||
auto second_input = matmul_node->input_value(1).get_node_shared_ptr();
|
||||
if (std::dynamic_pointer_cast<ngraph::opset8::FakeQuantize>(first_input)) {
|
||||
first_input = first_input->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
if (std::dynamic_pointer_cast<ngraph::opset8::FakeQuantize>(second_input)) {
|
||||
second_input = second_input->input_value(1).get_node_shared_ptr();
|
||||
}
|
||||
if (std::dynamic_pointer_cast<ngraph::opset8::Constant>(first_input)) {
|
||||
first_input_const = true;
|
||||
}
|
||||
if (std::dynamic_pointer_cast<ngraph::opset8::Constant>(second_input)) {
|
||||
second_input_const = true;
|
||||
}
|
||||
|
||||
auto input1 = first_input_const ? transpose_matmul_input(1) : matmul_node->input_value(1);
|
||||
auto input1 = (!first_input_const && second_input_const) ? matmul_node->input_value(1) : transpose_matmul_input(1);
|
||||
auto input2 = first_input_const ? matmul_node->input_value(0) : transpose_matmul_input(0);
|
||||
bool transpose_1 = first_input_const ? matmul_node->get_transpose_b() : !matmul_node->get_transpose_b();
|
||||
bool transpose_1 = (!first_input_const && second_input_const) ? !matmul_node->get_transpose_b() : matmul_node->get_transpose_b();
|
||||
bool transpose_2 = first_input_const ? !matmul_node->get_transpose_a() : matmul_node->get_transpose_a();
|
||||
std::shared_ptr<ngraph::Node> new_node = std::make_shared<ngraph::opset8::MatMul>(input1, input2, transpose_1, transpose_2);
|
||||
new_node->set_friendly_name(matmul_node->get_friendly_name() + "/swap_inputs");
|
||||
|
@ -313,6 +313,11 @@ const std::vector<std::map<std::string, std::string>> configsStrides = {
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_FP32"},
|
||||
{"GNA_SCALE_FACTOR_0", "1"},
|
||||
{"GNA_EXEC_TARGET", "GNA_TARGET_2_0"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "1"},
|
||||
{"GNA_EXEC_TARGET", "GNA_TARGET_2_0"}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -129,4 +129,91 @@ INSTANTIATE_TEST_SUITE_P(smoke_InsertTransposeBeforeMatmulTest, InsertTransposeB
|
||||
::testing::ValuesIn(firstInputConst)),
|
||||
InsertTransposeBeforeMatmul::getTestCaseName);
|
||||
|
||||
/* Case with two inputs with concat instead of reshape */
|
||||
|
||||
class InsertTransposeBeforeConcatConcat : public testing::WithParamInterface<insertTransposeBeforeMatmulParams>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<insertTransposeBeforeMatmulParams> obj) {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::map<std::string, std::string> configuration;
|
||||
size_t inputShape;
|
||||
bool firstInConst;
|
||||
std::tie(netPrecision, targetDevice, configuration, inputShape, firstInConst) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
for (auto const& configItem : configuration) {
|
||||
result << "_configItem=" << configItem.first << "_" << configItem.second;
|
||||
}
|
||||
result << "_inputShape=" << inputShape;
|
||||
result << "_firstInConst=" << firstInConst;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override {
|
||||
InferenceEngine::Blob::Ptr blob = make_blob_with_precision(info.getTensorDesc());
|
||||
blob->allocate();
|
||||
|
||||
auto* rawBlobDataPtr = blob->buffer().as<float*>();
|
||||
std::vector<float> values = CommonTestUtils::generate_float_numbers(blob->size(), -0.2f, 0.2f);
|
||||
for (size_t i = 0; i < blob->size(); i++) {
|
||||
rawBlobDataPtr[i] = values[i];
|
||||
}
|
||||
return blob;
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
size_t inputShape;
|
||||
bool firstInConst;
|
||||
std::tie(netPrecision, targetDevice, configuration, inputShape, firstInConst) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {{1, inputShape}});
|
||||
auto matmul_in_shape = ngraph::Shape{inputShape / 8, 8};
|
||||
auto pattern = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 2 }, matmul_in_shape);
|
||||
auto reshape = std::make_shared<ngraph::opset1::Reshape>(params[0], pattern, false);
|
||||
|
||||
std::vector<float> data = CommonTestUtils::generate_float_numbers(ngraph::shape_size(matmul_in_shape), -0.2f, 0.2f);
|
||||
auto concat_const = std::make_shared<ngraph::opset1::Constant>(ngPrc, matmul_in_shape, data);
|
||||
ngraph::OutputVector concat_chunks{reshape, concat_const};
|
||||
auto concat = std::make_shared<ngraph::opset7::Concat>(concat_chunks, 0);
|
||||
|
||||
std::shared_ptr<ngraph::Node> weights_node;
|
||||
std::vector<float> weights = CommonTestUtils::generate_float_numbers(matmul_in_shape[0] * 2, -0.2f, 0.2f);
|
||||
weights_node = std::make_shared<ngraph::opset1::Constant>(ngPrc, ngraph::Shape{ 1, matmul_in_shape[0] * 2 }, weights);
|
||||
|
||||
auto matmul = firstInConst ? ngraph::builder::makeMatMul(weights_node, concat, false, false) :
|
||||
ngraph::builder::makeMatMul(concat, weights_node, false, false);
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(matmul)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "InsertTransposeBeforeConcatConcat");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(InsertTransposeBeforeConcatConcat, CompareWithRefImpl) {
|
||||
Run();
|
||||
};
|
||||
|
||||
const std::vector<size_t> concatInputShape = {
|
||||
64,
|
||||
96,
|
||||
128
|
||||
};
|
||||
|
||||
const std::vector<bool> firstInputConstConcat = {true};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_InsertTransposeBeforeMatmulConcat, InsertTransposeBeforeConcatConcat,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::ValuesIn(concatInputShape),
|
||||
::testing::ValuesIn(firstInputConstConcat)),
|
||||
InsertTransposeBeforeConcatConcat::getTestCaseName);
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -66,6 +66,78 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& inpu
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> CreateConcatTransposeMatmulFunction(const ngraph::Shape& input1_shape, const ngraph::Shape& input2_shape,
|
||||
const ngraph::Shape& reshape1_shape, const ngraph::Shape& reshape2_shape, bool create_reshape_after_transpose) {
|
||||
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
|
||||
|
||||
auto input1_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input1_shape);
|
||||
std::vector<size_t> data1(ngraph::shape_size(input1_shape));
|
||||
std::iota(std::begin(data1), std::end(data1), 1);
|
||||
auto concat1_const = ngraph::opset7::Constant::create(ngraph::element::i64, input1_shape, data1);
|
||||
ngraph::OutputVector concat1_chunks{input1_params, concat1_const};
|
||||
auto concat1 = std::make_shared<ngraph::opset7::Concat>(concat1_chunks, 0);
|
||||
auto transpose1 = std::make_shared<ngraph::opset7::Transpose>(concat1, transpose_order);
|
||||
|
||||
auto input2_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input2_shape);
|
||||
std::vector<size_t> data2(ngraph::shape_size(input2_shape));
|
||||
std::iota(std::begin(data2), std::end(data2), 1);
|
||||
auto concat2_const = ngraph::opset7::Constant::create(ngraph::element::i64, input2_shape, data2);
|
||||
ngraph::OutputVector concat2_chunks{input2_params, concat2_const};
|
||||
auto concat2 = std::make_shared<ngraph::opset7::Concat>(concat2_chunks, 0);
|
||||
auto transpose2 = std::make_shared<ngraph::opset7::Transpose>(concat2, transpose_order);
|
||||
|
||||
std::shared_ptr<ngraph::opset7::MatMul> matmul;
|
||||
|
||||
if (create_reshape_after_transpose) {
|
||||
auto reshape_after_transpose1_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{reshape1_shape.size()}, reshape1_shape);
|
||||
auto reshape_after_transpose1 = std::make_shared<ngraph::opset7::Reshape>(transpose1, reshape_after_transpose1_const, false);
|
||||
auto reshape_after_transpose2_const = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{reshape2_shape.size()}, reshape2_shape);
|
||||
auto reshape_after_transpose2 = std::make_shared<ngraph::opset7::Reshape>(transpose2, reshape_after_transpose2_const, false);
|
||||
matmul = std::make_shared<ngraph::opset7::MatMul>(reshape_after_transpose1, reshape_after_transpose2);
|
||||
} else {
|
||||
matmul = std::make_shared<ngraph::opset7::MatMul>(transpose1, transpose2);
|
||||
}
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(matmul);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input1_params, input2_params});
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> CreateConcatMatmulFunction(const ngraph::Shape& input1_shape, const ngraph::Shape& input2_shape,
|
||||
const ngraph::Shape& reshape1_shape, const ngraph::Shape& reshape2_shape, bool create_reshape_instead_of_transpose) {
|
||||
auto input1_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input1_shape);
|
||||
std::vector<size_t> data1(ngraph::shape_size(input1_shape));
|
||||
std::iota(std::begin(data1), std::end(data1), 1);
|
||||
auto concat1_const = ngraph::opset7::Constant::create(ngraph::element::i64, input1_shape, data1);
|
||||
ngraph::OutputVector concat1_chunks{input1_params, concat1_const};
|
||||
auto concat1 = std::make_shared<ngraph::opset7::Concat>(concat1_chunks, 0);
|
||||
|
||||
auto input2_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input2_shape);
|
||||
std::vector<size_t> data2(ngraph::shape_size(input2_shape));
|
||||
std::iota(std::begin(data2), std::end(data2), 1);
|
||||
auto concat2_const = ngraph::opset7::Constant::create(ngraph::element::i64, input2_shape, data2);
|
||||
ngraph::OutputVector concat2_chunks{input2_params, concat2_const};
|
||||
auto concat2 = std::make_shared<ngraph::opset7::Concat>(concat2_chunks, 0);
|
||||
|
||||
std::shared_ptr<ngraph::opset7::MatMul> matmul;
|
||||
|
||||
if (create_reshape_instead_of_transpose) {
|
||||
auto new_shape_after_transpose1 = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{reshape1_shape.size()}, {reshape1_shape[1], reshape1_shape[0]});
|
||||
auto reshape1 = std::make_shared<ngraph::opset7::Reshape>(concat1, new_shape_after_transpose1, false);
|
||||
auto new_shape_after_transpose2 = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{reshape2_shape.size()}, {reshape2_shape[1], reshape2_shape[0]});
|
||||
auto reshape2 = std::make_shared<ngraph::opset7::Reshape>(concat2, new_shape_after_transpose2, false);
|
||||
matmul = std::make_shared<ngraph::opset7::MatMul>(reshape1, reshape2);
|
||||
} else {
|
||||
matmul = std::make_shared<ngraph::opset7::MatMul>(concat1, concat2);
|
||||
}
|
||||
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(matmul);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input1_params, input2_params});
|
||||
}
|
||||
|
||||
} // namespace handle_transpose_before_matmul
|
||||
|
||||
namespace handle_transpose_after_matmul {
|
||||
@ -235,6 +307,9 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTest) {
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, false),
|
||||
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, true));
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateConcatMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, false),
|
||||
handle_transpose_before_matmul::CreateConcatTransposeMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, true));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) {
|
||||
@ -244,12 +319,18 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) {
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateMatmulFunction({9, 2}, {9, 2}, {2, 1}, false),
|
||||
handle_transpose_before_matmul::CreateMatmulFunction({9, 2}, {9, 2}, {2, 1}, false));
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateConcatMatmulFunction({8, 16}, {8, 16}, {16, 16}, {16, 16}, false),
|
||||
handle_transpose_before_matmul::CreateConcatMatmulFunction({8, 16}, {8, 16}, {16, 16}, {16, 16}, false));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, RemoveTransposeBeforeMatmulTest) {
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 8}, {2, 4}, {2, 1}, false),
|
||||
handle_transpose_before_matmul::CreateMatmulFunction({1, 8}, {2, 4}, {2, 1}, true));
|
||||
RunTest(
|
||||
handle_transpose_before_matmul::CreateConcatTransposeMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, false),
|
||||
handle_transpose_before_matmul::CreateConcatMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, true));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
|
Loading…
Reference in New Issue
Block a user