[CPU] Allow external blob to be reallocated (#12029)

This commit is contained in:
Yuan Hu 2023-01-24 13:30:07 +08:00 committed by GitHub
parent ea519f85db
commit ea776672ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 119 additions and 9 deletions

View File

@ -373,6 +373,27 @@ void LegacyInferRequest::SetBatch(int new_batch) {
}
}
void LegacyInferRequest::changeDefaultPtr() {
// renew external pointers before infer
const auto &inMap = graph->inputNodesMap;
for (auto &it : inMap) {
const auto &name = it.first;
auto itr = externalPtr.find(name);
if (itr != externalPtr.end() && itr->second != _inputs[name]->buffer()) {
itr->second = _inputs[name]->buffer();
}
}
const auto &outMap = graph->outputNodesMap;
for (auto &it : outMap) {
const auto &name = it.first;
auto itr = externalPtr.find(name);
if (itr != externalPtr.end() && itr->second != _outputs[name]->buffer()) {
itr->second = _outputs[name]->buffer();
}
}
InferRequestBase::changeDefaultPtr();
}
void LegacyInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr &data) {
OV_ITT_SCOPED_TASK(itt::domains::intel_cpu, "SetBlobLegacy");
if (name.empty()) {
@ -495,20 +516,23 @@ InferenceEngine::Blob::Ptr LegacyInferRequest::GetBlob(const std::string& name)
}
if (_inputs.find(name) == _inputs.end()) {
auto pBlobDesc = MemoryDescUtils::interpretAsBlobDesc(graph->getInputNodeByName(name)->getChildEdgesAtPort(0)[0]->getMemory());
InferenceEngine::TensorDesc desc = pBlobDesc;
if (_networkInputs.find(name) != _networkInputs.end()) {
InferenceEngine::Layout l = _networkInputs[name]->getLayout();
InferenceEngine::Precision p = _networkInputs[name]->getPrecision();
InferenceEngine::SizeVector dims = _networkInputs[name]->getTensorDesc().getDims();
auto pBlob = MemoryDescUtils::interpretAsBlob(graph->getInputNodeByName(name)->getChildEdgesAtPort(0)[0]->getMemory());
if (!pBlob) {
IE_THROW() << "Can not interpret cpu plugin memory object as InferenceEngine::Blob. Input node name: " << name;
}
InferenceEngine::TensorDesc desc = pBlob->getTensorDesc();
auto itr = _networkInputs.find(name);
if (itr != _networkInputs.end()) {
const InferenceEngine::Layout &l = itr->second->getLayout();
const InferenceEngine::Precision &p = itr->second->getPrecision();
const InferenceEngine::SizeVector &dims = itr->second->getTensorDesc().getDims();
desc = InferenceEngine::TensorDesc(p, dims, l);
}
_inputs[name] = make_blob_with_precision(desc);
_inputs[name]->allocate();
if (pBlobDesc == desc &&
if (pBlob->getTensorDesc() == desc &&
graph->_normalizePreprocMap.find(name) == graph->_normalizePreprocMap.end() && !graph->getConfig().batchLimit) {
externalPtr[name] = _inputs[name]->buffer();
}

View File

@ -63,11 +63,13 @@ private:
void PullStates();
void redefineMemoryForInputNodes();
void changeDefaultPtr();
std::shared_ptr<ExecNetwork> execNetwork;
openvino::itt::handle_t profilingTask;
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> memoryStates;
AsyncInferRequest* _asyncRequest = nullptr;
protected:
virtual void changeDefaultPtr();
};
class LegacyInferRequest : public InferRequestBase {
@ -83,6 +85,7 @@ private:
void PushInputData() override;
void initBlobs() override;
void SetBatch(int batch = -1) override;
void changeDefaultPtr() override;
};
class InferRequest : public InferRequestBase {

View File

@ -55,6 +55,10 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*Behavior.*InferRequestIOBBlobSetLayoutTest.*CanSetOutBlobWithDifferentLayouts.*layout=(CN|HW).*)",
R"(.*Behavior.*(Multi|Auto).*InferRequestSetBlobByType.*Batched.*)",
R"(.*(Multi|Auto).*Behavior.*InferRequestIOBBlobTest.*canProcessDeallocatedOutputBlobAfterGetAndSetBlob.*)",
// TODO Issue 100145
R"(.*Behavior.*InferRequestIOBBlobTest.*canReallocateExternalBlobViaGet.*)",
R"(.*Behavior.*OVInferRequestIOTensorTest.*canInferAfterIOBlobReallocation.*)",
R"(.*Behavior.*OVInferRequestDynamicTests.*InferUpperBoundNetworkAfterIOTensorsReshaping.*)",
R"(.*(Auto|Multi).*Behavior.*IncorrectConfigTests.*CanNotLoadNetworkWithIncorrectConfig.*)",
// TODO: until issue is xxx-59670 is resolved
R"(.*Gather8LayerTest.*)",

View File

@ -331,6 +331,39 @@ TEST_P(InferRequestIOBBlobTest, canInferWithGetOut) {
ASSERT_NO_THROW(InferenceEngine::Blob::Ptr outputBlob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first));
}
TEST_P(InferRequestIOBBlobTest, canReallocateExternalBlobViaGet) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 10, 10});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
param->set_friendly_name("param");
auto relu = std::make_shared<ngraph::op::Relu>(param);
relu->set_friendly_name("relu");
auto result = std::make_shared<ngraph::op::Result>(relu);
result->set_friendly_name("result");
ngraph::ParameterVector params = {param};
ngraph::ResultVector results = {result};
ngraph = std::make_shared<ngraph::Function>(results, params);
}
// Create CNNNetwork from ngraph::Function
InferenceEngine::CNNNetwork cnnNet(ngraph);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, target_device, configuration);
// Create InferRequest
auto req = execNet.CreateInferRequest();
auto inBlob = req.GetBlob("param");
auto outBlob = req.GetBlob("relu");
inBlob->allocate();
outBlob->allocate();
ASSERT_NO_THROW(req.Infer());
}
class InferRequestIOBBlobSetPrecisionTest : public BehaviorTestsUtils::BehaviorTestsBasicBase,
public BehaviorTestsUtils::IEInferRequestTestBase {
protected:

View File

@ -275,6 +275,32 @@ TEST_P(OVInferRequestDynamicTests, InferUpperBoundNetworkWithGetTensor) {
ASSERT_TRUE(checkOutput(req.get_tensor("input_tensor"), req.get_tensor(outputname)));
}
TEST_P(OVInferRequestDynamicTests, InferUpperBoundNetworkAfterIOTensorsReshaping) {
const std::string tensor_name = "input_tensor";
std::map<std::string, ov::PartialShape> shapes;
shapes[tensor_name] = {ov::Dimension(0, 19), 4, 20, 20};
OV_ASSERT_NO_THROW(function->reshape(shapes));
// Load ov::Model to target plugins
auto execNet = ie->compile_model(function, target_device, configuration);
// Create InferRequest
ov::InferRequest req;
ov::Tensor tensor, otensor;
const std::string outputname = function->outputs().back().get_any_name();
OV_ASSERT_NO_THROW(req = execNet.create_infer_request());
OV_ASSERT_NO_THROW(otensor = req.get_tensor(outputname));
ASSERT_EQ(0, otensor.get_size()); // output tensor is not allocated
OV_ASSERT_NO_THROW(otensor.set_shape({1, 4, 20, 20}));
OV_ASSERT_NO_THROW(otensor.set_shape({4, 4, 20, 20}));
OV_ASSERT_NO_THROW(otensor.set_shape({1, 4, 20, 20}));
OV_ASSERT_NO_THROW(tensor = req.get_tensor(function->inputs().back().get_any_name()));
OV_ASSERT_NO_THROW(tensor.set_shape({1, 4, 20, 20}));
OV_ASSERT_NO_THROW(tensor.set_shape({4, 4, 20, 20}));
OV_ASSERT_NO_THROW(tensor.set_shape({1, 4, 20, 20}));
OV_ASSERT_NO_THROW(req.infer());
OV_ASSERT_NO_THROW(req.start_async());
OV_ASSERT_NO_THROW(req.wait());
}
TEST_P(OVInferRequestDynamicTests, InferFullyDynamicNetworkWithGetTensor) {
const std::string tensor_name = "input_tensor";
const ov::Shape refShape = inOutShapes[0].first;

View File

@ -166,6 +166,26 @@ TEST_P(OVInferRequestIOTensorTest, canInferWithGetIn) {
OV_ASSERT_NO_THROW(req.get_tensor(output));
}
TEST_P(OVInferRequestIOTensorTest, canInferAfterIOBlobReallocation) {
ov::Tensor input_tensor, output_tensor;
auto in_shape = input.get_shape();
auto out_shape = output.get_shape();
// imitates blob reallocation
OV_ASSERT_NO_THROW(input_tensor = req.get_tensor(input));
OV_ASSERT_NO_THROW(input_tensor.set_shape({5, 5, 5, 5}));
OV_ASSERT_NO_THROW(input_tensor.set_shape(in_shape));
OV_ASSERT_NO_THROW(output_tensor = req.get_tensor(output));
OV_ASSERT_NO_THROW(output_tensor.set_shape({20, 20}));
OV_ASSERT_NO_THROW(output_tensor.set_shape(out_shape));
OV_ASSERT_NO_THROW(req.infer());
OV_ASSERT_NO_THROW(req.start_async());
OV_ASSERT_NO_THROW(req.wait());
OV_ASSERT_NO_THROW(req.get_tensor(output));
}
TEST_P(OVInferRequestIOTensorTest, canInferWithGetOut) {
ov::Tensor output_tensor;
OV_ASSERT_NO_THROW(output_tensor = req.get_tensor(output));