[CPU] Adopt the static shape inference interface developed by the ngraph (#17719)
This commit is contained in:
parent
d32b6904bd
commit
915de21626
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "cpu_memory.h"
|
||||
#include "openvino/core/shape.hpp"
|
||||
#include "tensor_data_accessor.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
/**
|
||||
* @brief cpu memory accessor implementing ov::ITensorAccessor to get data as tensor from cpu container.
|
||||
*/
|
||||
class MemoryAccessor : public ov::ITensorAccessor {
|
||||
using container_type = std::unordered_map<size_t, MemoryPtr>;
|
||||
|
||||
public:
|
||||
MemoryAccessor(const container_type& ptrs, const std::vector<int64_t>& ranks)
|
||||
: m_ptrs{ptrs}, m_ranks(ranks) {}
|
||||
|
||||
~MemoryAccessor() = default;
|
||||
|
||||
ov::Tensor operator()(size_t port) const override {
|
||||
const auto t_iter = m_ptrs.find(port);
|
||||
if (t_iter != m_ptrs.cend()) {
|
||||
auto memPtr = t_iter->second;
|
||||
// use scalar shape {} instead of {1} if required by shapeInference
|
||||
const auto shape = (m_ranks[port] != 0) ? ov::Shape(memPtr->getStaticDims()) : ov::Shape();
|
||||
return {InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()),
|
||||
shape,
|
||||
memPtr->getData()
|
||||
};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const container_type& m_ptrs; //!< Pointer to cpu memory pointers with op data.
|
||||
const std::vector<int64_t>& m_ranks;
|
||||
};
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -3,6 +3,8 @@
|
||||
//
|
||||
|
||||
#include "shape_inference_ngraph.hpp"
|
||||
#include <memory>
|
||||
#include "memory_accessor.hpp"
|
||||
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
@ -13,34 +15,19 @@ NgraphShapeInfer::infer(
|
||||
const auto& iranks = m_shape_infer->get_input_ranks();
|
||||
IE_ASSERT(iranks.size() <= input_shapes.size()) << "Too few input shapes passed to Shape infer.";
|
||||
std::vector<StaticShapeRef> input_static_shapes;
|
||||
std::map<size_t, ov::HostTensorPtr> input_values;
|
||||
|
||||
input_static_shapes.reserve(input_shapes.size());
|
||||
|
||||
for (size_t port = 0; port < iranks.size(); port++) {
|
||||
if (iranks[port] == 0) {
|
||||
input_static_shapes.emplace_back();
|
||||
} else {
|
||||
input_static_shapes.emplace_back(input_shapes[port].get());
|
||||
}
|
||||
auto itr = data_dependency.find(port);
|
||||
if (itr != data_dependency.end()) {
|
||||
const auto& memPtr = itr->second;
|
||||
|
||||
ov::Shape shape;
|
||||
|
||||
// use scalar shape {} instead of {1} if required by shapeInference
|
||||
if (iranks[port] != 0) {
|
||||
shape = ov::Shape(memPtr->getStaticDims());
|
||||
}
|
||||
|
||||
input_values[port] = std::make_shared<ngraph::runtime::HostTensor>(
|
||||
InferenceEngine::details::convertPrecision(memPtr->getDesc().getPrecision()),
|
||||
shape,
|
||||
memPtr->getData());
|
||||
}
|
||||
}
|
||||
|
||||
// call shape inference API
|
||||
auto shape_infer_result = m_shape_infer->infer(input_static_shapes, ov::make_tensor_accessor(input_values));
|
||||
auto shape_infer_result = m_shape_infer->infer(input_static_shapes, MemoryAccessor(data_dependency, iranks));
|
||||
|
||||
Result result{{}, shape_infer_result ? ShapeInferStatus::success : ShapeInferStatus::skip};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user