[CPU] Optimize shape_infer time of matmul (#15601)

* optimize shape infer of matmul

* remove some redundant check

* fix some comments

* fix some comments

* Update src/plugins/intel_cpu/src/nodes/matmul.cpp

fix an implicit bug

Co-authored-by: Maksim Kutakov <maxim.kutakov@gmail.com>

* Update src/plugins/intel_cpu/src/nodes/matmul.cpp

optimize by using OV RTTI instead of dynamic_pointer_cast

Co-authored-by: Maksim Kutakov <maxim.kutakov@gmail.com>

---------

Co-authored-by: Maksim Kutakov <maxim.kutakov@gmail.com>
This commit is contained in:
Xiuchuan Zhai 2023-02-15 10:07:39 +08:00 committed by GitHub
parent bd4d74d3dc
commit 36cb32a8f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -110,8 +110,80 @@ bool MatMul::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op,
return true;
}
class MMShapeInfer : public ShapeInferEmptyPads {
public:
MMShapeInfer(const size_t& out_rank, const bool& transpose_a, const bool& transpose_b) :
m_out_rank(out_rank), m_transpose_a(transpose_a), m_transpose_b(transpose_b) {
m_shapeY = VectorDims(m_out_rank, 1); // for output and cache
}
std::vector<VectorDims> infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
const VectorDims& shapeA = input_shapes[0].get();
const VectorDims& shapeB = input_shapes[1].get();
const size_t rankA = shapeA.size();
const size_t rankB = shapeB.size();
// getSupportedDescriptors has done some shape check.
// 1. Needn't assert the scalar type since the matmul_shape_inference has checked.
// 2. Needn't check the compatibility of the last two dims
// 3. 1-D x 1-D is needed
// 4. transpose is necessary
// 5. Just support the same rank of matmul
// 6. simplify the broadcast check
if (rankA == 1 && rankB == 1 && shapeA[0] == shapeB[0]) {
return {m_shapeY};
}
m_shapeY[m_out_rank-2] = m_transpose_a ? shapeA[rankA-1] : shapeA[rankA-2];
m_shapeY[m_out_rank-1] = m_transpose_b ? shapeB[rankB-2] : shapeB[rankB-1];
for (size_t i=0; i < m_out_rank-2; ++i) {
if (shapeA[i] != shapeB[i]) {
if (shapeB[i] == 1) {
m_shapeY[i] = shapeA[i];
continue;
} else if (shapeA[i] != 1) {
IE_THROW() << "Incompatible MatMul batch dimension. Cant merge the first input dimension=" <<
shapeA[i] << " with second input dimension=" << shapeB[i] << " at index=" << i;
}
}
m_shapeY[i] = shapeB[i];
}
return {m_shapeY};
}
port_mask_t get_port_mask() const override {
return EMPTY_PORT_MASK;
}
private:
VectorDims m_shapeY;
const size_t m_out_rank;
const bool m_transpose_a;
const bool m_transpose_b;
};
class MMShapeInferFactory : public ShapeInferFactory {
public:
MMShapeInferFactory(const std::shared_ptr<ngraph::Node>& op) : m_op(op) {}
ShapeInferPtr makeShapeInfer() const override {
if (const auto matmul = ov::as_type_ptr<const ngraph::opset1::MatMul>(m_op)) {
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
const bool transpose_a = matmul->get_transpose_a();
const bool transpose_b = matmul->get_transpose_b();
return std::make_shared<MMShapeInfer>(output_rank, transpose_a, transpose_b);
} else {
IE_THROW() << "Unexpected operation type in the MatMul shape inference factory";
}
}
private:
std::shared_ptr<ngraph::Node> m_op;
};
MatMul::MatMul(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) {
Node(op, context, MMShapeInferFactory(op)), withBiases(false) {
std::string errorMessage;
errorPrefix = "MatMul node with name '" + getName() + "'";