[LPT] Turn back checks in reshape transformation when subtract is absent (#10939)

This commit is contained in:
Vladimir Zinoviev 2022-03-15 11:34:12 +03:00 committed by GitHub
parent ef5ad90dd7
commit 4a2d0f39dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -236,21 +236,19 @@ bool ReshapeTransformation::canBeTransformed(const TransformationContext& contex
multiplyShapeWithBatch.insert(multiplyShapeWithBatch.begin(), 1ul); multiplyShapeWithBatch.insert(multiplyShapeWithBatch.begin(), 1ul);
} }
if (subtractShapeWithBatch.size() > 1 && multiplyShapeWithBatch.size() > 1) {
const size_t outputChannel = static_cast<size_t>(outputPShape[1].get_length()); const size_t outputChannel = static_cast<size_t>(outputPShape[1].get_length());
if (!subtractShapeWithBatch.empty() && (outputChannel < subtractShapeWithBatch[1])) { if ((subtractShapeWithBatch.size() > 1) && (outputChannel < subtractShapeWithBatch[1])) {
return false; return false;
} }
if (!multiplyShapeWithBatch.empty() && (outputChannel < multiplyShapeWithBatch[1])) { if ((multiplyShapeWithBatch.size() > 1) && (outputChannel < multiplyShapeWithBatch[1])) {
return false; return false;
} }
if (outputPShape.is_static() && if (outputPShape.is_static() &&
((!subtractShapeWithBatch.empty() && ((outputChannel % subtractShapeWithBatch[1]) != 0)) || (((subtractShapeWithBatch.size() > 1) && ((outputChannel % subtractShapeWithBatch[1]) != 0)) ||
(!multiplyShapeWithBatch.empty() && (outputChannel % multiplyShapeWithBatch[1] != 0)))) { ((multiplyShapeWithBatch.size() > 1) && (outputChannel % multiplyShapeWithBatch[1] != 0)))) {
return false; return false;
} }
}
return canBeTransformed(subtractShapeWithBatch, multiplyShapeWithBatch, inputPShape, outputPShape); return canBeTransformed(subtractShapeWithBatch, multiplyShapeWithBatch, inputPShape, outputPShape);
} }