diff --git a/src/plugins/intel_cpu/src/nodes/reshape.cpp b/src/plugins/intel_cpu/src/nodes/reshape.cpp index f9cb43ab7a6..2398140996a 100644 --- a/src/plugins/intel_cpu/src/nodes/reshape.cpp +++ b/src/plugins/intel_cpu/src/nodes/reshape.cpp @@ -3,6 +3,7 @@ // #include "reshape.h" +#include "utils.hpp" #include #include #include @@ -10,6 +11,7 @@ #include #include #include +#include "utils/shape_inference/shape_inference_cpu.hpp" #include "common/cpu_memcpy.h" @@ -34,8 +36,193 @@ bool Reshape::isSupportedOperation(const std::shared_ptr& op return true; } +namespace { +class ReshapeShapeInfer : public ShapeInferEmptyPads { +public: + ReshapeShapeInfer(bool specialZero) : m_specialZero(specialZero) {} + Result infer(const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + static constexpr size_t RESHAPE_SRC = 0, RESHAPE_PATTERN = 1; + const auto& inputShape = input_shapes[RESHAPE_SRC].get(); + const size_t inputShapeSize = inputShape.size(); + const auto memPtr = data_dependency.at(RESHAPE_PATTERN); + const auto data = memPtr->GetPtr(); + // const auto outputPatternSize = shape_size(ov::Shape(memPtr->getStaticDims())); + const auto& dims = memPtr->getStaticDims(); + const auto outputPatternSize = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + std::vector outPattern = ov::get_raw_data_as( + InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()), + data, + outputPatternSize, + ov::util::Cast()); + VectorDims outputShape(outputPatternSize); + size_t outputProduct(1); + int32_t minusOneIdx = -1; + int32_t minusOneCount = 0; + for (size_t i = 0; i < outputPatternSize; ++i) { + if (outPattern[i] == 0 && m_specialZero && i < inputShapeSize) { + outputShape[i] = inputShape[i]; + outputProduct *= outputShape[i]; + } else if (outPattern[i] == -1) { + minusOneIdx = i; + minusOneCount++; + } else { + outputShape[i] = outPattern[i]; + outputProduct *= outputShape[i]; + } + } + size_t inputProduct(1); + for (size_t i = 0; i < inputShapeSize; ++i) { + inputProduct *= inputShape[i]; + } + if (outputProduct != 0 && minusOneIdx >= 0) { + outputShape[minusOneIdx] = inputProduct / outputProduct; + outputProduct *= outputShape[minusOneIdx]; + } + if (minusOneCount > 1 || inputProduct != outputProduct) { + IE_THROW(Unexpected) << "[cpu]reshape: the shape of input data conflicts with the reshape pattern"; + } + return {{std::move(outputShape)}, ShapeInferStatus::success}; + } + port_mask_t get_port_mask() const override { + return PortMask(1); + } + +private: + bool m_specialZero; +}; + +class SqueezeShapeInfer : public ShapeInferEmptyPads { +public: + SqueezeShapeInfer() {} + Result infer(const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + static constexpr size_t SQUEEZE_SRC = 0, SQUEEZE_PATTERN = 1; + const auto& inputShape = input_shapes[SQUEEZE_SRC].get(); + const size_t inputShapeSize = inputShape.size(); + auto itr = data_dependency.find(SQUEEZE_PATTERN); + VectorDims outputShape; + if (itr != data_dependency.end()) { + const auto memPtr = data_dependency.at(SQUEEZE_PATTERN); + const auto data = memPtr->GetPtr(); + const auto& dims = memPtr->getStaticDims(); + const auto outputPatternSize = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + std::vector outPattern = ov::get_raw_data_as( + InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()), + data, + outputPatternSize, + ov::util::Cast()); + std::vector removeMask(inputShapeSize, false); + bool existError = false; + for (size_t i = 0; i < outputPatternSize; i++) { + if (outPattern[i] < 0) { + outPattern[i] = inputShapeSize + outPattern[i]; + } + if (outPattern[i] >= 0 && outPattern[i] < static_cast(inputShapeSize)) { + removeMask[outPattern[i]] = true; + } else { + existError = true; + break; + } + } + for (size_t i = 0; i < inputShapeSize; i++) { + if (!removeMask[i]) { + outputShape.push_back(inputShape[i]); + } else if (inputShape[i] != 1) { + existError = true; + break; + } + } + if (existError) { + IE_THROW(Unexpected) << "[cpu]squeeze: the shape of input data conflict with the squeeze pattern"; + } + } else { + for (size_t i = 0; i < inputShapeSize; i++) { + if (inputShape[i] != 1) { + outputShape.push_back(inputShape[i]); + } + } + } + return {{std::move(outputShape)}, ShapeInferStatus::success}; + } + port_mask_t get_port_mask() const override { + return PortMask(1); + } +}; + +class UnsqueezeShapeInfer : public ShapeInferEmptyPads { +public: + UnsqueezeShapeInfer() {} + Result infer(const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + static constexpr size_t UNSQUEEZE_SRC = 0, UNSQUEEZE_PATTERN = 1; + const auto& inputShape = input_shapes[UNSQUEEZE_SRC].get(); + const size_t inputShapeSize = inputShape.size(); + const auto memPtr = data_dependency.at(UNSQUEEZE_PATTERN); + const auto data = memPtr->GetPtr(); + const auto& dims = memPtr->getStaticDims(); + const auto outputPatternSize = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + std::vector outPattern = ov::get_raw_data_as( + InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()), + data, + outputPatternSize, + ov::util::Cast()); + size_t outputShapeSize = inputShapeSize + outputPatternSize; + VectorDims outputShape(outputShapeSize, 0); + bool existError = false; + for (size_t i = 0; i < outputPatternSize; i++) { + if (outPattern[i] < 0) { + outPattern[i] = outputShapeSize + outPattern[i]; + } + if (outPattern[i] >= 0 && outPattern[i] < static_cast(outputShapeSize)) { + outputShape[outPattern[i]] = 1; + } else { + existError = true; + break; + } + } + for (size_t i = 0, y = 0; i < outputShapeSize; i++) { + if (outputShape[i] == 0) { + if (y < inputShapeSize) { + outputShape[i] = inputShape[y]; + y++; + } else { + existError = true; + break; + } + } + } + if (existError) { + IE_THROW(Unexpected) << "[cpu]unsqueeze: the shape of input data conflicts with the unsqueeze pattern"; + } + return {{std::move(outputShape)}, ShapeInferStatus::success}; + } + port_mask_t get_port_mask() const override { + return PortMask(1); + } +}; + +class ReshapeShapeInferFactory : public ShapeInferFactory { +public: + ReshapeShapeInferFactory(std::shared_ptr op) : m_op(op) {} + ShapeInferPtr makeShapeInfer() const override { + if (const auto reshapeOp = ov::as_type_ptr(m_op)) { + return std::make_shared(reshapeOp->get_special_zero()); + } else if (ov::is_type(m_op)) { + return std::make_shared(); + } else if (ov::is_type(m_op)) { + return std::make_shared(); + } else { + IE_THROW(Unexpected) << "[cpu]reshape: " << m_op->get_type_name() << "is not implemented"; + } + } +private: + std::shared_ptr m_op; +}; +} // namespace + Reshape::Reshape(const std::shared_ptr& op, const GraphContext::CPtr context) : - Node(op, context, NgraphShapeInferFactory(op, PortMask(1))) { + Node(op, context, ReshapeShapeInferFactory(op)) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { IE_THROW(NotImplemented) << errorMessage; diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/shape_ops.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/shape_ops.cpp index 8faeb80f634..775579fd0ed 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/shape_ops.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/shape_ops.cpp @@ -45,6 +45,7 @@ using shapeOpsParams = std::tuple< ngraph::helpers::InputLayerType, // second input type shapeNodeType, // node type Precision, // precision + ngraph::element::Type_t, // second input precision bool>; // special zero class ShapeOpsCPUTest : public testing::WithParamInterface, virtual public SubgraphBaseTest, public CPUTestsBase { @@ -55,7 +56,8 @@ public: shapeNodeType nodeType; Precision prc; bool specialZero; - std::tie(inpDesc, secondType, nodeType, prc, specialZero) = obj.param; + element::Type_t tmpSecondInPrc; + std::tie(inpDesc, secondType, nodeType, prc, tmpSecondInPrc, specialZero) = obj.param; std::ostringstream result; result << nodeType << "_"; @@ -72,6 +74,7 @@ public: } result << "PRC=" << prc << "_"; result << "specialZero=" << specialZero; + result << "_secondInPrc=" << tmpSecondInPrc; return result.str(); } @@ -84,10 +87,21 @@ protected: const auto& funcInput = funcInputs[i]; ov::runtime::Tensor tensor; if (i == 1) { - tensor = ov::runtime::Tensor{ov::element::i32, targetInputStaticShapes[i]}; - auto inputData = tensor.data::value_type>(); - for (size_t j = 0lu; j < data[idx].size(); ++j) { - inputData[j] = data[idx][j]; +#define RESHAPE_TEST_CASE(INT_TYPE) \ + case ov::element::Type_t::INT_TYPE: { \ + tensor = ov::runtime::Tensor{ov::element::INT_TYPE, targetInputStaticShapes[i]}; \ + auto inputData = tensor.data::value_type>(); \ + for (size_t j = 0lu; j < data[idx].size(); ++j) { \ + inputData[j] = data[idx][j]; \ + } \ + break; \ + } + switch (secondInPrc) { + RESHAPE_TEST_CASE(i64) + RESHAPE_TEST_CASE(i32) + default: + FAIL() << "We shouldn't get here."; +#undef RESHAPE_TEST_CASE } } else { if (funcInput.get_element_type().is_real()) { @@ -110,7 +124,7 @@ protected: shapeNodeType nodeType; Precision prc; bool specialZero; - std::tie(inpDesc, secondType, nodeType, prc, specialZero) = this->GetParam(); + std::tie(inpDesc, secondType, nodeType, prc, secondInPrc, specialZero) = this->GetParam(); selectedType = std::string("unknown_") + prc.name(); @@ -123,7 +137,6 @@ protected: init_input_shapes(inputShapes); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(prc); - const auto secondInPrc = ngraph::element::Type_t::i32; auto inputs = ngraph::builder::makeDynamicParams(ngPrc, {inputDynamicShapes.front()}); auto dataInput = inputs.front(); dataInput->set_friendly_name("param_1"); @@ -158,6 +171,7 @@ protected: private: std::vector> data; size_t idx; + element::Type_t secondInPrc; }; TEST_P(ShapeOpsCPUTest, CompareWithRefs) { @@ -166,6 +180,7 @@ TEST_P(ShapeOpsCPUTest, CompareWithRefs) { } namespace reshapeTest { +const std::vector secondInPrcs{ov::element::Type_t::i64, ov::element::Type_t::i32}; inputDescription noBounds{{{-1, -1, -1, -1}, {ngraph::Shape{2, 5, 7, 3}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{1, 2, 5, 5}}}, @@ -175,6 +190,7 @@ const auto params = ::testing::Combine(::testing::Values(noBounds), ::testing::Values(ngraph::helpers::InputLayerType::PARAMETER), ::testing::Values(shapeNodeType::Reshape), ::testing::Values(Precision::FP32), + ::testing::ValuesIn(secondInPrcs), ::testing::Values(true)); INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic, ShapeOpsCPUTest, params, ShapeOpsCPUTest::getTestCaseName); @@ -187,6 +203,7 @@ const auto params_const = ::testing::Combine(::testing::Values(noBounds_const), ::testing::Values(ngraph::helpers::InputLayerType::CONSTANT), ::testing::Values(shapeNodeType::Reshape), ::testing::Values(Precision::FP32), + ::testing::ValuesIn(secondInPrcs), ::testing::Values(true)); INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, params_const, ShapeOpsCPUTest::getTestCaseName); @@ -199,6 +216,7 @@ const auto params_dynBatch = ::testing::Combine(::testing::Values(shape_dynBatch ::testing::Values(ngraph::helpers::InputLayerType::CONSTANT), ::testing::Values(shapeNodeType::Reshape), ::testing::Values(Precision::FP32), + ::testing::ValuesIn(secondInPrcs), ::testing::Values(true)); INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynBatch, ShapeOpsCPUTest, params_dynBatch, ShapeOpsCPUTest::getTestCaseName); @@ -206,7 +224,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynBatch, ShapeOpsCPUTest, params } // namespace reshapeTest namespace squeezeTest { - +const std::vector secondInPrcs{ov::element::Type_t::i64, ov::element::Type_t::i32}; inputDescription noBounds{{{-1, -1, -1, -1, -1, -1}, { ngraph::Shape{2, 5, 1, 7, 3, 1}, @@ -220,6 +238,7 @@ const auto params = ::testing::Combine(::testing::Values(noBounds), ::testing::Values(ngraph::helpers::InputLayerType::PARAMETER), ::testing::Values(shapeNodeType::Squeeze), ::testing::Values(Precision::FP32), + ::testing::ValuesIn(secondInPrcs), ::testing::Values(true)); // at this momemnt squeze produce dynamic output rank, if second input is not constant @@ -234,6 +253,7 @@ const auto params_const = ::testing::Combine(::testing::Values(noBounds_const), ::testing::Values(ngraph::helpers::InputLayerType::CONSTANT), ::testing::Values(shapeNodeType::Squeeze), ::testing::Values(Precision::FP32), + ::testing::ValuesIn(secondInPrcs), ::testing::Values(true)); INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, params_const, ShapeOpsCPUTest::getTestCaseName); @@ -241,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, p } // namespace squeezeTest namespace unsqueezeTest { - +const std::vector secondInPrcs{ov::element::Type_t::i64, ov::element::Type_t::i32}; inputDescription noBounds{{{-1, -1, -1, -1}, {ngraph::Shape{2, 5, 7, 3}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{10, 6, 10, 5}, ngraph::Shape{5, 1, 5}}}, {std::vector{2, 5}, std::vector{1, 2}, std::vector{4, 5}, std::vector{0, 1}}}; @@ -250,6 +270,7 @@ const auto params = ::testing::Combine(::testing::Values(noBounds), ::testing::Values(ngraph::helpers::InputLayerType::PARAMETER), ::testing::Values(shapeNodeType::Unsqueeze), ::testing::Values(Precision::FP32), + ::testing::ValuesIn(secondInPrcs), ::testing::Values(true)); // at this momemnt unsqueze produce dynamic output rank, if second input is not constant @@ -264,6 +285,7 @@ const auto params_const = ::testing::Combine(::testing::Values(noBounds_const), ::testing::Values(ngraph::helpers::InputLayerType::CONSTANT), ::testing::Values(shapeNodeType::Unsqueeze), ::testing::Values(Precision::FP32), + ::testing::ValuesIn(secondInPrcs), ::testing::Values(true)); INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_dynamic_const, ShapeOpsCPUTest, params_const, ShapeOpsCPUTest::getTestCaseName);