[CPU] RNN: Shape checks are relaxed (#17724)

This commit is contained in:
Vladislav Golubev
2023-05-31 14:13:44 +02:00
committed by GitHub
parent 65caa9d745
commit b655fa55a1
2 changed files with 27 additions and 17 deletions

View File

@@ -479,26 +479,30 @@ void RNN::initCell() {
else
DC = getInputShapeAtPort(2).getDims()[1];
// Expected shapes.
const Shape shapeD{{N.minVal, DC}, {N.maxVal, DC}}, shapeS{{N.minVal, SC}, {N.maxVal, SC}};
if (N.isStatic()) {
// Expected shapes.
const auto B = N.minVal;
const Shape shapeD{B, DC}, shapeS{B, SC};
if ((getInputShapeAtPort(0).isStatic() && getInputShapeAtPort(0) != shapeD) ||
(getInputShapeAtPort(1).isStatic() && getInputShapeAtPort(1) != shapeS) ||
(getOutputShapeAtPort(0) != shapeS)) {
THROW_ERROR << "has incorrect input/output shapes. Data shape: " << getInputShapeAtPort(0).toString() <<
"; Hidden state input: " << getInputShapeAtPort(1).toString() << "; Hidden state output: " << getOutputShapeAtPort(0).toString();
}
if ((getInputShapeAtPort(0).isStatic() && getInputShapeAtPort(0) != shapeD) ||
(getInputShapeAtPort(1).isStatic() && getInputShapeAtPort(1) != shapeS) ||
(getOutputShapeAtPort(0).isStatic() && getOutputShapeAtPort(0) != shapeS)) {
THROW_ERROR << "has incorrect input/output shapes. Data shape: " << getInputShapeAtPort(0).toString() <<
"; Hidden state input: " << getInputShapeAtPort(1).toString() << "; Hidden state output: " << getOutputShapeAtPort(0).toString();
}
if (S == 2) {
if ((getInputShapeAtPort(2).isStatic() && getInputShapeAtPort(2) != shapeS) || (getOutputShapeAtPort(1) != shapeS))
THROW_ERROR << "has incorrect input/output shapes. Cell state input: " << getInputShapeAtPort(2).toString() <<
"; Cell state output: " << getOutputShapeAtPort(1).toString();
}
if (S == 2) {
if ((getInputShapeAtPort(2).isStatic() && getInputShapeAtPort(2) != shapeS) ||
(getOutputShapeAtPort(1).isStatic() && getOutputShapeAtPort(1) != shapeS))
THROW_ERROR << "has incorrect input/output shapes. Cell state input: " << getInputShapeAtPort(2).toString() <<
"; Cell state output: " << getOutputShapeAtPort(1).toString();
}
if (is_augru) {
const Shape shapeA{{N.minVal, 1}, {N.maxVal, 1}};
if (getInputShapeAtPort(5).isStatic() && getInputShapeAtPort(5) != shapeA) {
THROW_ERROR << "has incorrect input shapes. Attention shape: " << getInputShapeAtPort(5).toString();
if (is_augru) {
const Shape shapeA{B, 1};
if (getInputShapeAtPort(5).isStatic() && getInputShapeAtPort(5) != shapeA) {
THROW_ERROR << "has incorrect input shapes. Attention shape: " << getInputShapeAtPort(5).toString();
}
}
}
}

View File

@@ -153,6 +153,12 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
{ { -1, 1 }, // Dynamic shape 2
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
{ { { -1, 1 }, // Dynamic shape 0
{ {1, 1}, {5, 1} } }, // Target shapes
{ { {1, 5}, 1 }, // Dynamic shape 1
{ {1, 1}, {5, 1} } }, // Target shapes
{ { {1, 5}, 1 }, // Dynamic shape 2
{ {1, 1}, {5, 1} } } }, // Target shapes
{ { { {1, 20}, 30 }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 20}, 10 }, // Dynamic shape 1