Fixed issue with run stateful network with several infer requests on MKLDNNPlugin (#3711)

This commit is contained in:
Svetlana Dolinina 2021-01-21 15:01:03 +03:00 committed by GitHub
parent 88b200ea5b
commit 05d97fa24a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 232 additions and 43 deletions

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2020 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -16,6 +16,7 @@
#include "nodes/common/cpu_convert.h"
#include "mkldnn_memory_state.h"
#include "nodes/mkldnn_memory_node.hpp"
#include "nodes/common/cpu_memcpy.h"
MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs,
@ -42,7 +43,7 @@ MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsData
// of MemoryLayer implementation. It uses output edge of MemoryLayer
// producer as storage for tensor to keep it between infer calls.
IE_SUPPRESS_DEPRECATED_START
if (execNetwork->QueryState().size() == 0) {
if (execNetwork->_numRequests > 1 || execNetwork->QueryState().size() == 0) {
for (auto &node : graph->GetNodes()) {
if (node->getType() == MemoryInput) {
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
@ -132,6 +133,49 @@ void MKLDNNPlugin::MKLDNNInferRequest::PushInputData() {
}
}
void MKLDNNPlugin::MKLDNNInferRequest::PushStates() {
for (auto &node : graph->GetNodes()) {
if (node->getType() == MemoryInput) {
auto cur_node = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
auto cur_id = cur_node->getId();
for (const auto& state : memoryStates) {
if (state->GetName() == cur_id) {
auto cur_state_mem = cur_node->getStore();
auto data_ptr = state->GetState()->cbuffer().as<void*>();
auto data_size = state->GetState()->byteSize();
auto elemSize = MKLDNNExtensionUtils::sizeOfDataType(cur_state_mem->GetDataType());
auto padSize = cur_state_mem->GetDescriptor().data.layout_desc.blocking.offset_padding;
auto cur_state_mem_buf = static_cast<uint8_t*>(cur_state_mem->GetData()) + padSize * elemSize;
cpu_memcpy(cur_state_mem_buf, data_ptr, data_size);
}
}
}
}
}
void MKLDNNPlugin::MKLDNNInferRequest::PullStates() {
for (auto &node : graph->GetNodes()) {
if (node->getType() == MemoryInput) {
auto cur_node = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
auto cur_id = cur_node->getId();
for (const auto& state : memoryStates) {
if (state->GetName() == cur_id) {
auto cur_state_mem = cur_node->getStore();
auto data_ptr = state->GetState()->cbuffer().as<void*>();
auto data_size = state->GetState()->byteSize();
auto elemSize = MKLDNNExtensionUtils::sizeOfDataType(cur_state_mem->GetDataType());
auto padSize = cur_state_mem->GetDescriptor().data.layout_desc.blocking.offset_padding;
auto cur_state_mem_buf = static_cast<uint8_t*>(cur_state_mem->GetData()) + padSize * elemSize;
cpu_memcpy(data_ptr, cur_state_mem_buf, data_size);
}
}
}
}
}
void MKLDNNPlugin::MKLDNNInferRequest::InferImpl() {
using namespace openvino::itt;
OV_ITT_SCOPED_TASK(itt::domains::MKLDNNPlugin, profilingTask);
@ -144,8 +188,16 @@ void MKLDNNPlugin::MKLDNNInferRequest::InferImpl() {
PushInputData();
if (memoryStates.size() != 0) {
PushStates();
}
graph->Infer(m_curBatch);
if (memoryStates.size() != 0) {
PullStates();
}
graph->PullOutputData(_outputs);
}

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2020 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -49,6 +49,8 @@ public:
private:
void PushInputData();
void PushStates();
void PullStates();
void pushInput(const std::string& inputName, InferenceEngine::Blob::Ptr& inputBlob, InferenceEngine::Precision dataType);

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2020 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -15,24 +15,15 @@ std::string MKLDNNVariableState::GetName() const {
}
void MKLDNNVariableState::Reset() {
storage->FillZero();
std::memset(this->storage->buffer(), 0, storage->byteSize());
}
void MKLDNNVariableState::SetState(Blob::Ptr newState) {
auto prec = newState->getTensorDesc().getPrecision();
auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(prec);
auto data_layout = MKLDNNMemory::Convert(newState->getTensorDesc().getLayout());
auto data_ptr = newState->cbuffer().as<void*>();
auto data_size = newState->byteSize();
storage->SetData(data_type, data_layout, data_ptr, data_size);
storage = newState;
}
InferenceEngine::Blob::CPtr MKLDNNVariableState::GetState() const {
auto result_blob = make_blob_with_precision(MKLDNNMemoryDesc(storage->GetDescriptor()));
result_blob->allocate();
std::memcpy(result_blob->buffer(), storage->GetData(), storage->GetSize());
return result_blob;
return storage;
}
} // namespace MKLDNNPlugin

View File

@ -1,11 +1,13 @@
// Copyright (C) 2018-2020 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "cpp_interfaces/impl/ie_variable_state_internal.hpp"
#include "blob_factory.hpp"
#include "mkldnn_memory.h"
#include "nodes/common/cpu_memcpy.h"
#include <string>
@ -14,7 +16,11 @@ namespace MKLDNNPlugin {
class MKLDNNVariableState : public InferenceEngine::IVariableStateInternal {
public:
MKLDNNVariableState(std::string name, MKLDNNMemoryPtr storage) :
name(name), storage(storage) {}
name(name) {
this->storage = make_blob_with_precision(MKLDNNMemoryDesc(storage->GetDescriptor()));
this->storage->allocate();
cpu_memcpy(this->storage->buffer(), storage->GetData(), storage->GetSize());
}
std::string GetName() const override;
void Reset() override;
@ -23,7 +29,7 @@ public:
private:
std::string name;
MKLDNNMemoryPtr storage;
InferenceEngine::Blob::Ptr storage;
};
} // namespace MKLDNNPlugin

View File

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -51,6 +51,8 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*(ConstantResultSubgraphTest).*)",
// TODO: Issue: 29577
R"(.*CoreThreadingTests.smoke_QueryNetwork.*)",
//TODO: Issue: 46416
R"(.*VariableStateTest.inferreq_smoke_VariableState_2infers*.*)",
// TODO: Issue 24839
R"(.*ConvolutionLayerTest.CompareWithRefs.*D=\(1.3\).*)",
R"(.*ConvolutionLayerTest.CompareWithRefs.*D=\(3.1\).*)"

View File

@ -1,11 +1,13 @@
// Copyright (C) 2018-2020 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <base/behavior_test_utils.hpp>
#include <common_test_utils/common_utils.hpp>
#include "behavior/memory_states.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "blob_factory.hpp"
std::string VariableStateTest::getTestCaseName(const testing::TestParamInfo<memoryStateParams> &obj) {
std::ostringstream result;
@ -29,6 +31,8 @@ InferenceEngine::ExecutableNetwork VariableStateTest::PrepareNetwork() {
}
TEST_P(VariableStateTest, smoke_VariableState_QueryState) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
@ -44,6 +48,8 @@ TEST_P(VariableStateTest, smoke_VariableState_QueryState) {
}
TEST_P(VariableStateTest, smoke_VariableState_SetState) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
const float new_state_val = 13.0f;
@ -52,11 +58,14 @@ TEST_P(VariableStateTest, smoke_VariableState_SetState) {
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete []new_state_data;
state.SetState(stateBlob);
}
@ -74,6 +83,8 @@ TEST_P(VariableStateTest, smoke_VariableState_SetState) {
}
TEST_P(VariableStateTest, smoke_VariableState_Reset) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
const float new_state_val = 13.0f;
@ -82,10 +93,14 @@ TEST_P(VariableStateTest, smoke_VariableState_Reset) {
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete []new_state_data;
state.SetState(stateBlob);
}
@ -106,7 +121,7 @@ TEST_P(VariableStateTest, smoke_VariableState_Reset) {
}
} else {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(13.0f, last_state_data[j], 1e-5);
EXPECT_NEAR(new_state_val, last_state_data[j], 1e-5);
}
}
}
@ -114,6 +129,8 @@ TEST_P(VariableStateTest, smoke_VariableState_Reset) {
}
TEST_P(VariableStateTest, inferreq_smoke_VariableState_QueryState) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
@ -128,6 +145,8 @@ TEST_P(VariableStateTest, inferreq_smoke_VariableState_QueryState) {
}
TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
@ -137,11 +156,14 @@ TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) {
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete []new_state_data;
state.SetState(stateBlob);
}
@ -150,7 +172,6 @@ TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) {
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
for (int i = 0; i < last_state_size; i++) {
EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
}
@ -158,6 +179,8 @@ TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) {
}
TEST_P(VariableStateTest, inferreq_smoke_VariableState_Reset) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
@ -167,10 +190,14 @@ TEST_P(VariableStateTest, inferreq_smoke_VariableState_Reset) {
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete []new_state_data;
state.SetState(stateBlob);
}
@ -184,15 +211,124 @@ TEST_P(VariableStateTest, inferreq_smoke_VariableState_Reset) {
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
if (i == 0) {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0, last_state_data[j], 1e-5);
}
} else {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(13.0f, last_state_data[j], 1e-5);
EXPECT_NEAR(new_state_val, last_state_data[j], 1e-5);
}
}
}
}
TEST_P(VariableStateTest, inferreq_smoke_VariableState_2infers_set) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
auto inferReq2 = executableNet.CreateInferRequest();
const float new_state_val = 13.0f;
for (auto&& state : inferReq.QueryState()) {
state.Reset();
auto state_val = state.GetState();
auto element_count = state_val->size();
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete []new_state_data;
state.SetState(stateBlob);
}
for (auto&& state : inferReq2.QueryState()) {
state.Reset();
}
auto states = inferReq.QueryState();
auto states2 = inferReq2.QueryState();
for (int i = 0; i < states.size(); ++i) {
auto lastState = states[i].GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(13.0f, last_state_data[j], 1e-5);
}
}
for (int i = 0; i < states2.size(); ++i) {
auto lastState = states2[i].GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0, last_state_data[j], 1e-5);
}
}
}
TEST_P(VariableStateTest, inferreq_smoke_VariableState_2infers) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
auto inferReq2 = executableNet.CreateInferRequest();
for (const auto &input : executableNet.GetInputsInfo()) {
const auto &info = input.second;
InferenceEngine::Blob::Ptr inBlob;
inBlob = make_blob_with_precision(info->getTensorDesc());
inBlob->allocate();
std::memset(inBlob->buffer(), 0, inBlob->byteSize());
inferReq.SetBlob(info->name(), inBlob);
}
for (auto&& state : inferReq.QueryState()) {
state.Reset();
}
for (auto&& state : inferReq2.QueryState()) {
state.Reset();
}
inferReq.Infer();
auto states = inferReq.QueryState();
auto states2 = inferReq2.QueryState();
for (int i = 0; i < states.size(); ++i) {
auto lastState = states[i].GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
if (i == 0) {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0.5f, last_state_data[j], 1e-3);
}
} else {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0.0f, last_state_data[j], 1e-5);
}
}
}
for (int i = 0; i < states2.size(); ++i) {
auto lastState = states2[i].GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0.0f, last_state_data[j], 1e-5);
}
}
}