[CPU] Get interpolate scales input during interpolate node init if the input is Constant. (#10229)
This commit is contained in:
parent
cf805b17b9
commit
8bbabf8720
@ -1881,6 +1881,12 @@ MKLDNNInterpolateNode::MKLDNNInterpolateNode(const std::shared_ptr<ngraph::Node>
|
||||
interpAttrs.padEnd[i] = static_cast<int>(interpAttr.pads_end[i]);
|
||||
}
|
||||
|
||||
const auto scalesNode = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(SCALES_ID));
|
||||
if (scalesNode) {
|
||||
scales = scalesNode->cast_vector<float>();
|
||||
isScaleConstant = true;
|
||||
}
|
||||
|
||||
if (isAxesSpecified) {
|
||||
axes = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(AXES_ID))->cast_vector<int>();
|
||||
} else {
|
||||
@ -2095,6 +2101,12 @@ void MKLDNNInterpolateNode::prepareParams() {
|
||||
const auto &srcDims = srcMemPtr->getStaticDims();
|
||||
const auto &dstDims = dstMemPtr->getStaticDims();
|
||||
|
||||
if (!isScaleConstant) {
|
||||
const auto& scalesMem = getParentEdgesAtPort(SCALES_ID)[0]->getMemory();
|
||||
const float* scalesData = reinterpret_cast<const float *>(scalesMem.GetPtr());
|
||||
scales.assign(scalesData, scalesData + scalesMem.getStaticDims()[0]);
|
||||
}
|
||||
|
||||
std::vector<float> dataScales = getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims);
|
||||
if (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) {
|
||||
IE_THROW() << "Interpolate layer only supports resize on spatial dimensions(depth, height and width)";
|
||||
@ -2222,7 +2234,6 @@ SizeVector MKLDNNInterpolateNode::getPaddedInputShape(const VectorDims &srcDims,
|
||||
// scales is a required input, but should not use input scales when "size" case, which may added eps that lead to inaccurate result, recalculate scales instead.
|
||||
std::vector<float> MKLDNNInterpolateNode::getScales(const VectorDims &srcDimPad, const VectorDims &dstDim) {
|
||||
const size_t dataRank = getInputShapeAtPort(DATA_ID).getRank();
|
||||
const float *scales = reinterpret_cast<const float *>(getParentEdgesAtPort(SCALES_ID)[0]->getMemory().GetPtr());
|
||||
std::vector<float> fullScales(dataRank, 1.f);
|
||||
const size_t axesRank = axes.size();
|
||||
for (size_t i = 0; i < axesRank; i++) {
|
||||
|
@ -243,6 +243,8 @@ private:
|
||||
|
||||
bool isAxesSpecified = false;
|
||||
std::vector<int> axes;
|
||||
std::vector<float> scales;
|
||||
bool isScaleConstant = false;
|
||||
|
||||
// 6 ptrs for each quantization, 2 ptrs for each depth_wise
|
||||
std::vector<const void*> postOpsDataPtrs;
|
||||
|
Loading…
Reference in New Issue
Block a user