[CPU] Get interpolate scales input during interpolate node init if the input is Constant. (#10229)
This commit is contained in:
parent
cf805b17b9
commit
8bbabf8720
@ -1881,6 +1881,12 @@ MKLDNNInterpolateNode::MKLDNNInterpolateNode(const std::shared_ptr<ngraph::Node>
|
|||||||
interpAttrs.padEnd[i] = static_cast<int>(interpAttr.pads_end[i]);
|
interpAttrs.padEnd[i] = static_cast<int>(interpAttr.pads_end[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto scalesNode = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(SCALES_ID));
|
||||||
|
if (scalesNode) {
|
||||||
|
scales = scalesNode->cast_vector<float>();
|
||||||
|
isScaleConstant = true;
|
||||||
|
}
|
||||||
|
|
||||||
if (isAxesSpecified) {
|
if (isAxesSpecified) {
|
||||||
axes = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(AXES_ID))->cast_vector<int>();
|
axes = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(interp->get_input_node_shared_ptr(AXES_ID))->cast_vector<int>();
|
||||||
} else {
|
} else {
|
||||||
@ -2095,6 +2101,12 @@ void MKLDNNInterpolateNode::prepareParams() {
|
|||||||
const auto &srcDims = srcMemPtr->getStaticDims();
|
const auto &srcDims = srcMemPtr->getStaticDims();
|
||||||
const auto &dstDims = dstMemPtr->getStaticDims();
|
const auto &dstDims = dstMemPtr->getStaticDims();
|
||||||
|
|
||||||
|
if (!isScaleConstant) {
|
||||||
|
const auto& scalesMem = getParentEdgesAtPort(SCALES_ID)[0]->getMemory();
|
||||||
|
const float* scalesData = reinterpret_cast<const float *>(scalesMem.GetPtr());
|
||||||
|
scales.assign(scalesData, scalesData + scalesMem.getStaticDims()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<float> dataScales = getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims);
|
std::vector<float> dataScales = getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims);
|
||||||
if (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) {
|
if (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) {
|
||||||
IE_THROW() << "Interpolate layer only supports resize on spatial dimensions(depth, height and width)";
|
IE_THROW() << "Interpolate layer only supports resize on spatial dimensions(depth, height and width)";
|
||||||
@ -2222,7 +2234,6 @@ SizeVector MKLDNNInterpolateNode::getPaddedInputShape(const VectorDims &srcDims,
|
|||||||
// scales is a required input, but should not use input scales when "size" case, which may added eps that lead to inaccurate result, recalculate scales instead.
|
// scales is a required input, but should not use input scales when "size" case, which may added eps that lead to inaccurate result, recalculate scales instead.
|
||||||
std::vector<float> MKLDNNInterpolateNode::getScales(const VectorDims &srcDimPad, const VectorDims &dstDim) {
|
std::vector<float> MKLDNNInterpolateNode::getScales(const VectorDims &srcDimPad, const VectorDims &dstDim) {
|
||||||
const size_t dataRank = getInputShapeAtPort(DATA_ID).getRank();
|
const size_t dataRank = getInputShapeAtPort(DATA_ID).getRank();
|
||||||
const float *scales = reinterpret_cast<const float *>(getParentEdgesAtPort(SCALES_ID)[0]->getMemory().GetPtr());
|
|
||||||
std::vector<float> fullScales(dataRank, 1.f);
|
std::vector<float> fullScales(dataRank, 1.f);
|
||||||
const size_t axesRank = axes.size();
|
const size_t axesRank = axes.size();
|
||||||
for (size_t i = 0; i < axesRank; i++) {
|
for (size_t i = 0; i < axesRank; i++) {
|
||||||
|
@ -243,6 +243,8 @@ private:
|
|||||||
|
|
||||||
bool isAxesSpecified = false;
|
bool isAxesSpecified = false;
|
||||||
std::vector<int> axes;
|
std::vector<int> axes;
|
||||||
|
std::vector<float> scales;
|
||||||
|
bool isScaleConstant = false;
|
||||||
|
|
||||||
// 6 ptrs for each quantization, 2 ptrs for each depth_wise
|
// 6 ptrs for each quantization, 2 ptrs for each depth_wise
|
||||||
std::vector<const void*> postOpsDataPtrs;
|
std::vector<const void*> postOpsDataPtrs;
|
||||||
|
Loading…
Reference in New Issue
Block a user