[GNA] fixed failed to inserting transpose after matmul (#7720)

* fixed failed to inserting transpose after matmul

* added verifier for matmul
This commit is contained in:
Dmitrii Khurtin 2021-09-30 09:55:34 +03:00 committed by GitHub
parent ffd2091477
commit a13b934622
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View File

@ -117,7 +117,9 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
} }
HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() { HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>(); auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto out_shape = node.get_node_shared_ptr()->get_output_shape(0);
return std::count_if(out_shape.begin(), out_shape.end(), [](size_t n) { return n > 1; }) > 1; });
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()}); auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul}); auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right}); auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});

View File

@ -249,6 +249,11 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq), {4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulTransposeFunction( handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq)); {4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction(
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulFunction(
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq));
} }
} }
} }