From 8b20ccc6c88232e57ba40a62854e4dfb2f036718 Mon Sep 17 00:00:00 2001 From: Egor Duplensky Date: Wed, 27 Oct 2021 15:07:33 +0300 Subject: [PATCH] [CPU] Fix MatMul node for the case of strided inputs and outputs (#8070) --- .../nodes/mkldnn_matmul_node.cpp | 30 ++---- .../mkldnn_plugin/nodes/mkldnn_matmul_node.h | 3 - .../src/align_mamtul_input_ranks.cpp | 2 +- .../src/matmul_strided_inputs_outputs.cpp | 101 ++++++++++++++++++ 4 files changed, 110 insertions(+), 26 deletions(-) create mode 100644 inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/matmul_strided_inputs_outputs.cpp diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp index 8da70c508fd..4cc3e0d7021 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp @@ -189,15 +189,14 @@ void MKLDNNMatMulNode::getSupportedDescriptors() { return strides; }; - initialInShapes[0] = inputShapes[0]; - initialInShapes[1] = inputShapes[1]; + std::array newShapes{getInputShapeAtPort(0), getInputShapeAtPort(1)}; - const VectorDims inStrides0 = getStridesAndDims(inputShapes[0], transposeIn[0]); - const VectorDims inStrides1 = getStridesAndDims(inputShapes[1], transposeIn[1]); + const VectorDims inStrides0 = getStridesAndDims(newShapes[0], transposeIn[0]); + const VectorDims inStrides1 = getStridesAndDims(newShapes[1], transposeIn[1]); const VectorDims outStrides = getStridesAndDims(outputShapes[0], false); - inDataDesc[0] = std::make_shared(firstInPortPrec, inputShapes[0], inStrides0); - inDataDesc[1] = std::make_shared(secondInPortPrec, inputShapes[1], inStrides1); + inDataDesc[0] = std::make_shared(firstInPortPrec, newShapes[0], inStrides0); + inDataDesc[1] = std::make_shared(secondInPortPrec, newShapes[1], inStrides1); outDataDesc = std::make_shared(outPortPrec, getOutputShapeAtPort(0), outStrides); createDescriptor({inDataDesc[0], inDataDesc[1]}, {outDataDesc}); @@ -229,13 +228,7 @@ void MKLDNNMatMulNode::initSupportedPrimitiveDescriptors() { PortConfig portConfig; portConfig.inPlace = -1; portConfig.constant = false; - - auto src_desc = getSrcMemDesc(itpd, i); - if (src_desc->getType() & MemoryDescType::Blocked) { - portConfig.desc = src_desc->as()->cloneWithUndefStridesAndOffset(); - } else { - portConfig.desc = std::move(src_desc); - } + portConfig.desc = getSrcMemDesc(itpd, i); config.inConfs.push_back(portConfig); } @@ -244,13 +237,7 @@ void MKLDNNMatMulNode::initSupportedPrimitiveDescriptors() { PortConfig portConfig; portConfig.inPlace = canBeInPlace() ? 0 : -1; portConfig.constant = false; - - auto dst_desc = getDstMemDesc(itpd, i); - if (dst_desc->getType() & MemoryDescType::Blocked) { - portConfig.desc = dst_desc->as()->cloneWithUndefStridesAndOffset(); - } else { - portConfig.desc = std::move(dst_desc); - } + portConfig.desc = getDstMemDesc(itpd, i); config.outConfs.push_back(portConfig); } @@ -294,10 +281,9 @@ void MKLDNNMatMulNode::createPrimitive() { MemoryDescPtr MKLDNNMatMulNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) { auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1): primitive_desc_it.src_desc(idx); - return std::make_shared( MKLDNNExtensionUtils::DataTypeToIEPrecision(static_cast(desc.data.data_type)), - initialInShapes[idx]); /* provide initial shapes, so hide transpose effect */ + getInputShapeAtPort(idx)); /* provide initial shapes, so hide transpose effect */ } bool MKLDNNMatMulNode::created() const { diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h index 8820f7a4e6e..d23e089dfe3 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h @@ -43,9 +43,6 @@ private: /* whether to transpose input */ std::array transposeIn; - /* initial shapes without transpose, - * necessary to hide transpose effect from plugin */ - std::array initialInShapes; std::array inDataDesc; MemoryDescPtr outDataDesc; diff --git a/inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/align_mamtul_input_ranks.cpp b/inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/align_mamtul_input_ranks.cpp index fdffeb99b4b..45b61d95249 100644 --- a/inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/align_mamtul_input_ranks.cpp +++ b/inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/align_mamtul_input_ranks.cpp @@ -64,7 +64,7 @@ protected: int expectedNumOfReshapes = 0; }; -TEST_P(AlignMatMulInputRanksTest, supportedInputShapes) { +TEST_P(AlignMatMulInputRanksTest, CompareWithRefs) { SKIP_IF_CURRENT_TEST_IS_DISABLED() Run(); diff --git a/inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/matmul_strided_inputs_outputs.cpp b/inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/matmul_strided_inputs_outputs.cpp new file mode 100644 index 00000000000..116f4d67dab --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/matmul_strided_inputs_outputs.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils/cpu_test_utils.hpp" +#include "ngraph_functions/builders.hpp" + +using namespace ngraph; +using namespace InferenceEngine; +using namespace CPUTestUtils; + +namespace SubgraphTestsDefinitions { + +using MatmulStridedInputsOutputsTestParams = Precision; + +class MatmulStridedInputsOutputsTest : public testing::WithParamInterface, + public CPUTestsBase, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + Precision netPrecision; + netPrecision = obj.param; + + std::ostringstream result; + result << "netPRC=" << netPrecision.name() << "_"; + + return result.str(); + } + +protected: + void SetUp() override { + targetDevice = CommonTestUtils::DEVICE_CPU; + Precision netPrecision; + netPrecision = this->GetParam(); + const auto ngPrec = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + + SizeVector splitShape{1, 2, 1, 16}; + auto splitInputParams = builder::makeParams(ngPrec, {splitShape}); + const auto splitOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes(splitInputParams)); + const auto split = builder::makeSplit(splitOutputNodes[0], ngPrec, 2 /* splits */, 1 /* 2nd axis */); + + std::vector concatShapes{{1, 1, 8, 8}, {1, 1, 8, 8}}; + auto concatInputParams = builder::makeParams(ngPrec, {concatShapes}); + const auto concatOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes(concatInputParams)); + const auto concat = builder::makeConcat(concatOutputNodes, 2); + + const auto matMul1 = builder::makeMatMul(split->output(0), concat, false, false); + + SizeVector matmulShape{1, 1, 16, 8}; + auto matmulInputParams = builder::makeParams(ngPrec, {matmulShape}); + const auto matmulOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes(matmulInputParams)); + + const auto matMul2 = builder::makeMatMul(split->output(1), matmulOutputNodes[0], false, false); + + const auto concatMatMuls = builder::makeConcat({matMul1, matMul2}, 2 /* 3rd axis */); + + ngraph::ParameterVector inputParams = {splitInputParams[0], concatInputParams[0], concatInputParams[1], matmulInputParams[0]}; + function = makeNgraphFunction(ngPrec, inputParams, concatMatMuls, "MatmulStridedInputsOutputs"); + } +}; + +/* Network with two MatMul nodes and multiple inplace nodes + * Test that MatMul node works correctly with strided inputs / outputs + + Input Input Input + \ / | + \ / | + \ / | + \ / | + Concat Split Input + \ / \ / + \ / \ / + \ / \ / + \ / \ / + MatMul MatMul + \ / + \ / + \ / + \ / + Concat + | + | + Output +*/ + +TEST_P(MatmulStridedInputsOutputsTest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + Run(); +} + +namespace { + +INSTANTIATE_TEST_SUITE_P(smoke_Check, MatmulStridedInputsOutputsTest, + ::testing::Values(Precision::FP32, + Precision::BF16), + MatmulStridedInputsOutputsTest::getTestCaseName); + +} // namespace + +} // namespace SubgraphTestsDefinitions