[CPU] FullyConnected leftovers (#9621)
This commit is contained in:
@@ -129,6 +129,33 @@ std::vector<memory::format_tag> MKLDNNFullyConnectedNode::getAvailableFormatsFor
|
|||||||
return {memory::format_tag::any};
|
return {memory::format_tag::any};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VectorDims MKLDNNFullyConnectedNode::makeDummyInputDims() const {
|
||||||
|
const auto& inShape = getInputShapeAtPort(DATA_ID);
|
||||||
|
const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
|
||||||
|
|
||||||
|
auto inMinDims = inShape.getMinDims();
|
||||||
|
auto inMaxDims = inShape.getMaxDims();
|
||||||
|
|
||||||
|
if (inMinDims.size() == 3) {
|
||||||
|
inMinDims.back() = weightDims.back();
|
||||||
|
inMaxDims.back() = weightDims.back();
|
||||||
|
} else {
|
||||||
|
for (size_t i = 1; i < inMinDims.size(); i++) {
|
||||||
|
inMinDims[i] = weightDims[i];
|
||||||
|
inMaxDims[i] = weightDims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims();
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorDims MKLDNNFullyConnectedNode::makeDummyOutputDims(const VectorDims& inDims) const {
|
||||||
|
std::vector<Shape> inShapes = {Shape(inDims), getInputShapeAtPort(WEIGHTS_ID)};
|
||||||
|
if (inputShapes.size() > 2) {
|
||||||
|
inShapes.emplace_back(getInputShapeAtPort(BIAS_ID));
|
||||||
|
}
|
||||||
|
return shapeInferGeneric(inShapes).front();
|
||||||
|
}
|
||||||
|
|
||||||
void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
|
void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
|
||||||
if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
|
if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
|
||||||
IE_THROW() << errorPrefix << " has incorrect number of input edges";
|
IE_THROW() << errorPrefix << " has incorrect number of input edges";
|
||||||
@@ -163,32 +190,8 @@ void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
|
|||||||
outputDataType = memory::data_type::bf16;
|
outputDataType = memory::data_type::bf16;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& inShape = getInputShapeAtPort(DATA_ID);
|
inDims = isDynamicNode() ? makeDummyInputDims() : getInputShapeAtPort(DATA_ID).getStaticDims();
|
||||||
inDims = inShape.getDims();
|
outDims = isDynamicNode() ? makeDummyOutputDims(inDims) : getOutputShapeAtPort(0).getStaticDims();
|
||||||
const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
|
|
||||||
outDims = getOutputShapeAtPort(0).getDims();
|
|
||||||
|
|
||||||
if (isDynamicNode()) {
|
|
||||||
auto inMinDims = inShape.getMinDims();
|
|
||||||
auto inMaxDims = inShape.getMaxDims();
|
|
||||||
|
|
||||||
if (inMinDims.size() == 3) {
|
|
||||||
inMinDims.back() = weightDims.back();
|
|
||||||
inMaxDims.back() = weightDims.back();
|
|
||||||
} else {
|
|
||||||
for (size_t i = 1; i < inMinDims.size(); i++) {
|
|
||||||
inMinDims[i] = weightDims[i];
|
|
||||||
inMaxDims[i] = weightDims[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inDims = MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims();
|
|
||||||
|
|
||||||
std::vector<Shape> inShapes = {Shape(inDims), Shape(weightDims)};
|
|
||||||
if (inputShapes.size() > 2) {
|
|
||||||
inShapes.emplace_back(getInputShapeAtPort(BIAS_ID));
|
|
||||||
}
|
|
||||||
outDims = shapeInferGeneric(inShapes).front();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto format : getAvailableFormatsForDims(getInputShapeAtPort(0))) {
|
for (auto format : getAvailableFormatsForDims(getInputShapeAtPort(0))) {
|
||||||
auto in_candidate = mkldnn::memory::desc(MKLDNNExtensionUtils::convertToDnnlDims(inDims), inputDataType, format);
|
auto in_candidate = mkldnn::memory::desc(MKLDNNExtensionUtils::convertToDnnlDims(inDims), inputDataType, format);
|
||||||
@@ -219,14 +222,8 @@ void MKLDNNFullyConnectedNode::prepareParams() {
|
|||||||
if (selected_pd == nullptr)
|
if (selected_pd == nullptr)
|
||||||
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
|
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
|
||||||
|
|
||||||
auto initPrimitiveAttr = [&]() {
|
AttrPtr attr = std::make_shared<mkldnn::primitive_attr>();
|
||||||
mkldnn::primitive_attr attr;
|
setPostOps(*attr, dstMemPtr->getStaticDims());
|
||||||
setPostOps(attr, dstMemPtr->getStaticDims());
|
|
||||||
|
|
||||||
return std::make_shared<mkldnn::primitive_attr>(std::move(attr));
|
|
||||||
};
|
|
||||||
|
|
||||||
AttrPtr attr = initPrimitiveAttr();
|
|
||||||
|
|
||||||
DnnlMemoryDescCPtr weightDesc = wghMemPtr->GetDescWithType<DnnlMemoryDesc>();
|
DnnlMemoryDescCPtr weightDesc = wghMemPtr->GetDescWithType<DnnlMemoryDesc>();
|
||||||
DnnlMemoryDescCPtr biasDesc = nullptr;
|
DnnlMemoryDescCPtr biasDesc = nullptr;
|
||||||
|
|||||||
@@ -55,6 +55,9 @@ private:
|
|||||||
void createDescriptorInternal(const mkldnn::memory::desc &inputDesc,
|
void createDescriptorInternal(const mkldnn::memory::desc &inputDesc,
|
||||||
const mkldnn::memory::desc &outputDesc);
|
const mkldnn::memory::desc &outputDesc);
|
||||||
|
|
||||||
|
VectorDims makeDummyInputDims() const;
|
||||||
|
VectorDims makeDummyOutputDims(const VectorDims& inDims) const;
|
||||||
|
|
||||||
VectorDims inDims;
|
VectorDims inDims;
|
||||||
VectorDims outDims;
|
VectorDims outDims;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user