diff --git a/src/plugins/intel_cpu/src/nodes/rnn.cpp b/src/plugins/intel_cpu/src/nodes/rnn.cpp index 6dda7ed6784..7b7f9980c2e 100644 --- a/src/plugins/intel_cpu/src/nodes/rnn.cpp +++ b/src/plugins/intel_cpu/src/nodes/rnn.cpp @@ -479,26 +479,30 @@ void RNN::initCell() { else DC = getInputShapeAtPort(2).getDims()[1]; - // Expected shapes. - const Shape shapeD{{N.minVal, DC}, {N.maxVal, DC}}, shapeS{{N.minVal, SC}, {N.maxVal, SC}}; + if (N.isStatic()) { + // Expected shapes. + const auto B = N.minVal; + const Shape shapeD{B, DC}, shapeS{B, SC}; - if ((getInputShapeAtPort(0).isStatic() && getInputShapeAtPort(0) != shapeD) || - (getInputShapeAtPort(1).isStatic() && getInputShapeAtPort(1) != shapeS) || - (getOutputShapeAtPort(0) != shapeS)) { - THROW_ERROR << "has incorrect input/output shapes. Data shape: " << getInputShapeAtPort(0).toString() << - "; Hidden state input: " << getInputShapeAtPort(1).toString() << "; Hidden state output: " << getOutputShapeAtPort(0).toString(); - } + if ((getInputShapeAtPort(0).isStatic() && getInputShapeAtPort(0) != shapeD) || + (getInputShapeAtPort(1).isStatic() && getInputShapeAtPort(1) != shapeS) || + (getOutputShapeAtPort(0).isStatic() && getOutputShapeAtPort(0) != shapeS)) { + THROW_ERROR << "has incorrect input/output shapes. Data shape: " << getInputShapeAtPort(0).toString() << + "; Hidden state input: " << getInputShapeAtPort(1).toString() << "; Hidden state output: " << getOutputShapeAtPort(0).toString(); + } - if (S == 2) { - if ((getInputShapeAtPort(2).isStatic() && getInputShapeAtPort(2) != shapeS) || (getOutputShapeAtPort(1) != shapeS)) - THROW_ERROR << "has incorrect input/output shapes. Cell state input: " << getInputShapeAtPort(2).toString() << - "; Cell state output: " << getOutputShapeAtPort(1).toString(); - } + if (S == 2) { + if ((getInputShapeAtPort(2).isStatic() && getInputShapeAtPort(2) != shapeS) || + (getOutputShapeAtPort(1).isStatic() && getOutputShapeAtPort(1) != shapeS)) + THROW_ERROR << "has incorrect input/output shapes. Cell state input: " << getInputShapeAtPort(2).toString() << + "; Cell state output: " << getOutputShapeAtPort(1).toString(); + } - if (is_augru) { - const Shape shapeA{{N.minVal, 1}, {N.maxVal, 1}}; - if (getInputShapeAtPort(5).isStatic() && getInputShapeAtPort(5) != shapeA) { - THROW_ERROR << "has incorrect input shapes. Attention shape: " << getInputShapeAtPort(5).toString(); + if (is_augru) { + const Shape shapeA{B, 1}; + if (getInputShapeAtPort(5).isStatic() && getInputShapeAtPort(5) != shapeA) { + THROW_ERROR << "has incorrect input shapes. Attention shape: " << getInputShapeAtPort(5).toString(); + } } } } diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/lstm_cell.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/lstm_cell.cpp index 239eb822a33..e18d2b5f6e3 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/lstm_cell.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/lstm_cell.cpp @@ -153,6 +153,12 @@ const std::vector> dynamicShapes = { { {1, 1}, {3, 1}, {5, 1} } }, // Target shapes { { -1, 1 }, // Dynamic shape 2 { {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes + { { { -1, 1 }, // Dynamic shape 0 + { {1, 1}, {5, 1} } }, // Target shapes + { { {1, 5}, 1 }, // Dynamic shape 1 + { {1, 1}, {5, 1} } }, // Target shapes + { { {1, 5}, 1 }, // Dynamic shape 2 + { {1, 1}, {5, 1} } } }, // Target shapes { { { {1, 20}, 30 }, // Dynamic shape 0 { {2, 30}, {5, 30}, {8, 30} } }, // Target shapes { { {1, 20}, 10 }, // Dynamic shape 1