[CPU]FC shape infer with an CPU shapeInfer object (#15092)
* with an shapeInfer object * more efficient vector creation
This commit is contained in:
parent
2d9a213ed6
commit
8fc4c2c6e1
@ -83,6 +83,53 @@ bool FCKey::operator==(const FCKey &rhs) const {
|
||||
return retVal;
|
||||
}
|
||||
|
||||
class FCShapeInfer : public ShapeInferEmptyPads {
|
||||
public:
|
||||
FCShapeInfer(size_t outPut_rank) : out_rank(outPut_rank) {}
|
||||
std::vector<VectorDims> infer(
|
||||
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
|
||||
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
|
||||
const VectorDims& activationShape = input_shapes[0].get();
|
||||
const VectorDims& weightShape = input_shapes[1].get();
|
||||
size_t activationRank = activationShape.size();
|
||||
size_t channelRank = weightShape.size() - 1;
|
||||
|
||||
// activation weight output_shape
|
||||
// NCHW CoCHW NCo
|
||||
// TNC CoC TNCo
|
||||
// NC CoC NCo
|
||||
VectorDims outputShape(out_rank, 1);
|
||||
// set Co
|
||||
outputShape.back() = weightShape[0];
|
||||
// set batch dims
|
||||
size_t batchRank = activationRank - channelRank;
|
||||
size_t startIdx = out_rank - batchRank - 1;
|
||||
for (size_t i = 0; i < batchRank; i++) {
|
||||
outputShape[i + startIdx] = activationShape[i];
|
||||
}
|
||||
|
||||
return {outputShape};
|
||||
}
|
||||
|
||||
port_mask_t get_port_mask() const override {
|
||||
return EMPTY_PORT_MASK;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t out_rank = 0;
|
||||
};
|
||||
|
||||
class FCShapeInferFactory : public ShapeInferFactory {
|
||||
public:
|
||||
FCShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
|
||||
ShapeInferPtr makeShapeInfer() const override {
|
||||
return std::make_shared<FCShapeInfer>(m_op->get_output_partial_shape(0).rank().get_length());
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<const ngraph::Node> m_op;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
@ -114,7 +161,7 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ngraph::No
|
||||
}
|
||||
|
||||
FullyConnected::FullyConnected(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
|
||||
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)), withBiases(false) {
|
||||
: Node(op, context, FCShapeInferFactory(op)), withBiases(false) {
|
||||
std::string errorMessage;
|
||||
if (isSupportedOperation(op, errorMessage)) {
|
||||
errorPrefix = "FullyConnected node with name '" + getName() + "'";
|
||||
|
Loading…
Reference in New Issue
Block a user