[CPU] RNN: Shape checks are relaxed (#17724)
This commit is contained in:
committed by
GitHub
parent
65caa9d745
commit
b655fa55a1
@@ -479,26 +479,30 @@ void RNN::initCell() {
|
||||
else
|
||||
DC = getInputShapeAtPort(2).getDims()[1];
|
||||
|
||||
// Expected shapes.
|
||||
const Shape shapeD{{N.minVal, DC}, {N.maxVal, DC}}, shapeS{{N.minVal, SC}, {N.maxVal, SC}};
|
||||
if (N.isStatic()) {
|
||||
// Expected shapes.
|
||||
const auto B = N.minVal;
|
||||
const Shape shapeD{B, DC}, shapeS{B, SC};
|
||||
|
||||
if ((getInputShapeAtPort(0).isStatic() && getInputShapeAtPort(0) != shapeD) ||
|
||||
(getInputShapeAtPort(1).isStatic() && getInputShapeAtPort(1) != shapeS) ||
|
||||
(getOutputShapeAtPort(0) != shapeS)) {
|
||||
THROW_ERROR << "has incorrect input/output shapes. Data shape: " << getInputShapeAtPort(0).toString() <<
|
||||
"; Hidden state input: " << getInputShapeAtPort(1).toString() << "; Hidden state output: " << getOutputShapeAtPort(0).toString();
|
||||
}
|
||||
if ((getInputShapeAtPort(0).isStatic() && getInputShapeAtPort(0) != shapeD) ||
|
||||
(getInputShapeAtPort(1).isStatic() && getInputShapeAtPort(1) != shapeS) ||
|
||||
(getOutputShapeAtPort(0).isStatic() && getOutputShapeAtPort(0) != shapeS)) {
|
||||
THROW_ERROR << "has incorrect input/output shapes. Data shape: " << getInputShapeAtPort(0).toString() <<
|
||||
"; Hidden state input: " << getInputShapeAtPort(1).toString() << "; Hidden state output: " << getOutputShapeAtPort(0).toString();
|
||||
}
|
||||
|
||||
if (S == 2) {
|
||||
if ((getInputShapeAtPort(2).isStatic() && getInputShapeAtPort(2) != shapeS) || (getOutputShapeAtPort(1) != shapeS))
|
||||
THROW_ERROR << "has incorrect input/output shapes. Cell state input: " << getInputShapeAtPort(2).toString() <<
|
||||
"; Cell state output: " << getOutputShapeAtPort(1).toString();
|
||||
}
|
||||
if (S == 2) {
|
||||
if ((getInputShapeAtPort(2).isStatic() && getInputShapeAtPort(2) != shapeS) ||
|
||||
(getOutputShapeAtPort(1).isStatic() && getOutputShapeAtPort(1) != shapeS))
|
||||
THROW_ERROR << "has incorrect input/output shapes. Cell state input: " << getInputShapeAtPort(2).toString() <<
|
||||
"; Cell state output: " << getOutputShapeAtPort(1).toString();
|
||||
}
|
||||
|
||||
if (is_augru) {
|
||||
const Shape shapeA{{N.minVal, 1}, {N.maxVal, 1}};
|
||||
if (getInputShapeAtPort(5).isStatic() && getInputShapeAtPort(5) != shapeA) {
|
||||
THROW_ERROR << "has incorrect input shapes. Attention shape: " << getInputShapeAtPort(5).toString();
|
||||
if (is_augru) {
|
||||
const Shape shapeA{B, 1};
|
||||
if (getInputShapeAtPort(5).isStatic() && getInputShapeAtPort(5) != shapeA) {
|
||||
THROW_ERROR << "has incorrect input shapes. Attention shape: " << getInputShapeAtPort(5).toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,6 +153,12 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
||||
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
|
||||
{ { -1, 1 }, // Dynamic shape 2
|
||||
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
|
||||
{ { { -1, 1 }, // Dynamic shape 0
|
||||
{ {1, 1}, {5, 1} } }, // Target shapes
|
||||
{ { {1, 5}, 1 }, // Dynamic shape 1
|
||||
{ {1, 1}, {5, 1} } }, // Target shapes
|
||||
{ { {1, 5}, 1 }, // Dynamic shape 2
|
||||
{ {1, 1}, {5, 1} } } }, // Target shapes
|
||||
{ { { {1, 20}, 30 }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 20}, 10 }, // Dynamic shape 1
|
||||
|
||||
Reference in New Issue
Block a user