stateful inferface impl for AUTO/HETERO (#11590)
* CPU for stateful model Signed-off-by: fishbell <bell.song@intel.com> * log Signed-off-by: fishbell <bell.song@intel.com> * hetero impl Signed-off-by: fishbell <bell.song@intel.com> * enable tests Signed-off-by: fishbell <bell.song@intel.com>
This commit is contained in:
parent
870455675c
commit
912f40e74d
@ -743,7 +743,17 @@ std::vector<DeviceInformation> MultiDeviceInferencePlugin::FilterDeviceByNetwork
|
||||
|
||||
std::vector<DeviceInformation> filterDevice;
|
||||
auto model = network.getFunction();
|
||||
if (model->is_dynamic()) {
|
||||
auto isStateful = [&]() {
|
||||
for (auto& op : model->get_ops()) {
|
||||
if (std::dynamic_pointer_cast<ngraph::op::AssignBase>(op) ||
|
||||
std::dynamic_pointer_cast<ngraph::op::ReadValueBase>(op)) {
|
||||
LOG_INFO("[AUTOPLUGIN]:stateful mode, try deployed to CPU");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
if (model->is_dynamic() || isStateful()) {
|
||||
for (auto& iter : metaDevices) {
|
||||
if (iter.deviceName.find("CPU") != std::string::npos) {
|
||||
filterDevice.push_back(iter);
|
||||
|
@ -126,6 +126,18 @@ void HeteroInferRequest::InferImpl() {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> HeteroInferRequest::QueryState() {
|
||||
memoryStates = {};
|
||||
for (auto&& desc : _inferRequests) {
|
||||
auto& r = desc._request;
|
||||
assert(r);
|
||||
for (auto&& state : r->QueryState()) {
|
||||
memoryStates.emplace_back(state);
|
||||
}
|
||||
}
|
||||
return memoryStates;
|
||||
}
|
||||
|
||||
std::map<std::string, InferenceEngineProfileInfo> HeteroInferRequest::GetPerformanceCounts() const {
|
||||
std::map<std::string, InferenceEngineProfileInfo> perfMap;
|
||||
for (size_t i = 0; i < _inferRequests.size(); i++) {
|
||||
|
@ -50,6 +50,8 @@ public:
|
||||
|
||||
const InferenceEngine::PreProcessInfo& GetPreProcess(const std::string& name) const override;
|
||||
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override;
|
||||
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||
|
||||
SubRequestsList _inferRequests;
|
||||
@ -58,6 +60,7 @@ public:
|
||||
|
||||
private:
|
||||
void CreateInferRequest(const std::unordered_map<std::string, std::string>& subgraphInputToOutputBlobNames);
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> memoryStates;
|
||||
};
|
||||
|
||||
} // namespace HeteroPlugin
|
||||
|
@ -0,0 +1,53 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <common_test_utils/test_constants.hpp>
|
||||
#include "behavior/infer_request/memory_states.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace BehaviorTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
InferenceEngine::CNNNetwork getNetwork() {
|
||||
ngraph::Shape shape = {1, 200};
|
||||
ngraph::element::Type type = ngraph::element::f32;
|
||||
|
||||
auto input = std::make_shared<ngraph::op::v0::Parameter>(type, shape);
|
||||
auto mem_i1 = std::make_shared<ngraph::op::v0::Constant>(type, shape, 0);
|
||||
auto mem_r1 = std::make_shared<ngraph::op::v3::ReadValue>(mem_i1, "r_1-3");
|
||||
auto mul1 = std::make_shared<ngraph::op::v1::Multiply>(mem_r1, input);
|
||||
|
||||
auto mem_i2 = std::make_shared<ngraph::op::v0::Constant>(type, shape, 0);
|
||||
auto mem_r2 = std::make_shared<ngraph::op::v3::ReadValue>(mem_i2, "c_1-3");
|
||||
auto mul2 = std::make_shared<ngraph::op::v1::Multiply>(mem_r2, mul1);
|
||||
auto mem_w2 = std::make_shared<ngraph::op::v3::Assign>(mul2, "c_1-3");
|
||||
|
||||
auto mem_w1 = std::make_shared<ngraph::op::v3::Assign>(mul2, "r_1-3");
|
||||
auto sigm = std::make_shared<ngraph::op::Sigmoid>(mul2);
|
||||
sigm->set_friendly_name("sigmod_state");
|
||||
mem_r1->set_friendly_name("Memory_1");
|
||||
mem_w1->add_control_dependency(mem_r1);
|
||||
sigm->add_control_dependency(mem_w1);
|
||||
|
||||
mem_r2->set_friendly_name("Memory_2");
|
||||
mem_w2->add_control_dependency(mem_r2);
|
||||
sigm->add_control_dependency(mem_w2);
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(ngraph::NodeVector{sigm}, ngraph::ParameterVector{input}, "addOutput");
|
||||
return InferenceEngine::CNNNetwork{function};
|
||||
}
|
||||
|
||||
std::vector<memoryStateParams> memoryStateTestCases = {
|
||||
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_CPU, {}),
|
||||
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_AUTO,
|
||||
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES) , CommonTestUtils::DEVICE_CPU}}),
|
||||
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_HETERO,
|
||||
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES) , CommonTestUtils::DEVICE_CPU}})
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_VariableStateBasic, InferRequestVariableStateTest,
|
||||
::testing::ValuesIn(memoryStateTestCases),
|
||||
InferRequestVariableStateTest::getTestCaseName);
|
||||
} // namespace
|
@ -27,6 +27,7 @@ InferenceEngine::CNNNetwork getNetwork() {
|
||||
auto mem_w1 = std::make_shared<ngraph::op::v3::Assign>(mul2, "r_1-3");
|
||||
auto sigm = std::make_shared<ngraph::op::Sigmoid>(mul2);
|
||||
|
||||
sigm->set_friendly_name("sigmod_state");
|
||||
mem_r1->set_friendly_name("Memory_1");
|
||||
mem_w1->add_control_dependency(mem_r1);
|
||||
sigm->add_control_dependency(mem_w1);
|
||||
@ -40,7 +41,7 @@ InferenceEngine::CNNNetwork getNetwork() {
|
||||
}
|
||||
|
||||
std::vector<memoryStateParams> memoryStateTestCases = {
|
||||
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA)
|
||||
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA, {})
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_VariableStateBasic, InferRequestVariableStateTest,
|
||||
|
@ -12,7 +12,8 @@ namespace BehaviorTestsDefinitions {
|
||||
typedef std::tuple<
|
||||
InferenceEngine::CNNNetwork, // CNNNetwork to work with
|
||||
std::vector<std::string>, // Memory States to query
|
||||
std::string> // Target device name
|
||||
std::string, // Target device name
|
||||
std::map<std::string, std::string>> // device configuration
|
||||
memoryStateParams;
|
||||
|
||||
class InferRequestVariableStateTest : public CommonTestUtils::TestsCommon,
|
||||
@ -21,7 +22,7 @@ protected:
|
||||
InferenceEngine::CNNNetwork net;
|
||||
std::vector<std::string> statesToQuery;
|
||||
std::string deviceName;
|
||||
|
||||
std::map<std::string, std::string> configuration;
|
||||
InferenceEngine::ExecutableNetwork PrepareNetwork();
|
||||
|
||||
public:
|
||||
|
@ -13,22 +13,28 @@ std::string InferRequestVariableStateTest::getTestCaseName(const testing::TestPa
|
||||
InferenceEngine::CNNNetwork net;
|
||||
std::string targetDevice;
|
||||
std::vector<std::string> statesToQuery;
|
||||
std::tie(net, statesToQuery, targetDevice) = obj.param;
|
||||
std::map<std::string, std::string> configuration;
|
||||
std::tie(net, statesToQuery, targetDevice, configuration) = obj.param;
|
||||
result << "targetDevice=" << targetDevice;
|
||||
if (!configuration.empty()) {
|
||||
for (auto &configItem : configuration) {
|
||||
result << "_configItem=" << configItem.first << "_" << configItem.second << "_";
|
||||
}
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void InferRequestVariableStateTest::SetUp() {
|
||||
// Skip test according to plugin specific disabledTestPatterns() (if any)
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
std::tie(net, statesToQuery, deviceName) = GetParam();
|
||||
std::tie(net, statesToQuery, deviceName, configuration) = GetParam();
|
||||
}
|
||||
|
||||
InferenceEngine::ExecutableNetwork InferRequestVariableStateTest::PrepareNetwork() {
|
||||
net.addOutput("Memory_1");
|
||||
net.addOutput("Memory_2");
|
||||
auto ie = PluginCache::get().ie(deviceName);
|
||||
return ie->LoadNetwork(net, deviceName);
|
||||
return ie->LoadNetwork(net, deviceName, configuration);
|
||||
}
|
||||
|
||||
TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_QueryState) {
|
||||
@ -175,7 +181,9 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
|
||||
auto executableNet = PrepareNetwork();
|
||||
auto inferReq = executableNet.CreateInferRequest();
|
||||
auto inferReq2 = executableNet.CreateInferRequest();
|
||||
const float new_state_val = 13.0f;
|
||||
|
||||
// set the input data for the network
|
||||
for (const auto &input : executableNet.GetInputsInfo()) {
|
||||
const auto &info = input.second;
|
||||
InferenceEngine::Blob::Ptr inBlob;
|
||||
@ -185,17 +193,36 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
|
||||
inferReq.SetBlob(info->name(), inBlob);
|
||||
}
|
||||
|
||||
for (auto &&state : inferReq.QueryState()) {
|
||||
state.Reset();
|
||||
}
|
||||
// initial state for 2nd infer request
|
||||
for (auto &&state : inferReq2.QueryState()) {
|
||||
auto state_val = state.GetState();
|
||||
auto element_count = state_val->size();
|
||||
|
||||
float *new_state_data = new float[element_count];
|
||||
for (int i = 0; i < element_count; i++) {
|
||||
new_state_data[i] = new_state_val;
|
||||
}
|
||||
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
|
||||
stateBlob->allocate();
|
||||
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
|
||||
delete[]new_state_data;
|
||||
state.SetState(stateBlob);
|
||||
}
|
||||
|
||||
// reset state for 1st infer request
|
||||
for (auto &&state : inferReq.QueryState()) {
|
||||
state.Reset();
|
||||
}
|
||||
|
||||
inferReq.Infer();
|
||||
|
||||
auto states = inferReq.QueryState();
|
||||
auto states2 = inferReq2.QueryState();
|
||||
// check the output and state of 1st request
|
||||
auto outputBlob = inferReq.GetBlob("sigmod_state");
|
||||
auto output_data = InferenceEngine::as<InferenceEngine::MemoryBlob>(outputBlob)->rmap().as<float*>();
|
||||
for (int i = 0; i < outputBlob->size(); i++) {
|
||||
EXPECT_NEAR(0.5f, output_data[i], 1e-5);
|
||||
}
|
||||
for (int i = 0; i < states.size(); ++i) {
|
||||
auto lastState = states[i].GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
@ -203,16 +230,12 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
|
||||
|
||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
||||
|
||||
if (i == 0) {
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(0.5f, last_state_data[j], 1e-3);
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(0.0, last_state_data[j], 1e-5);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(0.0f, last_state_data[j], 1e-5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check the output and state of 2nd request
|
||||
for (int i = 0; i < states2.size(); ++i) {
|
||||
auto lastState = states2[i].GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
@ -221,7 +244,7 @@ TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_2infers) {
|
||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
||||
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(0.0f, last_state_data[j], 1e-5);
|
||||
EXPECT_NEAR(new_state_val, last_state_data[j], 1e-5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user