[CPU]FC shape infer with an CPU shapeInfer object (#15092)

* with an shapeInfer object

* more efficient vector creation
This commit is contained in:
Chenhu Wang 2023-02-01 22:35:15 +08:00 committed by GitHub
parent 2d9a213ed6
commit 8fc4c2c6e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -83,6 +83,53 @@ bool FCKey::operator==(const FCKey &rhs) const {
return retVal;
}
class FCShapeInfer : public ShapeInferEmptyPads {
public:
FCShapeInfer(size_t outPut_rank) : out_rank(outPut_rank) {}
std::vector<VectorDims> infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
const VectorDims& activationShape = input_shapes[0].get();
const VectorDims& weightShape = input_shapes[1].get();
size_t activationRank = activationShape.size();
size_t channelRank = weightShape.size() - 1;
// activation weight output_shape
// NCHW CoCHW NCo
// TNC CoC TNCo
// NC CoC NCo
VectorDims outputShape(out_rank, 1);
// set Co
outputShape.back() = weightShape[0];
// set batch dims
size_t batchRank = activationRank - channelRank;
size_t startIdx = out_rank - batchRank - 1;
for (size_t i = 0; i < batchRank; i++) {
outputShape[i + startIdx] = activationShape[i];
}
return {outputShape};
}
port_mask_t get_port_mask() const override {
return EMPTY_PORT_MASK;
}
private:
size_t out_rank = 0;
};
class FCShapeInferFactory : public ShapeInferFactory {
public:
FCShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
ShapeInferPtr makeShapeInfer() const override {
return std::make_shared<FCShapeInfer>(m_op->get_output_partial_shape(0).rank().get_length());
}
private:
std::shared_ptr<const ngraph::Node> m_op;
};
} // namespace
bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
@ -114,7 +161,7 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ngraph::No
}
FullyConnected::FullyConnected(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) {
: Node(op, context, FCShapeInferFactory(op)), withBiases(false) {
std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "FullyConnected node with name '" + getName() + "'";