diff --git a/src/plugins/intel_gna/backend/gna_limitations.cpp b/src/plugins/intel_gna/backend/gna_limitations.cpp index a532cf65ca6..b50b796486a 100644 --- a/src/plugins/intel_gna/backend/gna_limitations.cpp +++ b/src/plugins/intel_gna/backend/gna_limitations.cpp @@ -395,7 +395,7 @@ bool AreLayersSupported(InferenceEngine::CNNNetwork& network, std::string& errMe check_result = false; } } else if (info.isConcat()) { - if (!ValidateConcatAxis(layer, errMessage) && userWarning) { + if (userWarning && !ValidateConcatAxis(layer, errMessage)) { std::cout << errMessage; } } diff --git a/src/plugins/intel_gna/gna_graph_patterns.hpp b/src/plugins/intel_gna/gna_graph_patterns.hpp index 440fd805ad3..60e2ad066e4 100644 --- a/src/plugins/intel_gna/gna_graph_patterns.hpp +++ b/src/plugins/intel_gna/gna_graph_patterns.hpp @@ -13,7 +13,7 @@ namespace GNAPluginNS { /** - * @brief checks if it's a reshape from 4d to 3d tensor inserted after convolution + * @brief checks if it's a reshape from 4d to 3d tensor * @param layer Non-functional layer */ inline bool IsReshapeFrom4dTo3d(InferenceEngine::CNNLayerPtr layer) { @@ -24,8 +24,7 @@ inline bool IsReshapeFrom4dTo3d(InferenceEngine::CNNLayerPtr layer) { auto input_dims = layer->insData[0].lock()->getDims(); auto output_dims = layer->outData[0]->getDims(); // If H input dimension is not 1, it can't be just skipped during reshape to 3d - size_t h_dim = input_dims[2]; - if (input_dims.size() != 4 || output_dims.size() != 3 || h_dim != 1) { + if (input_dims.size() != 4 || output_dims.size() != 3 || input_dims[2] != 1) { return false; } @@ -61,7 +60,7 @@ inline bool IsReshapeFrom3dTo4d(InferenceEngine::CNNLayerPtr layer) { } /** - * @brief searchs for a pattern: Permute(NHWC->NCHW) -> ... -> Convolution -> ... -> Permute(NCHW->NHWC) or + * @brief searches for a pattern: Permute(NHWC->NCHW) -> ... -> Convolution -> ... -> Permute(NCHW->NHWC) or * Reshape(NHWC->NCHW) -> ... -> Convolution -> ... -> Reshape(NCHW->NHWC) if Convolution has only one input/output * dimension not equal to 1, * if the original convolution layout is 3d, 3d->4d/4d->3d reshapes will be inserted before/after the convolution, diff --git a/src/plugins/intel_gna/gna_plugin_config.cpp b/src/plugins/intel_gna/gna_plugin_config.cpp index a65ea64f33a..09a966d7bce 100644 --- a/src/plugins/intel_gna/gna_plugin_config.cpp +++ b/src/plugins/intel_gna/gna_plugin_config.cpp @@ -208,7 +208,7 @@ void Config::UpdateFromMap(const std::map& config) { if (value == PluginConfigParams::LOG_WARNING || value == PluginConfigParams::LOG_NONE) { gnaFlags.log_level = value; } else { - log << "Currently only LOG_LEVEL = LOG_WARNING or LOG_NONE is supported, not " << value; + log << "Currently only LOG_LEVEL = LOG_WARNING and LOG_NONE are supported, not " << value; THROW_GNA_EXCEPTION << "Currently only LOG_LEVEL = LOG_WARNING and LOG_NONE are supported, not " << value; } } else { diff --git a/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.cpp b/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.cpp index f00a5b4dcf7..91c98aaa237 100644 --- a/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.cpp +++ b/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.cpp @@ -88,21 +88,36 @@ bool VerifyReshape(const ngraph::Output& reshape_out) { return in_shape[0] != out_shape[0]; } +bool VerifyConcat(const ngraph::Output& node) { + auto concat_node = std::dynamic_pointer_cast(node.get_node_shared_ptr()); + return (concat_node->get_axis() == 0); +} + } // namespace HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() { + auto concat1 = ngraph::pattern::wrap_type(VerifyConcat); + auto reshape1 = ngraph::pattern::wrap_type(VerifyReshape); + auto transpose_input1 = std::make_shared(ngraph::OutputVector{concat1, reshape1}); + auto transpose1 = ngraph::pattern::wrap_type({transpose_input1, ngraph::pattern::any_input()}); + + auto concat2 = ngraph::pattern::wrap_type(VerifyConcat); + auto reshape2 = ngraph::pattern::wrap_type(VerifyReshape); + auto transpose_input2 = std::make_shared(ngraph::OutputVector{concat2, reshape2}); + auto transpose2 = ngraph::pattern::wrap_type({transpose_input2, ngraph::pattern::any_input()}); + auto constant = ngraph::pattern::wrap_type(); auto fq = ngraph::pattern::wrap_type({constant, ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()}); - auto reshape = ngraph::pattern::wrap_type({}, VerifyReshape); - auto transpose = ngraph::pattern::wrap_type({reshape, - ngraph::pattern::any_input()}); + auto matmul1 = ngraph::pattern::wrap_type({ - std::make_shared(ngraph::OutputVector{reshape, transpose}), - ngraph::pattern::any_input()}); + std::make_shared(ngraph::OutputVector{reshape1, concat1, transpose1, constant, fq, ngraph::pattern::any_input()}), + std::make_shared(ngraph::OutputVector{reshape2, concat2, transpose2})}); + auto matmul2 = ngraph::pattern::wrap_type({ - std::make_shared(ngraph::OutputVector{constant, fq}), - std::make_shared(ngraph::OutputVector{reshape, transpose, ngraph::pattern::any_input()})}); + std::make_shared(ngraph::OutputVector{reshape1, concat1, transpose1, constant, fq}), + ngraph::pattern::any_input()}); + auto matmul = std::make_shared(ngraph::OutputVector{matmul1, matmul2}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) { @@ -114,17 +129,25 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() { } auto matmul_node = matmul_iter->second.get_node_shared_ptr(); - auto transpose_reshape_it = pattern_map.find(transpose); + auto transpose_reshape_it = pattern_map.find(transpose1); if (transpose_reshape_it != std::end(pattern_map)) { ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr()); - } else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) { - auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr(); - if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) { - InsertTranspose(reshape_node, matmul_node->get_friendly_name(), true); + } else { + std::shared_ptr prev_node = nullptr; + if ((transpose_reshape_it = pattern_map.find(reshape1)) != std::end(pattern_map)) { + prev_node = pattern_map.at(reshape1).get_node_shared_ptr(); + } else if ((transpose_reshape_it = pattern_map.find(concat1)) != std::end(pattern_map)) { + prev_node = pattern_map.at(concat1).get_node_shared_ptr(); + } + + if (prev_node) { + if (GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) { + InsertTranspose(prev_node, matmul_node->get_friendly_name(), true); + } } } - // Transpose the constant input if it's the first input + // Transpose the first input if it's a constant auto iter = pattern_map.find(fq); if (iter != pattern_map.end() || (iter = pattern_map.find(constant)) != pattern_map.end()) { @@ -133,6 +156,24 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() { InsertTranspose(prev_node, prev_node->get_friendly_name(), true); } } + + transpose_reshape_it = pattern_map.find(transpose2); + if (transpose_reshape_it != std::end(pattern_map)) { + ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr()); + } else { + std::shared_ptr prev_node = nullptr; + if ((transpose_reshape_it = pattern_map.find(reshape2)) != std::end(pattern_map)) { + prev_node = pattern_map.at(reshape2).get_node_shared_ptr(); + } else if ((transpose_reshape_it = pattern_map.find(concat2)) != std::end(pattern_map)) { + prev_node = pattern_map.at(concat2).get_node_shared_ptr(); + } + + if (prev_node) { + if (GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) { + InsertTranspose(prev_node, matmul_node->get_friendly_name(), true); + } + } + } return true; }; diff --git a/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.hpp b/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.hpp index 9594f5eb21f..ba5702da1ba 100644 --- a/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.hpp +++ b/src/plugins/intel_gna/transformations/handle_transposes_around_matmul.hpp @@ -9,22 +9,23 @@ namespace GNAPluginNS { /** - * @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape - * before MatMul which changes the batch size: - * [1, A*B] [1, A*B] - * | | - * Reshape Reshape - * | | - * [A, B] [A, B] - * | | - * | Transpose - * | -> | - * | <- [B, A] - * | | - * | Reshape - * | [A, B] - * | | - * MatMul MatMul + * @brief Inserts Transpose before MatMul or removes it (if it exists) + * if there is Reshape/Concat layer before MatMul which changes the batch size, + * or transpose the input if there's a Constant/FQ layer on the first input: + * [1, A*B] [1, A*B] + * | | + * Reshape / Concat Reshape / Concat + * | | + * [A, B] [A, B] + * | | + * | Transpose + * | -> | + * | <- [B, A] + * | | + * | Reshape + * | [A, B] + * | | + * MatMul MatMul */ class HandleTransposeBeforeMatMul : public ngraph::pass::MatcherPass { public: diff --git a/src/plugins/intel_gna/transformations/swap_input_matmul_gna.cpp b/src/plugins/intel_gna/transformations/swap_input_matmul_gna.cpp index 2d9ccff8041..487575a0c24 100644 --- a/src/plugins/intel_gna/transformations/swap_input_matmul_gna.cpp +++ b/src/plugins/intel_gna/transformations/swap_input_matmul_gna.cpp @@ -66,17 +66,25 @@ static void SwapAndTransposeInputs( gnalog() << "Swap and transpose inputs for " << matmul_node->get_friendly_name() << "\n"; bool first_input_const = false; + bool second_input_const = false; auto first_input = matmul_node->input_value(0).get_node_shared_ptr(); + auto second_input = matmul_node->input_value(1).get_node_shared_ptr(); if (std::dynamic_pointer_cast(first_input)) { first_input = first_input->input_value(0).get_node_shared_ptr(); } + if (std::dynamic_pointer_cast(second_input)) { + second_input = second_input->input_value(1).get_node_shared_ptr(); + } if (std::dynamic_pointer_cast(first_input)) { first_input_const = true; } + if (std::dynamic_pointer_cast(second_input)) { + second_input_const = true; + } - auto input1 = first_input_const ? transpose_matmul_input(1) : matmul_node->input_value(1); + auto input1 = (!first_input_const && second_input_const) ? matmul_node->input_value(1) : transpose_matmul_input(1); auto input2 = first_input_const ? matmul_node->input_value(0) : transpose_matmul_input(0); - bool transpose_1 = first_input_const ? matmul_node->get_transpose_b() : !matmul_node->get_transpose_b(); + bool transpose_1 = (!first_input_const && second_input_const) ? !matmul_node->get_transpose_b() : matmul_node->get_transpose_b(); bool transpose_2 = first_input_const ? !matmul_node->get_transpose_a() : matmul_node->get_transpose_a(); std::shared_ptr new_node = std::make_shared(input1, input2, transpose_1, transpose_2); new_node->set_friendly_name(matmul_node->get_friendly_name() + "/swap_inputs"); diff --git a/src/tests/functional/plugin/gna/pass_tests/decompose_2d_conv.cpp b/src/tests/functional/plugin/gna/pass_tests/decompose_2d_conv.cpp index a3d608fa8c4..13d10745cf7 100644 --- a/src/tests/functional/plugin/gna/pass_tests/decompose_2d_conv.cpp +++ b/src/tests/functional/plugin/gna/pass_tests/decompose_2d_conv.cpp @@ -313,6 +313,11 @@ const std::vector> configsStrides = { {"GNA_DEVICE_MODE", "GNA_SW_FP32"}, {"GNA_SCALE_FACTOR_0", "1"}, {"GNA_EXEC_TARGET", "GNA_TARGET_2_0"} + }, + { + {"GNA_DEVICE_MODE", "GNA_SW_EXACT"}, + {"GNA_SCALE_FACTOR_0", "1"}, + {"GNA_EXEC_TARGET", "GNA_TARGET_2_0"} } }; diff --git a/src/tests/functional/plugin/gna/pass_tests/insert_transpose_before_matmul.cpp b/src/tests/functional/plugin/gna/pass_tests/insert_transpose_before_matmul.cpp index 05cfb2fbd23..948ade3e98b 100644 --- a/src/tests/functional/plugin/gna/pass_tests/insert_transpose_before_matmul.cpp +++ b/src/tests/functional/plugin/gna/pass_tests/insert_transpose_before_matmul.cpp @@ -129,4 +129,91 @@ INSTANTIATE_TEST_SUITE_P(smoke_InsertTransposeBeforeMatmulTest, InsertTransposeB ::testing::ValuesIn(firstInputConst)), InsertTransposeBeforeMatmul::getTestCaseName); +/* Case with two inputs with concat instead of reshape */ + +class InsertTransposeBeforeConcatConcat : public testing::WithParamInterface, + public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + InferenceEngine::Precision netPrecision; + std::string targetDevice; + std::map configuration; + size_t inputShape; + bool firstInConst; + std::tie(netPrecision, targetDevice, configuration, inputShape, firstInConst) = obj.param; + + std::ostringstream result; + result << "netPRC=" << netPrecision.name() << "_"; + result << "targetDevice=" << targetDevice << "_"; + for (auto const& configItem : configuration) { + result << "_configItem=" << configItem.first << "_" << configItem.second; + } + result << "_inputShape=" << inputShape; + result << "_firstInConst=" << firstInConst; + return result.str(); + } + + InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override { + InferenceEngine::Blob::Ptr blob = make_blob_with_precision(info.getTensorDesc()); + blob->allocate(); + + auto* rawBlobDataPtr = blob->buffer().as(); + std::vector values = CommonTestUtils::generate_float_numbers(blob->size(), -0.2f, 0.2f); + for (size_t i = 0; i < blob->size(); i++) { + rawBlobDataPtr[i] = values[i]; + } + return blob; + } + +protected: + void SetUp() override { + InferenceEngine::Precision netPrecision; + size_t inputShape; + bool firstInConst; + std::tie(netPrecision, targetDevice, configuration, inputShape, firstInConst) = this->GetParam(); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + + auto params = ngraph::builder::makeParams(ngPrc, {{1, inputShape}}); + auto matmul_in_shape = ngraph::Shape{inputShape / 8, 8}; + auto pattern = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{ 2 }, matmul_in_shape); + auto reshape = std::make_shared(params[0], pattern, false); + + std::vector data = CommonTestUtils::generate_float_numbers(ngraph::shape_size(matmul_in_shape), -0.2f, 0.2f); + auto concat_const = std::make_shared(ngPrc, matmul_in_shape, data); + ngraph::OutputVector concat_chunks{reshape, concat_const}; + auto concat = std::make_shared(concat_chunks, 0); + + std::shared_ptr weights_node; + std::vector weights = CommonTestUtils::generate_float_numbers(matmul_in_shape[0] * 2, -0.2f, 0.2f); + weights_node = std::make_shared(ngPrc, ngraph::Shape{ 1, matmul_in_shape[0] * 2 }, weights); + + auto matmul = firstInConst ? ngraph::builder::makeMatMul(weights_node, concat, false, false) : + ngraph::builder::makeMatMul(concat, weights_node, false, false); + + ngraph::ResultVector results{ std::make_shared(matmul)}; + function = std::make_shared(results, params, "InsertTransposeBeforeConcatConcat"); + } +}; + +TEST_P(InsertTransposeBeforeConcatConcat, CompareWithRefImpl) { + Run(); +}; + +const std::vector concatInputShape = { + 64, + 96, + 128 +}; + +const std::vector firstInputConstConcat = {true}; + +INSTANTIATE_TEST_SUITE_P(smoke_InsertTransposeBeforeMatmulConcat, InsertTransposeBeforeConcatConcat, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GNA), + ::testing::ValuesIn(configs), + ::testing::ValuesIn(concatInputShape), + ::testing::ValuesIn(firstInputConstConcat)), + InsertTransposeBeforeConcatConcat::getTestCaseName); + } // namespace LayerTestsDefinitions diff --git a/src/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp b/src/tests/unit/gna/ngraph/transformations/gna_handle_transposes_around_matmul.cpp similarity index 74% rename from src/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp rename to src/tests/unit/gna/ngraph/transformations/gna_handle_transposes_around_matmul.cpp index 20d9d46421e..17eaaee67b9 100644 --- a/src/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp +++ b/src/tests/unit/gna/ngraph/transformations/gna_handle_transposes_around_matmul.cpp @@ -66,6 +66,78 @@ std::shared_ptr CreateMatmulFunction(const ngraph::Shape& inpu return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params}); } +std::shared_ptr CreateConcatTransposeMatmulFunction(const ngraph::Shape& input1_shape, const ngraph::Shape& input2_shape, + const ngraph::Shape& reshape1_shape, const ngraph::Shape& reshape2_shape, bool create_reshape_after_transpose) { + auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0}); + + auto input1_params = std::make_shared(ngraph::element::i64, input1_shape); + std::vector data1(ngraph::shape_size(input1_shape)); + std::iota(std::begin(data1), std::end(data1), 1); + auto concat1_const = ngraph::opset7::Constant::create(ngraph::element::i64, input1_shape, data1); + ngraph::OutputVector concat1_chunks{input1_params, concat1_const}; + auto concat1 = std::make_shared(concat1_chunks, 0); + auto transpose1 = std::make_shared(concat1, transpose_order); + + auto input2_params = std::make_shared(ngraph::element::i64, input2_shape); + std::vector data2(ngraph::shape_size(input2_shape)); + std::iota(std::begin(data2), std::end(data2), 1); + auto concat2_const = ngraph::opset7::Constant::create(ngraph::element::i64, input2_shape, data2); + ngraph::OutputVector concat2_chunks{input2_params, concat2_const}; + auto concat2 = std::make_shared(concat2_chunks, 0); + auto transpose2 = std::make_shared(concat2, transpose_order); + + std::shared_ptr matmul; + + if (create_reshape_after_transpose) { + auto reshape_after_transpose1_const = ngraph::opset7::Constant::create(ngraph::element::i64, + ngraph::Shape{reshape1_shape.size()}, reshape1_shape); + auto reshape_after_transpose1 = std::make_shared(transpose1, reshape_after_transpose1_const, false); + auto reshape_after_transpose2_const = ngraph::opset7::Constant::create(ngraph::element::i64, + ngraph::Shape{reshape2_shape.size()}, reshape2_shape); + auto reshape_after_transpose2 = std::make_shared(transpose2, reshape_after_transpose2_const, false); + matmul = std::make_shared(reshape_after_transpose1, reshape_after_transpose2); + } else { + matmul = std::make_shared(transpose1, transpose2); + } + + auto result = std::make_shared(matmul); + return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input1_params, input2_params}); +} + +std::shared_ptr CreateConcatMatmulFunction(const ngraph::Shape& input1_shape, const ngraph::Shape& input2_shape, + const ngraph::Shape& reshape1_shape, const ngraph::Shape& reshape2_shape, bool create_reshape_instead_of_transpose) { + auto input1_params = std::make_shared(ngraph::element::i64, input1_shape); + std::vector data1(ngraph::shape_size(input1_shape)); + std::iota(std::begin(data1), std::end(data1), 1); + auto concat1_const = ngraph::opset7::Constant::create(ngraph::element::i64, input1_shape, data1); + ngraph::OutputVector concat1_chunks{input1_params, concat1_const}; + auto concat1 = std::make_shared(concat1_chunks, 0); + + auto input2_params = std::make_shared(ngraph::element::i64, input2_shape); + std::vector data2(ngraph::shape_size(input2_shape)); + std::iota(std::begin(data2), std::end(data2), 1); + auto concat2_const = ngraph::opset7::Constant::create(ngraph::element::i64, input2_shape, data2); + ngraph::OutputVector concat2_chunks{input2_params, concat2_const}; + auto concat2 = std::make_shared(concat2_chunks, 0); + + std::shared_ptr matmul; + + if (create_reshape_instead_of_transpose) { + auto new_shape_after_transpose1 = ngraph::opset7::Constant::create(ngraph::element::i64, + ngraph::Shape{reshape1_shape.size()}, {reshape1_shape[1], reshape1_shape[0]}); + auto reshape1 = std::make_shared(concat1, new_shape_after_transpose1, false); + auto new_shape_after_transpose2 = ngraph::opset7::Constant::create(ngraph::element::i64, + ngraph::Shape{reshape2_shape.size()}, {reshape2_shape[1], reshape2_shape[0]}); + auto reshape2 = std::make_shared(concat2, new_shape_after_transpose2, false); + matmul = std::make_shared(reshape1, reshape2); + } else { + matmul = std::make_shared(concat1, concat2); + } + + auto result = std::make_shared(matmul); + return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input1_params, input2_params}); +} + } // namespace handle_transpose_before_matmul namespace handle_transpose_after_matmul { @@ -235,6 +307,9 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTest) { RunTest( handle_transpose_before_matmul::CreateMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, false), handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, true)); + RunTest( + handle_transpose_before_matmul::CreateConcatMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, false), + handle_transpose_before_matmul::CreateConcatTransposeMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, true)); } TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) { @@ -244,12 +319,18 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) { RunTest( handle_transpose_before_matmul::CreateMatmulFunction({9, 2}, {9, 2}, {2, 1}, false), handle_transpose_before_matmul::CreateMatmulFunction({9, 2}, {9, 2}, {2, 1}, false)); + RunTest( + handle_transpose_before_matmul::CreateConcatMatmulFunction({8, 16}, {8, 16}, {16, 16}, {16, 16}, false), + handle_transpose_before_matmul::CreateConcatMatmulFunction({8, 16}, {8, 16}, {16, 16}, {16, 16}, false)); } TEST(TransformationTests, RemoveTransposeBeforeMatmulTest) { RunTest( handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 8}, {2, 4}, {2, 1}, false), handle_transpose_before_matmul::CreateMatmulFunction({1, 8}, {2, 4}, {2, 1}, true)); + RunTest( + handle_transpose_before_matmul::CreateConcatTransposeMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, false), + handle_transpose_before_matmul::CreateConcatMatmulFunction({4, 16}, {8, 8}, {8, 16}, {16, 8}, true)); } TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {