stateful inferface impl for AUTO/HETERO (#11590)

* CPU for stateful model

Signed-off-by: fishbell <bell.song@intel.com>

* log

Signed-off-by: fishbell <bell.song@intel.com>

* hetero impl

Signed-off-by: fishbell <bell.song@intel.com>

* enable tests

Signed-off-by: fishbell <bell.song@intel.com>
This commit is contained in:
yanlan song 2022-05-06 12:43:56 +08:00 committed by GitHub
parent 870455675c
commit 912f40e74d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 20 deletions

View File

@ -743,7 +743,17 @@ std::vector<DeviceInformation> MultiDeviceInferencePlugin::FilterDeviceByNetwork
std::vector<DeviceInformation> filterDevice;
auto model = network.getFunction();
if (model->is_dynamic()) {
auto isStateful = [&]() {
for (auto& op : model->get_ops()) {
if (std::dynamic_pointer_cast<ngraph::op::AssignBase>(op) ||
std::dynamic_pointer_cast<ngraph::op::ReadValueBase>(op)) {
LOG_INFO("[AUTOPLUGIN]:stateful mode, try deployed to CPU");
return true;
}
}
return false;
};
if (model->is_dynamic() || isStateful()) {
for (auto& iter : metaDevices) {
if (iter.deviceName.find("CPU") != std::string::npos) {
filterDevice.push_back(iter);

View File

@ -126,6 +126,18 @@ void HeteroInferRequest::InferImpl() {
}
}
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> HeteroInferRequest::QueryState() {
memoryStates = {};
for (auto&& desc : _inferRequests) {
auto& r = desc._request;
assert(r);
for (auto&& state : r->QueryState()) {
memoryStates.emplace_back(state);
}
}
return memoryStates;
}
std::map<std::string, InferenceEngineProfileInfo> HeteroInferRequest::GetPerformanceCounts() const {
std::map<std::string, InferenceEngineProfileInfo> perfMap;
for (size_t i = 0; i < _inferRequests.size(); i++) {

View File

@ -50,6 +50,8 @@ public:
const InferenceEngine::PreProcessInfo& GetPreProcess(const std::string& name) const override;
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override;
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
SubRequestsList _inferRequests;
@ -58,6 +60,7 @@ public:
private:
void CreateInferRequest(const std::unordered_map<std::string, std::string>& subgraphInputToOutputBlobNames);
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> memoryStates;
};
} // namespace HeteroPlugin

View File

@ -0,0 +1,53 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <common_test_utils/test_constants.hpp>
#include "behavior/infer_request/memory_states.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "ngraph_functions/builders.hpp"
using namespace BehaviorTestsDefinitions;
namespace {
InferenceEngine::CNNNetwork getNetwork() {
ngraph::Shape shape = {1, 200};
ngraph::element::Type type = ngraph::element::f32;
auto input = std::make_shared<ngraph::op::v0::Parameter>(type, shape);
auto mem_i1 = std::make_shared<ngraph::op::v0::Constant>(type, shape, 0);
auto mem_r1 = std::make_shared<ngraph::op::v3::ReadValue>(mem_i1, "r_1-3");
auto mul1 = std::make_shared<ngraph::op::v1::Multiply>(mem_r1, input);
auto mem_i2 = std::make_shared<ngraph::op::v0::Constant>(type, shape, 0);
auto mem_r2 = std::make_shared<ngraph::op::v3::ReadValue>(mem_i2, "c_1-3");
auto mul2 = std::make_shared<ngraph::op::v1::Multiply>(mem_r2, mul1);
auto mem_w2 = std::make_shared<ngraph::op::v3::Assign>(mul2, "c_1-3");
auto mem_w1 = std::make_shared<ngraph::op::v3::Assign>(mul2, "r_1-3");
auto sigm = std::make_shared<ngraph::op::Sigmoid>(mul2);
sigm->set_friendly_name("sigmod_state");
mem_r1->set_friendly_name("Memory_1");
mem_w1->add_control_dependency(mem_r1);
sigm->add_control_dependency(mem_w1);
mem_r2->set_friendly_name("Memory_2");
mem_w2->add_control_dependency(mem_r2);
sigm->add_control_dependency(mem_w2);
auto function = std::make_shared<ngraph::Function>(ngraph::NodeVector{sigm}, ngraph::ParameterVector{input}, "addOutput");
return InferenceEngine::CNNNetwork{function};
}
std::vector<memoryStateParams> memoryStateTestCases = {
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_CPU, {}),
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_AUTO,
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES) , CommonTestUtils::DEVICE_CPU}}),
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_HETERO,
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES) , CommonTestUtils::DEVICE_CPU}})
};
INSTANTIATE_TEST_SUITE_P(smoke_VariableStateBasic, InferRequestVariableStateTest,
::testing::ValuesIn(memoryStateTestCases),
InferRequestVariableStateTest::getTestCaseName);
} // namespace

View File

@ -27,6 +27,7 @@ InferenceEngine::CNNNetwork getNetwork() {
auto mem_w1 = std::make_shared<ngraph::op::v3::Assign>(mul2, "r_1-3");
auto sigm = std::make_shared<ngraph::op::Sigmoid>(mul2);
sigm->set_friendly_name("sigmod_state");
mem_r1->set_friendly_name("Memory_1");
mem_w1->add_control_dependency(mem_r1);
sigm->add_control_dependency(mem_w1);
@ -40,7 +41,7 @@ InferenceEngine::CNNNetwork getNetwork() {
}
std::vector<memoryStateParams> memoryStateTestCases = {
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA)
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA, {})
};
INSTANTIATE_TEST_SUITE_P(smoke_VariableStateBasic, InferRequestVariableStateTest,

View File

@ -12,7 +12,8 @@ namespace BehaviorTestsDefinitions {
typedef std::tuple<
InferenceEngine::CNNNetwork, // CNNNetwork to work with
std::vector<std::string>, // Memory States to query
std::string> // Target device name
std::string, // Target device name
std::map<std::string, std::string>> // device configuration
memoryStateParams;
class InferRequestVariableStateTest : public CommonTestUtils::TestsCommon,
@ -21,7 +22,7 @@ protected:
InferenceEngine::CNNNetwork net;
std::vector<std::string> statesToQuery;
std::string deviceName;
std::map<std::string, std::string> configuration;
InferenceEngine::ExecutableNetwork PrepareNetwork();
public:

View File

@ -13,22 +13,28 @@ std::string InferRequestVariableStateTest::getTestCaseName(const testing::TestPa
InferenceEngine::CNNNetwork net;
std::string targetDevice;
std::vector<std::string> statesToQuery;
std::tie(net, statesToQuery, targetDevice) = obj.param;
std::map<std::string, std::string> configuration;
std::tie(net, statesToQuery, targetDevice, configuration) = obj.param;
result << "targetDevice=" << targetDevice;
if (!configuration.empty()) {
for (auto &configItem : configuration) {
result << "_configItem=" << configItem.first << "_" << configItem.second << "_";
}
}
return result.str();
}
void InferRequestVariableStateTest::SetUp() {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
std::tie(net, statesToQuery, deviceName) = GetParam();
std::tie(net, statesToQuery, deviceName, configuration) = GetParam();
}
InferenceEngine::ExecutableNetwork InferRequestVariableStateTest::PrepareNetwork() {
net.addOutput("Memory_1");
net.addOutput("Memory_2");
auto ie = PluginCache::get().ie(deviceName);
return ie->LoadNetwork(net, deviceName);
return ie->LoadNetwork(net, deviceName, configuration);
}
TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_QueryState) {
@ -175,7 +181,9 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
auto inferReq2 = executableNet.CreateInferRequest();
const float new_state_val = 13.0f;
// set the input data for the network
for (const auto &input : executableNet.GetInputsInfo()) {
const auto &info = input.second;
InferenceEngine::Blob::Ptr inBlob;
@ -185,17 +193,36 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
inferReq.SetBlob(info->name(), inBlob);
}
for (auto &&state : inferReq.QueryState()) {
state.Reset();
}
// initial state for 2nd infer request
for (auto &&state : inferReq2.QueryState()) {
auto state_val = state.GetState();
auto element_count = state_val->size();
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete[]new_state_data;
state.SetState(stateBlob);
}
// reset state for 1st infer request
for (auto &&state : inferReq.QueryState()) {
state.Reset();
}
inferReq.Infer();
auto states = inferReq.QueryState();
auto states2 = inferReq2.QueryState();
// check the output and state of 1st request
auto outputBlob = inferReq.GetBlob("sigmod_state");
auto output_data = InferenceEngine::as<InferenceEngine::MemoryBlob>(outputBlob)->rmap().as<float*>();
for (int i = 0; i < outputBlob->size(); i++) {
EXPECT_NEAR(0.5f, output_data[i], 1e-5);
}
for (int i = 0; i < states.size(); ++i) {
auto lastState = states[i].GetState();
auto last_state_size = lastState->size();
@ -203,16 +230,12 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
if (i == 0) {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0.5f, last_state_data[j], 1e-3);
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0.0, last_state_data[j], 1e-5);
}
} else {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0.0f, last_state_data[j], 1e-5);
}
}
}
// check the output and state of 2nd request
for (int i = 0; i < states2.size(); ++i) {
auto lastState = states2[i].GetState();
auto last_state_size = lastState->size();
@ -221,7 +244,7 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0.0f, last_state_data[j], 1e-5);
EXPECT_NEAR(new_state_val, last_state_data[j], 1e-5);
}
}
}