[CPU] Fix FullyConnected node for strided inputs and outputs (#9575)

FullyConnected node cannot work with strided inputs (for example
inplace descs) because it implicitly reshapes input and output tensors.
This commit is contained in:
Egor Duplensky 2022-01-28 22:20:09 +03:00 committed by GitHub
parent b5ea943267
commit 9cb6626ffd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 160 additions and 0 deletions

View File

@ -516,6 +516,55 @@ void MKLDNNFullyConnectedNode::createDescriptor(const std::vector<MemoryDescPtr>
MemoryDescUtils::convertToDnnlMemoryDesc(outDesc)->getDnnlDesc());
}
void MKLDNNFullyConnectedNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
for (auto& desc : descs) {
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine());
while (static_cast<bool>(itpd)) {
// 3D FC requires implicit reshape so strides should be defined
auto supportsUndefStridesAndOffset = [&]() {
return getOutputShapeAtPort(0).getRank() == 2;
};
NodeConfig config;
config.dynBatchSupport = true;
for (size_t i = 0; i < descInputNumbers(desc); i++) {
PortConfig portConfig;
portConfig.inPlace = -1;
portConfig.constant = false;
auto desc = getSrcMemDesc(itpd, i);
if (supportsUndefStridesAndOffset()) {
portConfig.desc = desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
} else {
portConfig.desc = std::move(desc);
}
config.inConfs.push_back(portConfig);
}
for (size_t i = 0; i < descOutputNumbers(desc); i++) {
PortConfig portConfig;
portConfig.inPlace = canBeInPlace() ? 0 : -1;
portConfig.constant = false;
auto desc = getDstMemDesc(itpd, i);
if (supportsUndefStridesAndOffset()) {
portConfig.desc = desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
} else {
portConfig.desc = std::move(desc);
}
config.outConfs.push_back(portConfig);
}
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
supportedPrimitiveDescriptors.emplace_back(config, impl_type);
if (!itpd.next_impl())
break;
}
}
}
std::shared_ptr<MemoryDesc> MKLDNNFullyConnectedNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1) : primitive_desc_it.src_desc(idx);

View File

@ -37,6 +37,7 @@ public:
return static_cast<size_t>(getOriginalInputsNumber());
}
void initSupportedPrimitiveDescriptors() override;
std::shared_ptr<MemoryDesc> getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
std::shared_ptr<MemoryDesc> getDstMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) override;

View File

@ -0,0 +1,110 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/core/partial_shape.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
namespace SubgraphTestsDefinitions {
using FullyConnectedStridedInputsOutputsTestParams = std::tuple<Precision,
size_t>; // rank (2D or 3D)
class FullyConnectedStridedInputsOutputsTest : public testing::WithParamInterface<FullyConnectedStridedInputsOutputsTestParams>,
public CPUTestsBase,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<FullyConnectedStridedInputsOutputsTestParams> obj) {
Precision netPrecision;
size_t rank;
std::tie(netPrecision, rank) = obj.param;
std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_";
result << "rank=" << rank;
return result.str();
}
protected:
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
Precision netPrecision;
size_t rank;
std::tie(netPrecision, rank) = this->GetParam();
const auto ngPrec = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto bcastTo3D = [](SizeVector& shape) {
shape.insert(shape.begin(), 1);
};
SizeVector splitShape{2, 16};
if (rank == 3) bcastTo3D(splitShape);
auto params = builder::makeParams(ngPrec, {splitShape});
const auto splitOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(params));
const auto splitAxis = rank == 3 ? 1 : 0;
const auto split = builder::makeSplit(splitOutputNodes[0], ngPrec, 2 /* splits */, splitAxis);
SizeVector fcWeightsShape{16, 8};
if (rank == 3) bcastTo3D(fcWeightsShape);
auto fc1secondInput = builder::makeInputLayer(ngPrec, helpers::InputLayerType::CONSTANT, fcWeightsShape);
const auto fc1 = builder::makeMatMul(split->output(0), fc1secondInput, false, false);
auto fc2secondInputB = builder::makeInputLayer(ngPrec, helpers::InputLayerType::CONSTANT, fcWeightsShape);
const auto fc2 = builder::makeMatMul(split->output(1), fc2secondInputB, false, false);
const auto fcConcatAxis = rank == 3 ? 1 : 0;
const auto concatMatMuls = builder::makeConcat({fc1, fc2}, fcConcatAxis);
function = makeNgraphFunction(ngPrec, params, concatMatMuls, "FullyConnectedStridedInputsOutputs");
}
};
/* Network with two FullyConnected (FC) nodes and multiple inplace nodes
* Test that MatMul node works correctly with strided inputs / outputs
Input
|
|
|
|
Input Split Input
\ / \ /
\ / \ /
\ / \ /
\ / \ /
FC FC
\ /
\ /
\ /
\ /
Concat
|
|
Output
*/
TEST_P(FullyConnectedStridedInputsOutputsTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
Run();
}
namespace {
INSTANTIATE_TEST_SUITE_P(smoke_Check, FullyConnectedStridedInputsOutputsTest,
::testing::Combine(::testing::Values(Precision::FP32, Precision::BF16),
::testing::Values(2, 3)),
FullyConnectedStridedInputsOutputsTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions