diff --git a/src/plugins/intel_cpu/src/nodes/matmul.cpp b/src/plugins/intel_cpu/src/nodes/matmul.cpp index 55ecf7b9134..dcf9ea3d6c0 100644 --- a/src/plugins/intel_cpu/src/nodes/matmul.cpp +++ b/src/plugins/intel_cpu/src/nodes/matmul.cpp @@ -110,8 +110,80 @@ bool MatMul::isSupportedOperation(const std::shared_ptr& op, return true; } +class MMShapeInfer : public ShapeInferEmptyPads { +public: + MMShapeInfer(const size_t& out_rank, const bool& transpose_a, const bool& transpose_b) : + m_out_rank(out_rank), m_transpose_a(transpose_a), m_transpose_b(transpose_b) { + m_shapeY = VectorDims(m_out_rank, 1); // for output and cache + } + std::vector infer( + const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + const VectorDims& shapeA = input_shapes[0].get(); + const VectorDims& shapeB = input_shapes[1].get(); + const size_t rankA = shapeA.size(); + const size_t rankB = shapeB.size(); + + // getSupportedDescriptors has done some shape check. + // 1. Needn't assert the scalar type since the matmul_shape_inference has checked. + // 2. Needn't check the compatibility of the last two dims + // 3. 1-D x 1-D is needed + // 4. transpose is necessary + // 5. Just support the same rank of matmul + // 6. simplify the broadcast check + if (rankA == 1 && rankB == 1 && shapeA[0] == shapeB[0]) { + return {m_shapeY}; + } + + m_shapeY[m_out_rank-2] = m_transpose_a ? shapeA[rankA-1] : shapeA[rankA-2]; + m_shapeY[m_out_rank-1] = m_transpose_b ? shapeB[rankB-2] : shapeB[rankB-1]; + + for (size_t i=0; i < m_out_rank-2; ++i) { + if (shapeA[i] != shapeB[i]) { + if (shapeB[i] == 1) { + m_shapeY[i] = shapeA[i]; + continue; + } else if (shapeA[i] != 1) { + IE_THROW() << "Incompatible MatMul batch dimension. Cant merge the first input dimension=" << + shapeA[i] << " with second input dimension=" << shapeB[i] << " at index=" << i; + } + } + m_shapeY[i] = shapeB[i]; + } + + return {m_shapeY}; + } + + port_mask_t get_port_mask() const override { + return EMPTY_PORT_MASK; + } + +private: + VectorDims m_shapeY; + const size_t m_out_rank; + const bool m_transpose_a; + const bool m_transpose_b; +}; + +class MMShapeInferFactory : public ShapeInferFactory { +public: + MMShapeInferFactory(const std::shared_ptr& op) : m_op(op) {} + ShapeInferPtr makeShapeInfer() const override { + if (const auto matmul = ov::as_type_ptr(m_op)) { + const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length(); + const bool transpose_a = matmul->get_transpose_a(); + const bool transpose_b = matmul->get_transpose_b(); + return std::make_shared(output_rank, transpose_a, transpose_b); + } else { + IE_THROW() << "Unexpected operation type in the MatMul shape inference factory"; + } + } +private: + std::shared_ptr m_op; +}; + MatMul::MatMul(const std::shared_ptr& op, const GraphContext::CPtr context) : - Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) { + Node(op, context, MMShapeInferFactory(op)), withBiases(false) { std::string errorMessage; errorPrefix = "MatMul node with name '" + getName() + "'";