[CPU] FullyConnected leftovers (#9621)

This commit is contained in:
Maxim Andronov
2022-01-18 09:14:29 +03:00
committed by GitHub
parent 26f222dea5
commit dd8a073aa4
2 changed files with 34 additions and 34 deletions

View File

@@ -129,6 +129,33 @@ std::vector<memory::format_tag> MKLDNNFullyConnectedNode::getAvailableFormatsFor
return {memory::format_tag::any};
}
VectorDims MKLDNNFullyConnectedNode::makeDummyInputDims() const {
const auto& inShape = getInputShapeAtPort(DATA_ID);
const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
auto inMinDims = inShape.getMinDims();
auto inMaxDims = inShape.getMaxDims();
if (inMinDims.size() == 3) {
inMinDims.back() = weightDims.back();
inMaxDims.back() = weightDims.back();
} else {
for (size_t i = 1; i < inMinDims.size(); i++) {
inMinDims[i] = weightDims[i];
inMaxDims[i] = weightDims[i];
}
}
return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims();
}
VectorDims MKLDNNFullyConnectedNode::makeDummyOutputDims(const VectorDims& inDims) const {
std::vector<Shape> inShapes = {Shape(inDims), getInputShapeAtPort(WEIGHTS_ID)};
if (inputShapes.size() > 2) {
inShapes.emplace_back(getInputShapeAtPort(BIAS_ID));
}
return shapeInferGeneric(inShapes).front();
}
void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
IE_THROW() << errorPrefix << " has incorrect number of input edges";
@@ -163,32 +190,8 @@ void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
outputDataType = memory::data_type::bf16;
}
const auto& inShape = getInputShapeAtPort(DATA_ID);
inDims = inShape.getDims();
const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
outDims = getOutputShapeAtPort(0).getDims();
if (isDynamicNode()) {
auto inMinDims = inShape.getMinDims();
auto inMaxDims = inShape.getMaxDims();
if (inMinDims.size() == 3) {
inMinDims.back() = weightDims.back();
inMaxDims.back() = weightDims.back();
} else {
for (size_t i = 1; i < inMinDims.size(); i++) {
inMinDims[i] = weightDims[i];
inMaxDims[i] = weightDims[i];
}
}
inDims = MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims();
std::vector<Shape> inShapes = {Shape(inDims), Shape(weightDims)};
if (inputShapes.size() > 2) {
inShapes.emplace_back(getInputShapeAtPort(BIAS_ID));
}
outDims = shapeInferGeneric(inShapes).front();
}
inDims = isDynamicNode() ? makeDummyInputDims() : getInputShapeAtPort(DATA_ID).getStaticDims();
outDims = isDynamicNode() ? makeDummyOutputDims(inDims) : getOutputShapeAtPort(0).getStaticDims();
for (auto format : getAvailableFormatsForDims(getInputShapeAtPort(0))) {
auto in_candidate = mkldnn::memory::desc(MKLDNNExtensionUtils::convertToDnnlDims(inDims), inputDataType, format);
@@ -219,14 +222,8 @@ void MKLDNNFullyConnectedNode::prepareParams() {
if (selected_pd == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
auto initPrimitiveAttr = [&]() {
mkldnn::primitive_attr attr;
setPostOps(attr, dstMemPtr->getStaticDims());
return std::make_shared<mkldnn::primitive_attr>(std::move(attr));
};
AttrPtr attr = initPrimitiveAttr();
AttrPtr attr = std::make_shared<mkldnn::primitive_attr>();
setPostOps(*attr, dstMemPtr->getStaticDims());
DnnlMemoryDescCPtr weightDesc = wghMemPtr->GetDescWithType<DnnlMemoryDesc>();
DnnlMemoryDescCPtr biasDesc = nullptr;

View File

@@ -55,6 +55,9 @@ private:
void createDescriptorInternal(const mkldnn::memory::desc &inputDesc,
const mkldnn::memory::desc &outputDesc);
VectorDims makeDummyInputDims() const;
VectorDims makeDummyOutputDims(const VectorDims& inDims) const;
VectorDims inDims;
VectorDims outDims;