[CPU] Adopt the static shape inference interface developed by the ngraph (#17719)
This commit is contained in:
parent
d32b6904bd
commit
915de21626
@ -0,0 +1,46 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <vector>
|
||||||
|
#include "cpu_memory.h"
|
||||||
|
#include "openvino/core/shape.hpp"
|
||||||
|
#include "tensor_data_accessor.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_cpu {
|
||||||
|
/**
|
||||||
|
* @brief cpu memory accessor implementing ov::ITensorAccessor to get data as tensor from cpu container.
|
||||||
|
*/
|
||||||
|
class MemoryAccessor : public ov::ITensorAccessor {
|
||||||
|
using container_type = std::unordered_map<size_t, MemoryPtr>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MemoryAccessor(const container_type& ptrs, const std::vector<int64_t>& ranks)
|
||||||
|
: m_ptrs{ptrs}, m_ranks(ranks) {}
|
||||||
|
|
||||||
|
~MemoryAccessor() = default;
|
||||||
|
|
||||||
|
ov::Tensor operator()(size_t port) const override {
|
||||||
|
const auto t_iter = m_ptrs.find(port);
|
||||||
|
if (t_iter != m_ptrs.cend()) {
|
||||||
|
auto memPtr = t_iter->second;
|
||||||
|
// use scalar shape {} instead of {1} if required by shapeInference
|
||||||
|
const auto shape = (m_ranks[port] != 0) ? ov::Shape(memPtr->getStaticDims()) : ov::Shape();
|
||||||
|
return {InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()),
|
||||||
|
shape,
|
||||||
|
memPtr->getData()
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const container_type& m_ptrs; //!< Pointer to cpu memory pointers with op data.
|
||||||
|
const std::vector<int64_t>& m_ranks;
|
||||||
|
};
|
||||||
|
} // namespace intel_cpu
|
||||||
|
} // namespace ov
|
||||||
|
|
@ -3,6 +3,8 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "shape_inference_ngraph.hpp"
|
#include "shape_inference_ngraph.hpp"
|
||||||
|
#include <memory>
|
||||||
|
#include "memory_accessor.hpp"
|
||||||
|
|
||||||
using namespace ov::intel_cpu;
|
using namespace ov::intel_cpu;
|
||||||
|
|
||||||
@ -13,34 +15,19 @@ NgraphShapeInfer::infer(
|
|||||||
const auto& iranks = m_shape_infer->get_input_ranks();
|
const auto& iranks = m_shape_infer->get_input_ranks();
|
||||||
IE_ASSERT(iranks.size() <= input_shapes.size()) << "Too few input shapes passed to Shape infer.";
|
IE_ASSERT(iranks.size() <= input_shapes.size()) << "Too few input shapes passed to Shape infer.";
|
||||||
std::vector<StaticShapeRef> input_static_shapes;
|
std::vector<StaticShapeRef> input_static_shapes;
|
||||||
std::map<size_t, ov::HostTensorPtr> input_values;
|
|
||||||
|
|
||||||
input_static_shapes.reserve(input_shapes.size());
|
input_static_shapes.reserve(input_shapes.size());
|
||||||
|
|
||||||
for (size_t port = 0; port < iranks.size(); port++) {
|
for (size_t port = 0; port < iranks.size(); port++) {
|
||||||
if (iranks[port] == 0) {
|
if (iranks[port] == 0) {
|
||||||
input_static_shapes.emplace_back();
|
input_static_shapes.emplace_back();
|
||||||
} else {
|
} else {
|
||||||
input_static_shapes.emplace_back(input_shapes[port].get());
|
input_static_shapes.emplace_back(input_shapes[port].get());
|
||||||
}
|
}
|
||||||
auto itr = data_dependency.find(port);
|
|
||||||
if (itr != data_dependency.end()) {
|
|
||||||
const auto& memPtr = itr->second;
|
|
||||||
|
|
||||||
ov::Shape shape;
|
|
||||||
|
|
||||||
// use scalar shape {} instead of {1} if required by shapeInference
|
|
||||||
if (iranks[port] != 0) {
|
|
||||||
shape = ov::Shape(memPtr->getStaticDims());
|
|
||||||
}
|
|
||||||
|
|
||||||
input_values[port] = std::make_shared<ngraph::runtime::HostTensor>(
|
|
||||||
InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()),
|
|
||||||
shape,
|
|
||||||
memPtr->getData());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// call shape inference API
|
// call shape inference API
|
||||||
auto shape_infer_result = m_shape_infer->infer(input_static_shapes, ov::make_tensor_accessor(input_values));
|
auto shape_infer_result = m_shape_infer->infer(input_static_shapes, MemoryAccessor(data_dependency, iranks));
|
||||||
|
|
||||||
Result result{{}, shape_infer_result ? ShapeInferStatus::success : ShapeInferStatus::skip};
|
Result result{{}, shape_infer_result ? ShapeInferStatus::success : ShapeInferStatus::skip};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user