[CPU] Get interpolate scales input during interpolate node init if the input is Constant. (#10229)

This commit is contained in:
Mang Guo 2022-02-11 15:27:50 +08:00 committed by GitHub
parent cf805b17b9
commit 8bbabf8720
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 1 deletions

View File

@ -1881,6 +1881,12 @@ MKLDNNInterpolateNode::MKLDNNInterpolateNode(const std::shared_ptr<ngraph::Node>
interpAttrs.padEnd[i] = static_cast<int>(interpAttr.pads_end[i]); interpAttrs.padEnd[i] = static_cast<int>(interpAttr.pads_end[i]);
} }
const auto scalesNode = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(SCALES_ID));
if (scalesNode) {
scales = scalesNode->cast_vector<float>();
isScaleConstant = true;
}
if (isAxesSpecified) { if (isAxesSpecified) {
axes = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(AXES_ID))->cast_vector<int>(); axes = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(AXES_ID))->cast_vector<int>();
} else { } else {
@ -2095,6 +2101,12 @@ void MKLDNNInterpolateNode::prepareParams() {
const auto &srcDims = srcMemPtr->getStaticDims(); const auto &srcDims = srcMemPtr->getStaticDims();
const auto &dstDims = dstMemPtr->getStaticDims(); const auto &dstDims = dstMemPtr->getStaticDims();
if (!isScaleConstant) {
const auto& scalesMem = getParentEdgesAtPort(SCALES_ID)[0]->getMemory();
const float* scalesData = reinterpret_cast<const float *>(scalesMem.GetPtr());
scales.assign(scalesData, scalesData + scalesMem.getStaticDims()[0]);
}
std::vector<float> dataScales = getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims); std::vector<float> dataScales = getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims);
if (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) { if (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) {
IE_THROW() << "Interpolate layer only supports resize on spatial dimensions(depth, height and width)"; IE_THROW() << "Interpolate layer only supports resize on spatial dimensions(depth, height and width)";
@ -2222,7 +2234,6 @@ SizeVector MKLDNNInterpolateNode::getPaddedInputShape(const VectorDims &srcDims,
// scales is a required input, but should not use input scales when "size" case, which may added eps that lead to inaccurate result, recalculate scales instead. // scales is a required input, but should not use input scales when "size" case, which may added eps that lead to inaccurate result, recalculate scales instead.
std::vector<float> MKLDNNInterpolateNode::getScales(const VectorDims &srcDimPad, const VectorDims &dstDim) { std::vector<float> MKLDNNInterpolateNode::getScales(const VectorDims &srcDimPad, const VectorDims &dstDim) {
const size_t dataRank = getInputShapeAtPort(DATA_ID).getRank(); const size_t dataRank = getInputShapeAtPort(DATA_ID).getRank();
const float *scales = reinterpret_cast<const float *>(getParentEdgesAtPort(SCALES_ID)[0]->getMemory().GetPtr());
std::vector<float> fullScales(dataRank, 1.f); std::vector<float> fullScales(dataRank, 1.f);
const size_t axesRank = axes.size(); const size_t axesRank = axes.size();
for (size_t i = 0; i < axesRank; i++) { for (size_t i = 0; i < axesRank; i++) {

View File

@ -243,6 +243,8 @@ private:
bool isAxesSpecified = false; bool isAxesSpecified = false;
std::vector<int> axes; std::vector<int> axes;
std::vector<float> scales;
bool isScaleConstant = false;
// 6 ptrs for each quantization, 2 ptrs for each depth_wise // 6 ptrs for each quantization, 2 ptrs for each depth_wise
std::vector<const void*> postOpsDataPtrs; std::vector<const void*> postOpsDataPtrs;