From 8fc4c2c6e1c1d3e354245f31ee6186e22b27ae31 Mon Sep 17 00:00:00 2001 From: Chenhu Wang Date: Wed, 1 Feb 2023 22:35:15 +0800 Subject: [PATCH] [CPU]FC shape infer with an CPU shapeInfer object (#15092) * with an shapeInfer object * more efficient vector creation --- .../intel_cpu/src/nodes/fullyconnected.cpp | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp index 4e863dcbcb2..a4675ee3969 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp @@ -83,6 +83,53 @@ bool FCKey::operator==(const FCKey &rhs) const { return retVal; } +class FCShapeInfer : public ShapeInferEmptyPads { +public: + FCShapeInfer(size_t outPut_rank) : out_rank(outPut_rank) {} + std::vector infer( + const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + const VectorDims& activationShape = input_shapes[0].get(); + const VectorDims& weightShape = input_shapes[1].get(); + size_t activationRank = activationShape.size(); + size_t channelRank = weightShape.size() - 1; + + // activation weight output_shape + // NCHW CoCHW NCo + // TNC CoC TNCo + // NC CoC NCo + VectorDims outputShape(out_rank, 1); + // set Co + outputShape.back() = weightShape[0]; + // set batch dims + size_t batchRank = activationRank - channelRank; + size_t startIdx = out_rank - batchRank - 1; + for (size_t i = 0; i < batchRank; i++) { + outputShape[i + startIdx] = activationShape[i]; + } + + return {outputShape}; + } + + port_mask_t get_port_mask() const override { + return EMPTY_PORT_MASK; + } + +private: + size_t out_rank = 0; +}; + +class FCShapeInferFactory : public ShapeInferFactory { +public: + FCShapeInferFactory(std::shared_ptr op) : m_op(op) {} + ShapeInferPtr makeShapeInfer() const override { + return std::make_shared(m_op->get_output_partial_shape(0).rank().get_length()); + } + +private: + std::shared_ptr m_op; +}; + } // namespace bool FullyConnected::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { @@ -114,7 +161,7 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr& op, const GraphContext::CPtr context) - : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) { + : Node(op, context, FCShapeInferFactory(op)), withBiases(false) { std::string errorMessage; if (isSupportedOperation(op, errorMessage)) { errorPrefix = "FullyConnected node with name '" + getName() + "'";