diff --git a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp index e0a009a9e92..24dc28efe1a 100644 --- a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp +++ b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp @@ -117,7 +117,9 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() { } HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() { - auto matmul = ngraph::pattern::wrap_type(); + auto matmul = ngraph::pattern::wrap_type({}, [](const ngraph::Output& node) { + auto out_shape = node.get_node_shared_ptr()->get_output_shape(0); + return std::count_if(out_shape.begin(), out_shape.end(), [](size_t n) { return n > 1; }) > 1; }); auto add_left = ngraph::pattern::wrap_type({matmul, ngraph::pattern::any_input()}); auto add_right = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), matmul}); auto fq_input = std::make_shared(ngraph::OutputVector{matmul, add_left, add_right}); diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp index cd35294579c..522ebfc06a1 100644 --- a/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp +++ b/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp @@ -249,6 +249,11 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTest) { {4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq), handle_transpose_after_matmul::CreateMatmulTransposeFunction( {4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq)); + RunTest( + handle_transpose_after_matmul::CreateMatmulFunction( + {1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq), + handle_transpose_after_matmul::CreateMatmulFunction( + {1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq)); } } }