[CPU] Fix MatMul node for the case of strided inputs and outputs (#8070)
This commit is contained in:
parent
ce9a968030
commit
8b20ccc6c8
@ -189,15 +189,14 @@ void MKLDNNMatMulNode::getSupportedDescriptors() {
|
||||
return strides;
|
||||
};
|
||||
|
||||
initialInShapes[0] = inputShapes[0];
|
||||
initialInShapes[1] = inputShapes[1];
|
||||
std::array<Shape, 2> newShapes{getInputShapeAtPort(0), getInputShapeAtPort(1)};
|
||||
|
||||
const VectorDims inStrides0 = getStridesAndDims(inputShapes[0], transposeIn[0]);
|
||||
const VectorDims inStrides1 = getStridesAndDims(inputShapes[1], transposeIn[1]);
|
||||
const VectorDims inStrides0 = getStridesAndDims(newShapes[0], transposeIn[0]);
|
||||
const VectorDims inStrides1 = getStridesAndDims(newShapes[1], transposeIn[1]);
|
||||
const VectorDims outStrides = getStridesAndDims(outputShapes[0], false);
|
||||
|
||||
inDataDesc[0] = std::make_shared<DnnlBlockedMemoryDesc>(firstInPortPrec, inputShapes[0], inStrides0);
|
||||
inDataDesc[1] = std::make_shared<DnnlBlockedMemoryDesc>(secondInPortPrec, inputShapes[1], inStrides1);
|
||||
inDataDesc[0] = std::make_shared<DnnlBlockedMemoryDesc>(firstInPortPrec, newShapes[0], inStrides0);
|
||||
inDataDesc[1] = std::make_shared<DnnlBlockedMemoryDesc>(secondInPortPrec, newShapes[1], inStrides1);
|
||||
outDataDesc = std::make_shared<DnnlBlockedMemoryDesc>(outPortPrec, getOutputShapeAtPort(0), outStrides);
|
||||
|
||||
createDescriptor({inDataDesc[0], inDataDesc[1]}, {outDataDesc});
|
||||
@ -229,13 +228,7 @@ void MKLDNNMatMulNode::initSupportedPrimitiveDescriptors() {
|
||||
PortConfig portConfig;
|
||||
portConfig.inPlace = -1;
|
||||
portConfig.constant = false;
|
||||
|
||||
auto src_desc = getSrcMemDesc(itpd, i);
|
||||
if (src_desc->getType() & MemoryDescType::Blocked) {
|
||||
portConfig.desc = src_desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
|
||||
} else {
|
||||
portConfig.desc = std::move(src_desc);
|
||||
}
|
||||
portConfig.desc = getSrcMemDesc(itpd, i);
|
||||
|
||||
config.inConfs.push_back(portConfig);
|
||||
}
|
||||
@ -244,13 +237,7 @@ void MKLDNNMatMulNode::initSupportedPrimitiveDescriptors() {
|
||||
PortConfig portConfig;
|
||||
portConfig.inPlace = canBeInPlace() ? 0 : -1;
|
||||
portConfig.constant = false;
|
||||
|
||||
auto dst_desc = getDstMemDesc(itpd, i);
|
||||
if (dst_desc->getType() & MemoryDescType::Blocked) {
|
||||
portConfig.desc = dst_desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
|
||||
} else {
|
||||
portConfig.desc = std::move(dst_desc);
|
||||
}
|
||||
portConfig.desc = getDstMemDesc(itpd, i);
|
||||
|
||||
config.outConfs.push_back(portConfig);
|
||||
}
|
||||
@ -294,10 +281,9 @@ void MKLDNNMatMulNode::createPrimitive() {
|
||||
|
||||
MemoryDescPtr MKLDNNMatMulNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
|
||||
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1): primitive_desc_it.src_desc(idx);
|
||||
|
||||
return std::make_shared<CpuBlockedMemoryDesc>(
|
||||
MKLDNNExtensionUtils::DataTypeToIEPrecision(static_cast<mkldnn::memory::data_type>(desc.data.data_type)),
|
||||
initialInShapes[idx]); /* provide initial shapes, so hide transpose effect */
|
||||
getInputShapeAtPort(idx)); /* provide initial shapes, so hide transpose effect */
|
||||
}
|
||||
|
||||
bool MKLDNNMatMulNode::created() const {
|
||||
|
@ -43,9 +43,6 @@ private:
|
||||
|
||||
/* whether to transpose input */
|
||||
std::array<bool, 2> transposeIn;
|
||||
/* initial shapes without transpose,
|
||||
* necessary to hide transpose effect from plugin */
|
||||
std::array<Shape, 2> initialInShapes;
|
||||
|
||||
std::array<MemoryDescPtr, 2> inDataDesc;
|
||||
MemoryDescPtr outDataDesc;
|
||||
|
@ -64,7 +64,7 @@ protected:
|
||||
int expectedNumOfReshapes = 0;
|
||||
};
|
||||
|
||||
TEST_P(AlignMatMulInputRanksTest, supportedInputShapes) {
|
||||
TEST_P(AlignMatMulInputRanksTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
|
@ -0,0 +1,101 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using MatmulStridedInputsOutputsTestParams = Precision;
|
||||
|
||||
class MatmulStridedInputsOutputsTest : public testing::WithParamInterface<MatmulStridedInputsOutputsTestParams>,
|
||||
public CPUTestsBase,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<MatmulStridedInputsOutputsTestParams> obj) {
|
||||
Precision netPrecision;
|
||||
netPrecision = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
Precision netPrecision;
|
||||
netPrecision = this->GetParam();
|
||||
const auto ngPrec = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
SizeVector splitShape{1, 2, 1, 16};
|
||||
auto splitInputParams = builder::makeParams(ngPrec, {splitShape});
|
||||
const auto splitOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(splitInputParams));
|
||||
const auto split = builder::makeSplit(splitOutputNodes[0], ngPrec, 2 /* splits */, 1 /* 2nd axis */);
|
||||
|
||||
std::vector<SizeVector> concatShapes{{1, 1, 8, 8}, {1, 1, 8, 8}};
|
||||
auto concatInputParams = builder::makeParams(ngPrec, {concatShapes});
|
||||
const auto concatOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(concatInputParams));
|
||||
const auto concat = builder::makeConcat(concatOutputNodes, 2);
|
||||
|
||||
const auto matMul1 = builder::makeMatMul(split->output(0), concat, false, false);
|
||||
|
||||
SizeVector matmulShape{1, 1, 16, 8};
|
||||
auto matmulInputParams = builder::makeParams(ngPrec, {matmulShape});
|
||||
const auto matmulOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(matmulInputParams));
|
||||
|
||||
const auto matMul2 = builder::makeMatMul(split->output(1), matmulOutputNodes[0], false, false);
|
||||
|
||||
const auto concatMatMuls = builder::makeConcat({matMul1, matMul2}, 2 /* 3rd axis */);
|
||||
|
||||
ngraph::ParameterVector inputParams = {splitInputParams[0], concatInputParams[0], concatInputParams[1], matmulInputParams[0]};
|
||||
function = makeNgraphFunction(ngPrec, inputParams, concatMatMuls, "MatmulStridedInputsOutputs");
|
||||
}
|
||||
};
|
||||
|
||||
/* Network with two MatMul nodes and multiple inplace nodes
|
||||
* Test that MatMul node works correctly with strided inputs / outputs
|
||||
|
||||
Input Input Input
|
||||
\ / |
|
||||
\ / |
|
||||
\ / |
|
||||
\ / |
|
||||
Concat Split Input
|
||||
\ / \ /
|
||||
\ / \ /
|
||||
\ / \ /
|
||||
\ / \ /
|
||||
MatMul MatMul
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
Concat
|
||||
|
|
||||
|
|
||||
Output
|
||||
*/
|
||||
|
||||
TEST_P(MatmulStridedInputsOutputsTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Check, MatmulStridedInputsOutputsTest,
|
||||
::testing::Values(Precision::FP32,
|
||||
Precision::BF16),
|
||||
MatmulStridedInputsOutputsTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user