[CPU] FullyConnected leftovers (#9621)
This commit is contained in:
@@ -129,6 +129,33 @@ std::vector<memory::format_tag> MKLDNNFullyConnectedNode::getAvailableFormatsFor
|
||||
return {memory::format_tag::any};
|
||||
}
|
||||
|
||||
VectorDims MKLDNNFullyConnectedNode::makeDummyInputDims() const {
|
||||
const auto& inShape = getInputShapeAtPort(DATA_ID);
|
||||
const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
|
||||
|
||||
auto inMinDims = inShape.getMinDims();
|
||||
auto inMaxDims = inShape.getMaxDims();
|
||||
|
||||
if (inMinDims.size() == 3) {
|
||||
inMinDims.back() = weightDims.back();
|
||||
inMaxDims.back() = weightDims.back();
|
||||
} else {
|
||||
for (size_t i = 1; i < inMinDims.size(); i++) {
|
||||
inMinDims[i] = weightDims[i];
|
||||
inMaxDims[i] = weightDims[i];
|
||||
}
|
||||
}
|
||||
return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims();
|
||||
}
|
||||
|
||||
VectorDims MKLDNNFullyConnectedNode::makeDummyOutputDims(const VectorDims& inDims) const {
|
||||
std::vector<Shape> inShapes = {Shape(inDims), getInputShapeAtPort(WEIGHTS_ID)};
|
||||
if (inputShapes.size() > 2) {
|
||||
inShapes.emplace_back(getInputShapeAtPort(BIAS_ID));
|
||||
}
|
||||
return shapeInferGeneric(inShapes).front();
|
||||
}
|
||||
|
||||
void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
|
||||
if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
|
||||
IE_THROW() << errorPrefix << " has incorrect number of input edges";
|
||||
@@ -163,32 +190,8 @@ void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
|
||||
outputDataType = memory::data_type::bf16;
|
||||
}
|
||||
|
||||
const auto& inShape = getInputShapeAtPort(DATA_ID);
|
||||
inDims = inShape.getDims();
|
||||
const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
|
||||
outDims = getOutputShapeAtPort(0).getDims();
|
||||
|
||||
if (isDynamicNode()) {
|
||||
auto inMinDims = inShape.getMinDims();
|
||||
auto inMaxDims = inShape.getMaxDims();
|
||||
|
||||
if (inMinDims.size() == 3) {
|
||||
inMinDims.back() = weightDims.back();
|
||||
inMaxDims.back() = weightDims.back();
|
||||
} else {
|
||||
for (size_t i = 1; i < inMinDims.size(); i++) {
|
||||
inMinDims[i] = weightDims[i];
|
||||
inMaxDims[i] = weightDims[i];
|
||||
}
|
||||
}
|
||||
inDims = MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims();
|
||||
|
||||
std::vector<Shape> inShapes = {Shape(inDims), Shape(weightDims)};
|
||||
if (inputShapes.size() > 2) {
|
||||
inShapes.emplace_back(getInputShapeAtPort(BIAS_ID));
|
||||
}
|
||||
outDims = shapeInferGeneric(inShapes).front();
|
||||
}
|
||||
inDims = isDynamicNode() ? makeDummyInputDims() : getInputShapeAtPort(DATA_ID).getStaticDims();
|
||||
outDims = isDynamicNode() ? makeDummyOutputDims(inDims) : getOutputShapeAtPort(0).getStaticDims();
|
||||
|
||||
for (auto format : getAvailableFormatsForDims(getInputShapeAtPort(0))) {
|
||||
auto in_candidate = mkldnn::memory::desc(MKLDNNExtensionUtils::convertToDnnlDims(inDims), inputDataType, format);
|
||||
@@ -219,14 +222,8 @@ void MKLDNNFullyConnectedNode::prepareParams() {
|
||||
if (selected_pd == nullptr)
|
||||
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
|
||||
|
||||
auto initPrimitiveAttr = [&]() {
|
||||
mkldnn::primitive_attr attr;
|
||||
setPostOps(attr, dstMemPtr->getStaticDims());
|
||||
|
||||
return std::make_shared<mkldnn::primitive_attr>(std::move(attr));
|
||||
};
|
||||
|
||||
AttrPtr attr = initPrimitiveAttr();
|
||||
AttrPtr attr = std::make_shared<mkldnn::primitive_attr>();
|
||||
setPostOps(*attr, dstMemPtr->getStaticDims());
|
||||
|
||||
DnnlMemoryDescCPtr weightDesc = wghMemPtr->GetDescWithType<DnnlMemoryDesc>();
|
||||
DnnlMemoryDescCPtr biasDesc = nullptr;
|
||||
|
||||
@@ -55,6 +55,9 @@ private:
|
||||
void createDescriptorInternal(const mkldnn::memory::desc &inputDesc,
|
||||
const mkldnn::memory::desc &outputDesc);
|
||||
|
||||
VectorDims makeDummyInputDims() const;
|
||||
VectorDims makeDummyOutputDims(const VectorDims& inDims) const;
|
||||
|
||||
VectorDims inDims;
|
||||
VectorDims outDims;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user