[CPU] Optimize shape_infer time of matmul (#15601)
* optimize shape infer of matmul * remove some redundant check * fix some comments * fix some comments * Update src/plugins/intel_cpu/src/nodes/matmul.cpp fix an implicit bug Co-authored-by: Maksim Kutakov <maxim.kutakov@gmail.com> * Update src/plugins/intel_cpu/src/nodes/matmul.cpp optimize by using OV RTTI instead of dynamic_pointer_cast Co-authored-by: Maksim Kutakov <maxim.kutakov@gmail.com> --------- Co-authored-by: Maksim Kutakov <maxim.kutakov@gmail.com>
This commit is contained in:
parent
bd4d74d3dc
commit
36cb32a8f5
@ -110,8 +110,80 @@ bool MatMul::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op,
|
||||
return true;
|
||||
}
|
||||
|
||||
class MMShapeInfer : public ShapeInferEmptyPads {
|
||||
public:
|
||||
MMShapeInfer(const size_t& out_rank, const bool& transpose_a, const bool& transpose_b) :
|
||||
m_out_rank(out_rank), m_transpose_a(transpose_a), m_transpose_b(transpose_b) {
|
||||
m_shapeY = VectorDims(m_out_rank, 1); // for output and cache
|
||||
}
|
||||
std::vector<VectorDims> infer(
|
||||
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
|
||||
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
|
||||
const VectorDims& shapeA = input_shapes[0].get();
|
||||
const VectorDims& shapeB = input_shapes[1].get();
|
||||
const size_t rankA = shapeA.size();
|
||||
const size_t rankB = shapeB.size();
|
||||
|
||||
// getSupportedDescriptors has done some shape check.
|
||||
// 1. Needn't assert the scalar type since the matmul_shape_inference has checked.
|
||||
// 2. Needn't check the compatibility of the last two dims
|
||||
// 3. 1-D x 1-D is needed
|
||||
// 4. transpose is necessary
|
||||
// 5. Just support the same rank of matmul
|
||||
// 6. simplify the broadcast check
|
||||
if (rankA == 1 && rankB == 1 && shapeA[0] == shapeB[0]) {
|
||||
return {m_shapeY};
|
||||
}
|
||||
|
||||
m_shapeY[m_out_rank-2] = m_transpose_a ? shapeA[rankA-1] : shapeA[rankA-2];
|
||||
m_shapeY[m_out_rank-1] = m_transpose_b ? shapeB[rankB-2] : shapeB[rankB-1];
|
||||
|
||||
for (size_t i=0; i < m_out_rank-2; ++i) {
|
||||
if (shapeA[i] != shapeB[i]) {
|
||||
if (shapeB[i] == 1) {
|
||||
m_shapeY[i] = shapeA[i];
|
||||
continue;
|
||||
} else if (shapeA[i] != 1) {
|
||||
IE_THROW() << "Incompatible MatMul batch dimension. Cant merge the first input dimension=" <<
|
||||
shapeA[i] << " with second input dimension=" << shapeB[i] << " at index=" << i;
|
||||
}
|
||||
}
|
||||
m_shapeY[i] = shapeB[i];
|
||||
}
|
||||
|
||||
return {m_shapeY};
|
||||
}
|
||||
|
||||
port_mask_t get_port_mask() const override {
|
||||
return EMPTY_PORT_MASK;
|
||||
}
|
||||
|
||||
private:
|
||||
VectorDims m_shapeY;
|
||||
const size_t m_out_rank;
|
||||
const bool m_transpose_a;
|
||||
const bool m_transpose_b;
|
||||
};
|
||||
|
||||
class MMShapeInferFactory : public ShapeInferFactory {
|
||||
public:
|
||||
MMShapeInferFactory(const std::shared_ptr<ngraph::Node>& op) : m_op(op) {}
|
||||
ShapeInferPtr makeShapeInfer() const override {
|
||||
if (const auto matmul = ov::as_type_ptr<const ngraph::opset1::MatMul>(m_op)) {
|
||||
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
|
||||
const bool transpose_a = matmul->get_transpose_a();
|
||||
const bool transpose_b = matmul->get_transpose_b();
|
||||
return std::make_shared<MMShapeInfer>(output_rank, transpose_a, transpose_b);
|
||||
} else {
|
||||
IE_THROW() << "Unexpected operation type in the MatMul shape inference factory";
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::shared_ptr<ngraph::Node> m_op;
|
||||
};
|
||||
|
||||
MatMul::MatMul(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
|
||||
Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) {
|
||||
Node(op, context, MMShapeInferFactory(op)), withBiases(false) {
|
||||
std::string errorMessage;
|
||||
errorPrefix = "MatMul node with name '" + getName() + "'";
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user