[CPU] Fix FullyConnected node for strided inputs and outputs (#9575)
FullyConnected node cannot work with strided inputs (for example inplace descs) because it implicitly reshapes input and output tensors.
This commit is contained in:
parent
b5ea943267
commit
9cb6626ffd
@ -516,6 +516,55 @@ void MKLDNNFullyConnectedNode::createDescriptor(const std::vector<MemoryDescPtr>
|
||||
MemoryDescUtils::convertToDnnlMemoryDesc(outDesc)->getDnnlDesc());
|
||||
}
|
||||
|
||||
void MKLDNNFullyConnectedNode::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
for (auto& desc : descs) {
|
||||
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine());
|
||||
while (static_cast<bool>(itpd)) {
|
||||
// 3D FC requires implicit reshape so strides should be defined
|
||||
auto supportsUndefStridesAndOffset = [&]() {
|
||||
return getOutputShapeAtPort(0).getRank() == 2;
|
||||
};
|
||||
|
||||
NodeConfig config;
|
||||
config.dynBatchSupport = true;
|
||||
for (size_t i = 0; i < descInputNumbers(desc); i++) {
|
||||
PortConfig portConfig;
|
||||
portConfig.inPlace = -1;
|
||||
portConfig.constant = false;
|
||||
auto desc = getSrcMemDesc(itpd, i);
|
||||
if (supportsUndefStridesAndOffset()) {
|
||||
portConfig.desc = desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
|
||||
} else {
|
||||
portConfig.desc = std::move(desc);
|
||||
}
|
||||
config.inConfs.push_back(portConfig);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < descOutputNumbers(desc); i++) {
|
||||
PortConfig portConfig;
|
||||
portConfig.inPlace = canBeInPlace() ? 0 : -1;
|
||||
portConfig.constant = false;
|
||||
auto desc = getDstMemDesc(itpd, i);
|
||||
if (supportsUndefStridesAndOffset()) {
|
||||
portConfig.desc = desc->as<BlockedMemoryDesc>()->cloneWithUndefStridesAndOffset();
|
||||
} else {
|
||||
portConfig.desc = std::move(desc);
|
||||
}
|
||||
config.outConfs.push_back(portConfig);
|
||||
}
|
||||
|
||||
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
|
||||
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_type);
|
||||
if (!itpd.next_impl())
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<MemoryDesc> MKLDNNFullyConnectedNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
|
||||
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1) : primitive_desc_it.src_desc(idx);
|
||||
|
||||
|
@ -37,6 +37,7 @@ public:
|
||||
return static_cast<size_t>(getOriginalInputsNumber());
|
||||
}
|
||||
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
std::shared_ptr<MemoryDesc> getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
|
||||
std::shared_ptr<MemoryDesc> getDstMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
|
||||
|
||||
|
@ -0,0 +1,110 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using FullyConnectedStridedInputsOutputsTestParams = std::tuple<Precision,
|
||||
size_t>; // rank (2D or 3D)
|
||||
|
||||
class FullyConnectedStridedInputsOutputsTest : public testing::WithParamInterface<FullyConnectedStridedInputsOutputsTestParams>,
|
||||
public CPUTestsBase,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<FullyConnectedStridedInputsOutputsTestParams> obj) {
|
||||
Precision netPrecision;
|
||||
size_t rank;
|
||||
std::tie(netPrecision, rank) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "rank=" << rank;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
Precision netPrecision;
|
||||
size_t rank;
|
||||
std::tie(netPrecision, rank) = this->GetParam();
|
||||
const auto ngPrec = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto bcastTo3D = [](SizeVector& shape) {
|
||||
shape.insert(shape.begin(), 1);
|
||||
};
|
||||
|
||||
SizeVector splitShape{2, 16};
|
||||
if (rank == 3) bcastTo3D(splitShape);
|
||||
|
||||
auto params = builder::makeParams(ngPrec, {splitShape});
|
||||
|
||||
const auto splitOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(params));
|
||||
const auto splitAxis = rank == 3 ? 1 : 0;
|
||||
const auto split = builder::makeSplit(splitOutputNodes[0], ngPrec, 2 /* splits */, splitAxis);
|
||||
|
||||
SizeVector fcWeightsShape{16, 8};
|
||||
if (rank == 3) bcastTo3D(fcWeightsShape);
|
||||
|
||||
auto fc1secondInput = builder::makeInputLayer(ngPrec, helpers::InputLayerType::CONSTANT, fcWeightsShape);
|
||||
const auto fc1 = builder::makeMatMul(split->output(0), fc1secondInput, false, false);
|
||||
|
||||
auto fc2secondInputB = builder::makeInputLayer(ngPrec, helpers::InputLayerType::CONSTANT, fcWeightsShape);
|
||||
const auto fc2 = builder::makeMatMul(split->output(1), fc2secondInputB, false, false);
|
||||
|
||||
const auto fcConcatAxis = rank == 3 ? 1 : 0;
|
||||
const auto concatMatMuls = builder::makeConcat({fc1, fc2}, fcConcatAxis);
|
||||
|
||||
function = makeNgraphFunction(ngPrec, params, concatMatMuls, "FullyConnectedStridedInputsOutputs");
|
||||
}
|
||||
};
|
||||
|
||||
/* Network with two FullyConnected (FC) nodes and multiple inplace nodes
|
||||
* Test that MatMul node works correctly with strided inputs / outputs
|
||||
|
||||
Input
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
Input Split Input
|
||||
\ / \ /
|
||||
\ / \ /
|
||||
\ / \ /
|
||||
\ / \ /
|
||||
FC FC
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
Concat
|
||||
|
|
||||
|
|
||||
Output
|
||||
*/
|
||||
|
||||
TEST_P(FullyConnectedStridedInputsOutputsTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Check, FullyConnectedStridedInputsOutputsTest,
|
||||
::testing::Combine(::testing::Values(Precision::FP32, Precision::BF16),
|
||||
::testing::Values(2, 3)),
|
||||
FullyConnectedStridedInputsOutputsTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user