[CPU] Fix MatMul node for the case of strided inputs and outputs (#8070)

This commit is contained in:
Egor Duplensky 2021-10-27 15:07:33 +03:00 committed by GitHub
parent ce9a968030
commit 8b20ccc6c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 110 additions and 26 deletions

View File

@ -189,15 +189,14 @@ void MKLDNNMatMulNode::getSupportedDescriptors() {
return strides; return strides;
}; };
initialInShapes[0] = inputShapes[0]; std::array<Shape, 2> newShapes{getInputShapeAtPort(0), getInputShapeAtPort(1)};
initialInShapes[1] = inputShapes[1];
const VectorDims inStrides0 = getStridesAndDims(inputShapes[0], transposeIn[0]); const VectorDims inStrides0 = getStridesAndDims(newShapes[0], transposeIn[0]);
const VectorDims inStrides1 = getStridesAndDims(inputShapes[1], transposeIn[1]); const VectorDims inStrides1 = getStridesAndDims(newShapes[1], transposeIn[1]);
const VectorDims outStrides = getStridesAndDims(outputShapes[0], false); const VectorDims outStrides = getStridesAndDims(outputShapes[0], false);
inDataDesc[0] = std::make_shared<DnnlBlockedMemoryDesc>(firstInPortPrec, inputShapes[0], inStrides0); inDataDesc[0] = std::make_shared<DnnlBlockedMemoryDesc>(firstInPortPrec, newShapes[0], inStrides0);
inDataDesc[1] = std::make_shared<DnnlBlockedMemoryDesc>(secondInPortPrec, inputShapes[1], inStrides1); inDataDesc[1] = std::make_shared<DnnlBlockedMemoryDesc>(secondInPortPrec, newShapes[1], inStrides1);
outDataDesc = std::make_shared<DnnlBlockedMemoryDesc>(outPortPrec, getOutputShapeAtPort(0), outStrides); outDataDesc = std::make_shared<DnnlBlockedMemoryDesc>(outPortPrec, getOutputShapeAtPort(0), outStrides);
createDescriptor({inDataDesc[0], inDataDesc[1]}, {outDataDesc}); createDescriptor({inDataDesc[0], inDataDesc[1]}, {outDataDesc});
@ -229,13 +228,7 @@ void MKLDNNMatMulNode::initSupportedPrimitiveDescriptors() {
PortConfig portConfig; PortConfig portConfig;
portConfig.inPlace = -1; portConfig.inPlace = -1;
portConfig.constant = false; portConfig.constant = false;
portConfig.desc = getSrcMemDesc(itpd, i);
auto src_desc = getSrcMemDesc(itpd, i);
if (src_desc->getType() & MemoryDescType::Blocked) {
portConfig.desc = src_desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
} else {
portConfig.desc = std::move(src_desc);
}
config.inConfs.push_back(portConfig); config.inConfs.push_back(portConfig);
} }
@ -244,13 +237,7 @@ void MKLDNNMatMulNode::initSupportedPrimitiveDescriptors() {
PortConfig portConfig; PortConfig portConfig;
portConfig.inPlace = canBeInPlace() ? 0 : -1; portConfig.inPlace = canBeInPlace() ? 0 : -1;
portConfig.constant = false; portConfig.constant = false;
portConfig.desc = getDstMemDesc(itpd, i);
auto dst_desc = getDstMemDesc(itpd, i);
if (dst_desc->getType() & MemoryDescType::Blocked) {
portConfig.desc = dst_desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
} else {
portConfig.desc = std::move(dst_desc);
}
config.outConfs.push_back(portConfig); config.outConfs.push_back(portConfig);
} }
@ -294,10 +281,9 @@ void MKLDNNMatMulNode::createPrimitive() {
MemoryDescPtr MKLDNNMatMulNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) { MemoryDescPtr MKLDNNMatMulNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1): primitive_desc_it.src_desc(idx); auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1): primitive_desc_it.src_desc(idx);
return std::make_shared<CpuBlockedMemoryDesc>( return std::make_shared<CpuBlockedMemoryDesc>(
MKLDNNExtensionUtils::DataTypeToIEPrecision(static_cast<mkldnn::memory::data_type>(desc.data.data_type)), MKLDNNExtensionUtils::DataTypeToIEPrecision(static_cast<mkldnn::memory::data_type>(desc.data.data_type)),
initialInShapes[idx]); /* provide initial shapes, so hide transpose effect */ getInputShapeAtPort(idx)); /* provide initial shapes, so hide transpose effect */
} }
bool MKLDNNMatMulNode::created() const { bool MKLDNNMatMulNode::created() const {

View File

@ -43,9 +43,6 @@ private:
/* whether to transpose input */ /* whether to transpose input */
std::array<bool, 2> transposeIn; std::array<bool, 2> transposeIn;
/* initial shapes without transpose,
* necessary to hide transpose effect from plugin */
std::array<Shape, 2> initialInShapes;
std::array<MemoryDescPtr, 2> inDataDesc; std::array<MemoryDescPtr, 2> inDataDesc;
MemoryDescPtr outDataDesc; MemoryDescPtr outDataDesc;

View File

@ -64,7 +64,7 @@ protected:
int expectedNumOfReshapes = 0; int expectedNumOfReshapes = 0;
}; };
TEST_P(AlignMatMulInputRanksTest, supportedInputShapes) { TEST_P(AlignMatMulInputRanksTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()
Run(); Run();

View File

@ -0,0 +1,101 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils/cpu_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
namespace SubgraphTestsDefinitions {
using MatmulStridedInputsOutputsTestParams = Precision;
class MatmulStridedInputsOutputsTest : public testing::WithParamInterface<MatmulStridedInputsOutputsTestParams>,
public CPUTestsBase,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<MatmulStridedInputsOutputsTestParams> obj) {
Precision netPrecision;
netPrecision = obj.param;
std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_";
return result.str();
}
protected:
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
Precision netPrecision;
netPrecision = this->GetParam();
const auto ngPrec = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
SizeVector splitShape{1, 2, 1, 16};
auto splitInputParams = builder::makeParams(ngPrec, {splitShape});
const auto splitOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(splitInputParams));
const auto split = builder::makeSplit(splitOutputNodes[0], ngPrec, 2 /* splits */, 1 /* 2nd axis */);
std::vector<SizeVector> concatShapes{{1, 1, 8, 8}, {1, 1, 8, 8}};
auto concatInputParams = builder::makeParams(ngPrec, {concatShapes});
const auto concatOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(concatInputParams));
const auto concat = builder::makeConcat(concatOutputNodes, 2);
const auto matMul1 = builder::makeMatMul(split->output(0), concat, false, false);
SizeVector matmulShape{1, 1, 16, 8};
auto matmulInputParams = builder::makeParams(ngPrec, {matmulShape});
const auto matmulOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(matmulInputParams));
const auto matMul2 = builder::makeMatMul(split->output(1), matmulOutputNodes[0], false, false);
const auto concatMatMuls = builder::makeConcat({matMul1, matMul2}, 2 /* 3rd axis */);
ngraph::ParameterVector inputParams = {splitInputParams[0], concatInputParams[0], concatInputParams[1], matmulInputParams[0]};
function = makeNgraphFunction(ngPrec, inputParams, concatMatMuls, "MatmulStridedInputsOutputs");
}
};
/* Network with two MatMul nodes and multiple inplace nodes
* Test that MatMul node works correctly with strided inputs / outputs
Input Input Input
\ / |
\ / |
\ / |
\ / |
Concat Split Input
\ / \ /
\ / \ /
\ / \ /
\ / \ /
MatMul MatMul
\ /
\ /
\ /
\ /
Concat
|
|
Output
*/
TEST_P(MatmulStridedInputsOutputsTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
Run();
}
namespace {
INSTANTIATE_TEST_SUITE_P(smoke_Check, MatmulStridedInputsOutputsTest,
::testing::Values(Precision::FP32,
Precision::BF16),
MatmulStridedInputsOutputsTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions