[GPU] GPU Remote tensor API update (#8485)

This commit is contained in:
Vladimir Paramuzov 2021-11-15 09:58:12 +03:00 committed by GitHub
parent b320061ea3
commit f3e1dc25b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 695 additions and 400 deletions

View File

@ -786,7 +786,9 @@ void CLDNNInferRequest::allocate_inputs() {
} else {
auto blobPtr = create_device_blob(desc, litr->second);
_deviceInputs[name] = blobPtr;
_inputs[name] = blobPtr;
Blob::Ptr inputBlob = create_host_blob(desc);
inputBlob->allocate();
_inputs[name] = inputBlob;
}
}
}
@ -832,7 +834,9 @@ void CLDNNInferRequest::allocate_outputs() {
}
auto blobPtr = create_device_blob(desc, output_layout);
_deviceOutputs[no.first] = blobPtr;
_outputs[no.first] = blobPtr;
Blob::Ptr outputBlob = create_host_blob(desc);
outputBlob->allocate();
_outputs[no.first] = outputBlob;
outputsMap[no.first] = outputID;
}
}

View File

@ -204,6 +204,7 @@ CLDNNExecutionContextImpl::CLDNNExecutionContextImpl(const std::shared_ptr<IInfe
lock.clear(std::memory_order_relaxed);
gpu_handle_param _context_id = nullptr;
gpu_handle_param _va_device = nullptr;
int ctx_device_id = 0;
int target_tile_id = -1;
if (params.size()) {
@ -212,6 +213,12 @@ CLDNNExecutionContextImpl::CLDNNExecutionContextImpl(const std::shared_ptr<IInfe
if (GPU_PARAM_VALUE(OCL) == contextTypeStr) {
_context_id = _ObjFromParamSimple<gpu_handle_param>(params, GPU_PARAM_KEY(OCL_CONTEXT));
if (params.find(GPU_PARAM_KEY(OCL_QUEUE)) != params.end())
m_external_queue = _ObjFromParamSimple<gpu_handle_param>(params, GPU_PARAM_KEY(OCL_QUEUE));
if (params.find(GPU_PARAM_KEY(OCL_CONTEXT_DEVICE_ID)) != params.end())
ctx_device_id = _ObjFromParamSimple<int>(params, GPU_PARAM_KEY(OCL_CONTEXT_DEVICE_ID));
} else if (GPU_PARAM_VALUE(VA_SHARED) == contextTypeStr) {
m_va_display = _va_device = _ObjFromParamSimple<gpu_handle_param>(params, GPU_PARAM_KEY(VA_DEVICE));
m_type = ContextType::DEV_SHARED;
@ -222,16 +229,13 @@ CLDNNExecutionContextImpl::CLDNNExecutionContextImpl(const std::shared_ptr<IInfe
if (tile_id_itr != params.end()) {
target_tile_id = tile_id_itr->second.as<int>();
}
if (params.find(GPU_PARAM_KEY(OCL_QUEUE)) != params.end())
m_external_queue = _ObjFromParamSimple<gpu_handle_param>(params, GPU_PARAM_KEY(OCL_QUEUE));
}
// TODO: Parameterize this based on plugin config and compilation options
auto engine_type = cldnn::engine_types::ocl;
auto runtime_type = cldnn::runtime_types::ocl;
// Use actual runtime and engine types
cldnn::device_query device_query(engine_type, runtime_type, _context_id, _va_device, target_tile_id);
cldnn::device_query device_query(engine_type, runtime_type, _context_id, _va_device, ctx_device_id, target_tile_id);
auto device_map = device_query.get_available_devices();
auto iter = device_map.find(m_config.device_id);
@ -273,6 +277,7 @@ ParamMap CLDNNExecutionContextImpl::getParams() const {
switch (m_type) {
case OCL:
ret[GPU_PARAM_KEY(CONTEXT_TYPE)] = GPU_PARAM_VALUE(OCL);
ret[GPU_PARAM_KEY(OCL_QUEUE)] = static_cast<gpu_handle_param>(m_external_queue);
break;
case DEV_SHARED:
ret[GPU_PARAM_KEY(CONTEXT_TYPE)] = GPU_PARAM_VALUE(VA_SHARED);
@ -287,6 +292,19 @@ ParamMap CLDNNExecutionContextImpl::getParams() const {
std::string CLDNNExecutionContextImpl::getDeviceName() const noexcept {
auto devName = m_plugin.lock()->GetName();
auto engine_type = cldnn::engine_types::ocl;
auto runtime_type = cldnn::runtime_types::ocl;
// Use actual runtime and engine types
cldnn::device_query device_query(engine_type, runtime_type);
auto all_devices = device_query.get_available_devices();
auto current_device = m_engine->get_device();
for (auto& kv : all_devices) {
if (current_device->is_same(kv.second))
return devName + "." + kv.first;
}
if (!m_config.device_id.empty())
devName += "." + m_config.device_id;
return devName;

View File

@ -62,6 +62,12 @@ DECLARE_GPU_PARAM_VALUE(VA_SHARED);
*/
DECLARE_GPU_PARAM_KEY(OCL_CONTEXT, gpu_handle_param);
/**
* @brief This key identifies ID of device in OpenCL context
* if multiple devices are present in the context
*/
DECLARE_GPU_PARAM_KEY(OCL_CONTEXT_DEVICE_ID, int);
/**
* @brief In case of multi-tile system,
* this key identifies tile within given context

View File

@ -7,7 +7,7 @@
* shared Video Acceleration device contexts
* and shared memory tensors which contain Video Acceleration surfaces
*
* @file openvino/runtime/gpu/dx.hpp
* @file openvino/runtime/gpu/ocl/dx.hpp
*/
#pragma once
@ -15,16 +15,21 @@
# define NOMINMAX
#endif
#ifndef _WIN32
# error "OpenCL DirectX interoperability is supported only on Windows platforms"
#endif
#include <d3d11.h>
#include <memory>
#include <string>
#include "openvino/runtime/gpu/ocl.hpp"
#include "openvino/runtime/gpu/ocl/ocl.hpp"
namespace ov {
namespace runtime {
namespace gpu {
namespace ocl {
/**
* @brief This class represents an abstraction for GPU plugin remote tensor
@ -122,12 +127,15 @@ public:
* @brief Constructs D3DContext remote context object from ID3D11Device
* @param core OpenVINO Runtime Core object instance
* @param device A pointer to ID3D11Device to be used to create a remote context
* @param target_tile_id Desired tile id within given context for multi-tile system. Default value (-1) means
* that root device should be used
*/
D3DContext(Core& core, ID3D11Device* device) : ClContext(core, (cl_context) nullptr) {
D3DContext(Core& core, ID3D11Device* device, int target_tile_id = -1) : ClContext(core, (cl_context) nullptr) {
// clang-format off
ParamMap context_params = {
{GPU_PARAM_KEY(CONTEXT_TYPE), GPU_PARAM_VALUE(VA_SHARED)},
{GPU_PARAM_KEY(VA_DEVICE), static_cast<gpu_handle_param>(device)}
{GPU_PARAM_KEY(VA_DEVICE), static_cast<gpu_handle_param>(device)},
{GPU_PARAM_KEY(TILE_ID), target_tile_id}
};
*this = core.create_context(device_name, context_params);
}
@ -183,6 +191,7 @@ public:
return create_tensor(type, shape, params);
}
};
} // namespace ocl
} // namespace gpu
} // namespace runtime
} // namespace ov

View File

@ -6,7 +6,7 @@
* @brief a header that defines wrappers for internal GPU plugin-specific
* OpenCL context and OpenCL shared memory tensors
*
* @file openvino/runtime/gpu/ocl.hpp
* @file openvino/runtime/gpu/ocl/ocl.hpp
*/
#pragma once
@ -15,13 +15,14 @@
#include "gpu/gpu_params.hpp"
#include "openvino/runtime/core.hpp"
#include "openvino/runtime/gpu/ocl_wrapper.hpp"
#include "openvino/runtime/gpu/ocl/ocl_wrapper.hpp"
#include "openvino/runtime/remote_context.hpp"
#include "openvino/runtime/remote_tensor.hpp"
namespace ov {
namespace runtime {
namespace gpu {
namespace ocl {
/**
* @brief Shortcut for defining a handle parameter
@ -146,10 +147,12 @@ public:
* @brief Constructs context object from user-supplied OpenCL context handle
* @param core A reference to OpenVINO Runtime Core object
* @param ctx A OpenCL context to be used to create shared remote context
* @param ctx_device_id An ID of device to be used from ctx
*/
ClContext(Core& core, cl_context ctx) {
ClContext(Core& core, cl_context ctx, int ctx_device_id = 0) {
ParamMap context_params = {{GPU_PARAM_KEY(CONTEXT_TYPE), GPU_PARAM_VALUE(OCL)},
{GPU_PARAM_KEY(OCL_CONTEXT), static_cast<gpu_handle_param>(ctx)}};
{GPU_PARAM_KEY(OCL_CONTEXT), static_cast<gpu_handle_param>(ctx)},
{GPU_PARAM_KEY(OCL_CONTEXT_DEVICE_ID), ctx_device_id}};
*this = core.create_context(device_name, context_params);
}
@ -250,6 +253,7 @@ public:
return create_tensor(type, shape, params);
}
};
} // namespace ocl
} // namespace gpu
} // namespace runtime
} // namespace ov

View File

@ -7,14 +7,18 @@
* shared Video Acceleration device contexts
* and shared memory tensors which contain Video Acceleration surfaces
*
* @file openvino/runtime/gpu/va.hpp
* @file openvino/runtime/gpu/ocl/va.hpp
*/
#pragma once
#ifdef _WIN32
# error "OpenCL VA-API interoperability is supported only on Linux-based platforms"
#endif
#include <memory>
#include <string>
#include "openvino/runtime/gpu/ocl.hpp"
#include "openvino/runtime/gpu/ocl/ocl.hpp"
// clang-format off
#include <va/va.h>
@ -23,6 +27,7 @@
namespace ov {
namespace runtime {
namespace gpu {
namespace ocl {
/**
* @brief This class represents an abstraction for GPU plugin remote tensor
@ -92,10 +97,13 @@ public:
* @brief Constructs remote context object from VA display handle
* @param core OpenVINO Runtime Core object
* @param device A `VADisplay` to create remote context from
* @param target_tile_id Desired tile id within given context for multi-tile system. Default value (-1) means
* that root device should be used
*/
VAContext(Core& core, VADisplay device) : ClContext(core, (cl_context) nullptr) {
VAContext(Core& core, VADisplay device, int target_tile_id = -1) : ClContext(core, (cl_context) nullptr) {
ParamMap context_params = {{GPU_PARAM_KEY(CONTEXT_TYPE), GPU_PARAM_VALUE(VA_SHARED)},
{GPU_PARAM_KEY(VA_DEVICE), static_cast<gpu_handle_param>(device)}};
{GPU_PARAM_KEY(VA_DEVICE), static_cast<gpu_handle_param>(device)},
{GPU_PARAM_KEY(TILE_ID), target_tile_id}};
*this = core.create_context(device_name, context_params);
}
@ -137,6 +145,7 @@ public:
return create_tensor(type, shape, params);
}
};
} // namespace ocl
} // namespace gpu
} // namespace runtime
} // namespace ov

View File

@ -149,18 +149,18 @@ function(ie_headers_compilation_with_custom_flags)
"gpu/gpu_context_api_va.hpp"
"gpu/gpu_context_api_dx.hpp"
"gpu/gpu_ocl_wrapper.hpp"
"openvino/runtime/gpu/ocl_wrapper.hpp"
"openvino/runtime/gpu/ocl.hpp"
"openvino/runtime/gpu/va.hpp"
"openvino/runtime/gpu/dx.hpp")
"openvino/runtime/gpu/ocl/ocl_wrapper.hpp"
"openvino/runtime/gpu/ocl/ocl.hpp"
"openvino/runtime/gpu/ocl/va.hpp"
"openvino/runtime/gpu/ocl/dx.hpp")
endif()
if(NOT WIN32)
list(APPEND IE_TEST_HEADERS_TO_SKIP "gpu/gpu_context_api_dx.hpp"
"openvino/runtime/gpu/dx.hpp")
"openvino/runtime/gpu/ocl/dx.hpp")
endif()
if(NOT LIBVA_FOUND)
list(APPEND IE_TEST_HEADERS_TO_SKIP "gpu/gpu_context_api_va.hpp"
"openvino/runtime/gpu/va.hpp")
"openvino/runtime/gpu/ocl/va.hpp")
endif()
endif()

View File

@ -0,0 +1,110 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include "openvino/runtime/core.hpp"
#include <gpu/gpu_config.hpp>
#include <common_test_utils/test_common.hpp>
#include <functional_test_utils/plugin_cache.hpp>
#include "ngraph_functions/subgraph_builders.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "openvino/core/preprocess/pre_post_process.hpp"
#include "transformations/utils/utils.hpp"
using namespace ::testing;
using ConcurrencyTestParams = std::tuple<size_t, // number of streams
size_t>; // number of requests
class OVConcurrencyTest : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<ConcurrencyTestParams> {
void SetUp() override {
std::tie(num_streams, num_requests) = this->GetParam();
fn_ptrs = {ngraph::builder::subgraph::makeSplitMultiConvConcat(),
ngraph::builder::subgraph::makeMultiSingleConv()};
};
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConcurrencyTestParams>& obj) {
size_t streams, requests;
std::tie(streams, requests) = obj.param;
return "_num_streams_" + std::to_string(streams) + "_num_req_" +
std::to_string(requests);
}
protected:
size_t num_streams;
size_t num_requests;
std::vector<std::shared_ptr<ngraph::Function>> fn_ptrs;
};
TEST_P(OVConcurrencyTest, canInferTwoExecNets) {
auto ie = ov::runtime::Core();
ov::ResultVector outputs;
std::vector<ov::runtime::InferRequest> irs;
std::vector<std::vector<uint8_t>> ref;
std::vector<int> outElementsCount;
for (size_t i = 0; i < fn_ptrs.size(); ++i) {
auto fn = fn_ptrs[i];
auto exec_net = ie.compile_model(fn_ptrs[i], CommonTestUtils::DEVICE_GPU,
{{ov::ie::PluginConfigParams::KEY_GPU_THROUGHPUT_STREAMS, std::to_string(num_streams)}});
auto input = fn_ptrs[i]->get_parameters().at(0);
auto output = fn_ptrs[i]->get_results().at(0);
for (int j = 0; j < num_streams * num_requests; j++) {
outputs.push_back(output);
auto inf_req = exec_net.create_infer_request();
irs.push_back(inf_req);
auto tensor = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req.set_tensor(input, tensor);
outElementsCount.push_back(ov::shape_size(fn_ptrs[i]->get_output_shape(0)));
const auto in_tensor = inf_req.get_tensor(input);
const auto tensorSize = in_tensor.get_byte_size();
const auto inBlobBuf = static_cast<uint8_t*>(in_tensor.data());
std::vector<uint8_t> inData(inBlobBuf, inBlobBuf + tensorSize);
auto reOutData = ngraph::helpers::interpreterFunction(fn_ptrs[i], {inData}).front().second;
ref.push_back(reOutData);
}
}
const int niter = 10;
for (int i = 0; i < niter; i++) {
for (auto ir : irs) {
ir.start_async();
}
for (auto ir : irs) {
ir.wait();
}
}
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
for (size_t i = 0; i < irs.size(); ++i) {
const auto &refBuffer = ref[i].data();
ASSERT_EQ(outElementsCount[i], irs[i].get_tensor(outputs[i]).get_size());
FuncTestUtils::compareRawBuffers(irs[i].get_tensor(outputs[i]).data<float>(),
reinterpret_cast<const float *>(refBuffer), outElementsCount[i],
outElementsCount[i],
thr);
}
}
const std::vector<size_t> num_streams{ 1, 2 };
const std::vector<size_t> num_requests{ 1, 4 };
INSTANTIATE_TEST_SUITE_P(smoke_RemoteTensor, OVConcurrencyTest,
::testing::Combine(::testing::ValuesIn(num_streams),
::testing::ValuesIn(num_requests)),
OVConcurrencyTest::getTestCaseName);

View File

@ -1,367 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include "openvino/runtime/gpu/ocl.hpp"
#include "openvino/runtime/core.hpp"
#include <gpu/gpu_config.hpp>
#include <remote_blob_tests/remote_blob_helpers.hpp>
#include <common_test_utils/test_common.hpp>
#include <functional_test_utils/plugin_cache.hpp>
#include "ngraph_functions/subgraph_builders.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "openvino/core/preprocess/pre_post_process.hpp"
#include "transformations/utils/utils.hpp"
using namespace ::testing;
class OVRemoteTensor_Test : public CommonTestUtils::TestsCommon {
protected:
std::shared_ptr<ngraph::Function> fn_ptr;
void SetUp() override {
fn_ptr = ngraph::builder::subgraph::makeSplitMultiConvConcat();
}
};
TEST_F(OVRemoteTensor_Test, DISABLED_smoke_canInputUserTensor) {
#if defined(ANDROID)
GTEST_SKIP();
#endif
auto ie = ov::runtime::Core();
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr);
auto exec_net = ie.compile_model(function, CommonTestUtils::DEVICE_GPU);
// regular inference
auto inf_req_regular = exec_net.create_infer_request();
auto input = function->get_parameters().at(0);
auto output = function->get_results().at(0);
auto fakeImageData = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req_regular.set_tensor(input->get_friendly_name(), fakeImageData);
inf_req_regular.infer();
auto output_tensor_regular = inf_req_regular.get_tensor(ngraph::op::util::create_ie_output_name(output->input_value(0)));
// inference using remote tensor
auto inf_req_shared = exec_net.create_infer_request();
auto cldnn_context = exec_net.get_context().as<ov::runtime::gpu::ClContext>();
cl_context ctx = cldnn_context;
auto ocl_instance = std::make_shared<OpenCL>(ctx);
cl_int err;
auto imSize = ov::shape_size(input->get_shape());
cl::Buffer shared_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, imSize, NULL, &err);
{
void* buffer = fakeImageData.data();
ocl_instance->_queue.enqueueWriteBuffer(shared_buffer, true, 0, imSize, buffer);
}
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input->get_shape(), shared_buffer);
inf_req_shared.set_tensor(input->get_friendly_name(), cldnn_tensor);
inf_req_shared.infer();
auto output_tensor_shared = inf_req_shared.get_tensor(ngraph::op::util::create_ie_output_name(output->input_value(0)));
// compare results
{
ASSERT_EQ(output->get_element_type(), ov::element::f32);
ASSERT_EQ(output_tensor_regular.get_size(), output_tensor_shared.get_size());
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
ASSERT_NO_THROW(output_tensor_regular.data());
ASSERT_NO_THROW(output_tensor_shared.data());
FuncTestUtils::compare_tensor(output_tensor_regular, output_tensor_shared, thr);
}
}
TEST_F(OVRemoteTensor_Test, DISABLED_smoke_canInferOnUserContext) {
auto ie = ov::runtime::Core();
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr);
auto exec_net_regular = ie.compile_model(function, CommonTestUtils::DEVICE_GPU);
auto input = function->get_parameters().at(0);
auto output = function->get_results().at(0);
// regular inference
auto inf_req_regular = exec_net_regular.create_infer_request();
auto fakeImageData = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req_regular.set_tensor(input->get_friendly_name(), fakeImageData);
inf_req_regular.infer();
auto output_tensor_regular = inf_req_regular.get_tensor(ngraph::op::util::create_ie_output_name(output->input_value(0)));
// inference using remote tensor
auto ocl_instance = std::make_shared<OpenCL>();
auto remote_context = ov::runtime::gpu::ClContext(ie, ocl_instance->_context.get());
auto exec_net_shared = ie.compile_model(function, remote_context);
auto inf_req_shared = exec_net_shared.create_infer_request();
inf_req_shared.set_tensor(input->get_friendly_name(), fakeImageData);
inf_req_shared.infer();
auto output_tensor_shared = inf_req_shared.get_tensor(ngraph::op::util::create_ie_output_name(output->input_value(0)));
// compare results
{
ASSERT_EQ(output->get_element_type(), ov::element::f32);
ASSERT_EQ(output_tensor_regular.get_size(), output_tensor_shared.get_size());
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
ASSERT_NO_THROW(output_tensor_regular.data());
ASSERT_NO_THROW(output_tensor_shared.data());
FuncTestUtils::compare_tensor(output_tensor_regular, output_tensor_shared, thr);
}
}
class OVRemoteTensorBatched_Test : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<size_t> {
void SetUp() override {
num_batch = this->GetParam();
};
public:
static std::string getTestCaseName(const testing::TestParamInfo<std::size_t> &obj) {
return "num_batch_" + std::to_string(obj.param);
}
protected:
size_t num_batch;
std::vector<std::shared_ptr<ngraph::Function>> fn_ptrs;
};
TEST_P(OVRemoteTensorBatched_Test, DISABLED_canInputNV12) {
#if defined(ANDROID)
GTEST_SKIP();
#endif
const int height = 16;
const int width = 16;
// ------------------------------------------------------
// Prepare input data
std::vector<ov::runtime::Tensor> fake_image_data_y;
std::vector<ov::runtime::Tensor> fake_image_data_uv;
for (int i = 0; i < num_batch; i++) {
fake_image_data_y.push_back(FuncTestUtils::create_and_fill_tensor(ov::element::u8, {1, 1, height, width}, 50, 0, 1, i));
fake_image_data_uv.push_back(FuncTestUtils::create_and_fill_tensor(ov::element::u8, {1, 2, height / 2, width / 2}, 256, 0, 1, i));
}
auto ie = ov::runtime::Core();
// ------------------------------------------------------
// inference using remote tensor with batch
auto fn_ptr_remote = ngraph::builder::subgraph::makeConvPoolRelu({num_batch, 3, height, width});
// TODO: Add preprocessing!
// CNNNetwork net_remote(fn_ptr_remote);
// net_remote.getInputsInfo().begin()->second->setLayout(Layout::NCHW);
// net_remote.getInputsInfo().begin()->second->setPrecision(Precision::U8);
// net_remote.getInputsInfo().begin()->second->getPreProcess().setColorFormat(ColorFormat::NV12);
/* XXX: is it correct to set KEY_CLDNN_NV12_TWO_INPUTS in case of remote tensor? */
auto exec_net_b = ie.compile_model(fn_ptr_remote, CommonTestUtils::DEVICE_GPU,
{ { ov::ie::GPUConfigParams::KEY_GPU_NV12_TWO_INPUTS, ov::ie::PluginConfigParams::YES} });
auto inf_req_remote = exec_net_b.create_infer_request();
auto cldnn_context = exec_net_b.get_context().as<ov::runtime::gpu::ClContext>();
cl_context ctx = cldnn_context.get();
auto ocl_instance = std::make_shared<OpenCL>(ctx);
cl_int err;
std::vector<cl_mem> nv12_image_plane_y, nv12_image_plane_uv;
std::vector<cl::Image2D> img_y, img_uv;
std::vector<std::pair<ov::runtime::RemoteTensor, ov::runtime::RemoteTensor>> tensor_remote;
for (int i = 0; i < num_batch; i++) {
cl_image_format image_format;
cl_image_desc image_desc = { 0 };
image_format.image_channel_order = CL_R;
image_format.image_channel_data_type = CL_UNORM_INT8;
image_desc.image_type = CL_MEM_OBJECT_IMAGE2D;
image_desc.image_width = width;
image_desc.image_height = height;
nv12_image_plane_y.push_back(clCreateImage(ocl_instance->_context.get(), CL_MEM_READ_WRITE, &image_format, &image_desc, NULL, &err));
ASSERT_EQ(err, 0);
image_format.image_channel_order = CL_RG;
image_desc.image_width = width / 2;
image_desc.image_height = height / 2;
nv12_image_plane_uv.push_back(clCreateImage(ocl_instance->_context.get(), CL_MEM_READ_WRITE, &image_format, &image_desc, NULL, &err));
ASSERT_EQ(err, 0);
size_t origin[3] = { 0, 0, 0 };
size_t y_region[3] = { (size_t)width, (size_t)height, 1 };
size_t uv_region[3] = { (size_t)width / 2, (size_t)height / 2, 1 };
err = clEnqueueWriteImage(ocl_instance->_queue.get(), nv12_image_plane_y[i],
true, origin, y_region, 0, 0, fake_image_data_y[i].data(), 0, NULL, NULL);
ASSERT_EQ(err, 0);
err = clEnqueueWriteImage(ocl_instance->_queue.get(), nv12_image_plane_uv[i],
true, origin, uv_region, 0, 0, fake_image_data_uv[i].data(), 0, NULL, NULL);
ASSERT_EQ(err, 0);
img_y.push_back(cl::Image2D(nv12_image_plane_y[i]));
img_uv.push_back(cl::Image2D(nv12_image_plane_uv[i]));
tensor_remote.push_back(cldnn_context.create_tensor_nv12(img_y[i], img_uv[i]));
}
if (num_batch == 1) {
inf_req_remote.set_tensor(fn_ptr_remote->get_parameters().front()->get_friendly_name() + "/y", tensor_remote[0].first);
inf_req_remote.set_tensor(fn_ptr_remote->get_parameters().front()->get_friendly_name() + "/uv", tensor_remote[0].second);
} else {
GTEST_SKIP() << "Not implemented test";
}
inf_req_remote.infer();
auto outputTensor_shared = inf_req_remote.get_tensor(
ngraph::op::util::create_ie_output_name(fn_ptr_remote->get_results().front()->input_value(0)));
// ------------------------------------------------------
// Setup to inference using local tensor with batch=1
auto fn_ptr_local = ngraph::builder::subgraph::makeConvPoolRelu({1, 3, height, width});
// net_local.getInputsInfo().begin()->second->setLayout(Layout::NCHW);
// net_local.getInputsInfo().begin()->second->setPrecision(Precision::U8);
// net_local.getInputsInfo().begin()->second->getPreProcess().setColorFormat(ColorFormat::NV12);
auto exec_net_b1 = ie.compile_model(fn_ptr_local, CommonTestUtils::DEVICE_GPU);
auto inf_req_local = exec_net_b1.create_infer_request();
// Run regular input for each image and compare against batched tensor
for (int i = 0; i < num_batch; i++) {
auto y_tensor = ov::runtime::Tensor{ov::element::u8, {1, 1, height, width}};
auto uv_tensor = ov::runtime::Tensor{ov::element::u8, {1, 2, height / 2, width / 2}};
inf_req_local.set_tensor(fn_ptr_local->get_parameters().front()->get_friendly_name() + "/y", y_tensor);
inf_req_local.set_tensor(fn_ptr_local->get_parameters().front()->get_friendly_name() + "/uv", uv_tensor);
inf_req_local.infer();
auto output_tensor_local = inf_req_local.get_tensor(
ngraph::op::util::create_ie_output_name(fn_ptr_local->get_results().front()->input_value(0)));
// This network generates [1, size] tensor whether batch=1 or 2. So need to split
auto split_shared_tensor = ov::runtime::Tensor{output_tensor_local.get_element_type(),
output_tensor_local.get_shape(),
outputTensor_shared.data<float_t>() + output_tensor_local.get_size() * i};
ASSERT_EQ(output_tensor_local.get_size(), split_shared_tensor.get_size());
float thr = 0.1;
FuncTestUtils::compare_tensor(output_tensor_local, split_shared_tensor, thr, "", false);
}
}
const std::vector<size_t> num_batches{1, 2, 4};
INSTANTIATE_TEST_SUITE_P(smoke_RemoteTensor, OVRemoteTensorBatched_Test, ::testing::ValuesIn(num_batches), OVRemoteTensorBatched_Test::getTestCaseName);
using TwoNetsParams = std::tuple<size_t, // number of streams
size_t>; // number of requests
class OVRemoteTensorTwoNets_Test : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<TwoNetsParams> {
void SetUp() override {
std::tie(num_streams, num_requests) = this->GetParam();
fn_ptrs = {ngraph::builder::subgraph::makeSplitMultiConvConcat(),
ngraph::builder::subgraph::makeMultiSingleConv()};
};
public:
static std::string getTestCaseName(const testing::TestParamInfo<TwoNetsParams>& obj) {
size_t streams, requests;
std::tie(streams, requests) = obj.param;
return "_num_streams_" + std::to_string(streams) + "_num_req_" +
std::to_string(requests);
}
protected:
size_t num_streams;
size_t num_requests;
std::vector<std::shared_ptr<ngraph::Function>> fn_ptrs;
};
TEST_P(OVRemoteTensorTwoNets_Test, DISABLED_canInferTwoExecNets) {
auto ie = ov::runtime::Core();
std::vector<std::string> outputs;
std::vector<ov::runtime::InferRequest> irs;
std::vector<std::vector<uint8_t>> ref;
std::vector<int> outElementsCount;
for (size_t i = 0; i < fn_ptrs.size(); ++i) {
auto fn = fn_ptrs[i];
auto exec_net = ie.compile_model(fn_ptrs[i], CommonTestUtils::DEVICE_GPU,
{{ov::ie::PluginConfigParams::KEY_GPU_THROUGHPUT_STREAMS, std::to_string(num_streams)}});
auto input = fn_ptrs[i]->get_parameters().at(0);
auto output = fn_ptrs[i]->get_results().at(0);
for (int j = 0; j < num_streams * num_requests; j++) {
outputs.push_back(ngraph::op::util::create_ie_output_name(output->input_value(0)));
auto inf_req = exec_net.create_infer_request();
irs.push_back(inf_req);
auto tensor = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req.set_tensor(input->get_friendly_name(), tensor);
outElementsCount.push_back(
std::accumulate(begin(fn_ptrs[i]->get_output_shape(0)), end(fn_ptrs[i]->get_output_shape(0)), 1,
std::multiplies<size_t>()));
const auto in_tensor = inf_req.get_tensor(input->get_friendly_name());
const auto tensorSize = in_tensor.get_byte_size();
const auto inBlobBuf = static_cast<uint8_t*>(in_tensor.data());
std::vector<uint8_t> inData(inBlobBuf, inBlobBuf + tensorSize);
auto reOutData = ngraph::helpers::interpreterFunction(fn_ptrs[i], {inData}).front().second;
ref.push_back(reOutData);
}
}
const int niter = 10;
for (int i = 0; i < niter; i++) {
for (auto ir : irs) {
ir.start_async();
}
for (auto ir : irs) {
ir.wait();
}
}
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
for (size_t i = 0; i < irs.size(); ++i) {
const auto &refBuffer = ref[i].data();
ASSERT_EQ(outElementsCount[i], irs[i].get_tensor(outputs[i]).get_size());
FuncTestUtils::compareRawBuffers(irs[i].get_tensor(outputs[i]).data<float>(),
reinterpret_cast<const float *>(refBuffer), outElementsCount[i],
outElementsCount[i],
thr);
}
}
const std::vector<size_t> num_streams{ 1, 2 };
const std::vector<size_t> num_requests{ 1, 4 };
INSTANTIATE_TEST_SUITE_P(smoke_RemoteTensor, OVRemoteTensorTwoNets_Test,
::testing::Combine(::testing::ValuesIn(num_streams),
::testing::ValuesIn(num_requests)),
OVRemoteTensorTwoNets_Test::getTestCaseName);

View File

@ -0,0 +1,479 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include "openvino/runtime/gpu/ocl/ocl.hpp"
#include "openvino/runtime/core.hpp"
#include <gpu/gpu_config.hpp>
#include <remote_blob_tests/remote_blob_helpers.hpp>
#include <common_test_utils/test_common.hpp>
#include <functional_test_utils/plugin_cache.hpp>
#include "ngraph_functions/subgraph_builders.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "openvino/core/preprocess/pre_post_process.hpp"
#include "transformations/utils/utils.hpp"
using namespace ::testing;
class OVRemoteTensor_Test : public CommonTestUtils::TestsCommon {
protected:
std::shared_ptr<ngraph::Function> fn_ptr;
void SetUp() override {
fn_ptr = ngraph::builder::subgraph::makeSplitMultiConvConcat();
}
};
TEST_F(OVRemoteTensor_Test, smoke_canInputUserTensor) {
#if defined(ANDROID)
GTEST_SKIP();
#endif
auto ie = ov::runtime::Core();
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr);
auto exec_net = ie.compile_model(function, CommonTestUtils::DEVICE_GPU);
// regular inference
auto inf_req_regular = exec_net.create_infer_request();
auto input = function->get_parameters().at(0);
auto output = function->get_results().at(0);
auto fakeImageData = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req_regular.set_tensor(input, fakeImageData);
inf_req_regular.infer();
auto output_tensor_regular = inf_req_regular.get_tensor(output);
// inference using remote tensor
auto inf_req_shared = exec_net.create_infer_request();
auto cldnn_context = exec_net.get_context().as<ov::runtime::gpu::ocl::ClContext>();
cl_context ctx = cldnn_context;
auto ocl_instance = std::make_shared<OpenCL>(ctx);
cl_int err;
auto imSize = ov::shape_size(input->get_shape());
cl::Buffer shared_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, imSize, NULL, &err);
{
void* buffer = fakeImageData.data();
ocl_instance->_queue.enqueueWriteBuffer(shared_buffer, true, 0, imSize, buffer);
}
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input->get_shape(), shared_buffer);
inf_req_shared.set_tensor(input, cldnn_tensor);
inf_req_shared.infer();
auto output_tensor_shared = inf_req_shared.get_tensor(output);
// compare results
{
ASSERT_EQ(output->get_element_type(), ov::element::f32);
ASSERT_EQ(output_tensor_regular.get_size(), output_tensor_shared.get_size());
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
ASSERT_NO_THROW(output_tensor_regular.data());
ASSERT_NO_THROW(output_tensor_shared.data());
FuncTestUtils::compare_tensor(output_tensor_regular, output_tensor_shared, thr);
}
}
TEST_F(OVRemoteTensor_Test, smoke_canInferOnUserContext) {
auto ie = ov::runtime::Core();
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr);
auto exec_net_regular = ie.compile_model(function, CommonTestUtils::DEVICE_GPU);
auto input = function->get_parameters().at(0);
auto output = function->get_results().at(0);
// regular inference
auto inf_req_regular = exec_net_regular.create_infer_request();
auto fakeImageData = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req_regular.set_tensor(input, fakeImageData);
inf_req_regular.infer();
auto output_tensor_regular = inf_req_regular.get_tensor(exec_net_regular.output());
// inference using remote tensor
auto ocl_instance = std::make_shared<OpenCL>();
auto remote_context = ov::runtime::gpu::ocl::ClContext(ie, ocl_instance->_context.get());
auto exec_net_shared = ie.compile_model(function, remote_context);
auto inf_req_shared = exec_net_shared.create_infer_request();
inf_req_shared.set_tensor(input, fakeImageData);
inf_req_shared.infer();
auto output_tensor_shared = inf_req_shared.get_tensor(output);
// compare results
{
ASSERT_EQ(output->get_element_type(), ov::element::f32);
ASSERT_EQ(output_tensor_regular.get_size(), output_tensor_shared.get_size());
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
ASSERT_NO_THROW(output_tensor_regular.data());
ASSERT_NO_THROW(output_tensor_shared.data());
FuncTestUtils::compare_tensor(output_tensor_regular, output_tensor_shared, thr);
}
}
TEST_F(OVRemoteTensor_Test, smoke_canInferOnUserContextWithMultipleDevices) {
auto ie = ov::runtime::Core();
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr);
auto exec_net_regular = ie.compile_model(function, CommonTestUtils::DEVICE_GPU);
auto input = function->get_parameters().at(0);
auto output = function->get_results().at(0);
// regular inference
auto inf_req_regular = exec_net_regular.create_infer_request();
auto fakeImageData = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req_regular.set_tensor(input, fakeImageData);
inf_req_regular.infer();
auto output_tensor_regular = inf_req_regular.get_tensor(exec_net_regular.output());
// inference using remote tensor
auto ocl_instance_tmp = std::make_shared<OpenCL>();
cl::Context multi_device_ctx({ocl_instance_tmp->_device, ocl_instance_tmp->_device});
auto ocl_instance = std::make_shared<OpenCL>(multi_device_ctx.get());
auto remote_context = ov::runtime::gpu::ocl::ClContext(ie, ocl_instance->_context.get(), 1);
ASSERT_EQ(remote_context.get_device_name(), "GPU.0");
auto exec_net_shared = ie.compile_model(function, remote_context);
auto inf_req_shared = exec_net_shared.create_infer_request();
inf_req_shared.set_tensor(input, fakeImageData);
inf_req_shared.infer();
auto output_tensor_shared = inf_req_shared.get_tensor(output);
// compare results
{
ASSERT_EQ(output->get_element_type(), ov::element::f32);
ASSERT_EQ(output_tensor_regular.get_size(), output_tensor_shared.get_size());
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
ASSERT_NO_THROW(output_tensor_regular.data());
ASSERT_NO_THROW(output_tensor_shared.data());
FuncTestUtils::compare_tensor(output_tensor_regular, output_tensor_shared, thr);
}
}
TEST_F(OVRemoteTensor_Test, smoke_canInferOnUserQueue_out_of_order) {
auto ie = ov::runtime::Core();
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr);
auto exec_net_regular = ie.compile_model(function, CommonTestUtils::DEVICE_GPU);
auto input = function->get_parameters().at(0);
auto output = function->get_results().at(0);
// regular inference
auto inf_req_regular = exec_net_regular.create_infer_request();
auto fakeImageData = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req_regular.set_tensor(input, fakeImageData);
inf_req_regular.infer();
auto output_tensor_regular = inf_req_regular.get_tensor(exec_net_regular.output());
auto in_size = ov::shape_size(input->get_output_shape(0)) * input->get_output_element_type(0).size();
auto out_size = ov::shape_size(output->get_output_shape(0)) * output->get_output_element_type(0).size();
// inference using remote tensor
auto ocl_instance = std::make_shared<OpenCL>();
cl_int err;
// Allocate shared buffers for input and output data which will be set to infer request
cl::Buffer shared_input_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, in_size, NULL, &err);
cl::Buffer shared_output_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, out_size, NULL, &err);
auto remote_context = ov::runtime::gpu::ocl::ClContext(ie, ocl_instance->_queue.get());
auto exec_net_shared = ie.compile_model(function, remote_context);
auto gpu_context = exec_net_shared.get_context().as<ov::runtime::gpu::ocl::ClContext>();
auto gpu_in_tensor = gpu_context.create_tensor(input->get_output_element_type(0), input->get_output_shape(0), shared_input_buffer);
auto gpu_out_tensor = gpu_context.create_tensor(output->get_output_element_type(0), output->get_output_shape(0), shared_output_buffer);
auto out_tensor = FuncTestUtils::create_and_fill_tensor(output->get_output_element_type(0), output->get_output_shape(0));
auto inf_req_shared = exec_net_shared.create_infer_request();
inf_req_shared.set_tensor(input, gpu_in_tensor);
inf_req_shared.set_tensor(output, gpu_out_tensor);
// 1. Pre-processing. Enqueue non-blocking copy from host ptr to shared device input buffer and barrier to ensure that copy is finished before
// inference primitives starts execution
{
void* buffer = fakeImageData.data();
ocl_instance->_queue.enqueueWriteBuffer(shared_input_buffer, false, 0, in_size, buffer);
ocl_instance->_queue.enqueueBarrierWithWaitList(nullptr, nullptr);
}
// 2. Enqueue inference primitives. With shared queue this call ensures that all kernels are scheduled to the corresponding queue
// before giving the control back
inf_req_shared.start_async();
// 3. Post-processing. Enqueue copy from shared blob with inference result to another output blob
// Enqueue barrier with empty wait list is needed to ensure that previous kernels are finished before copying the data. It's needed here since we
// create OOO queue.
// Note: inf_req_shared.wait() can be dropped in some cases, but if plugin-side post-processing is required,
// then the result may be incorrect without Wait().
{
ocl_instance->_queue.enqueueBarrierWithWaitList(nullptr, nullptr);
ocl_instance->_queue.enqueueReadBuffer(shared_output_buffer, false, 0, out_size, out_tensor.data(), nullptr, nullptr);
}
// 4. Wait for infer request and post-processing completion
ocl_instance->_queue.finish();
// compare results
{
ASSERT_EQ(output->get_element_type(), ov::element::f32);
ASSERT_EQ(output_tensor_regular.get_size(), out_tensor.get_size());
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
ASSERT_NO_THROW(output_tensor_regular.data());
FuncTestUtils::compare_tensor(output_tensor_regular, out_tensor, thr);
}
}
TEST_F(OVRemoteTensor_Test, smoke_canInferOnUserQueue_in_order) {
auto ie = ov::runtime::Core();
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr);
auto exec_net_regular = ie.compile_model(function, CommonTestUtils::DEVICE_GPU);
auto input = function->get_parameters().at(0);
auto output = function->get_results().at(0);
// regular inference
auto inf_req_regular = exec_net_regular.create_infer_request();
auto fakeImageData = FuncTestUtils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
inf_req_regular.set_tensor(input, fakeImageData);
inf_req_regular.infer();
auto output_tensor_regular = inf_req_regular.get_tensor(exec_net_regular.output());
auto in_size = ov::shape_size(input->get_output_shape(0)) * input->get_output_element_type(0).size();
auto out_size = ov::shape_size(output->get_output_shape(0)) * output->get_output_element_type(0).size();
// inference using remote tensor
auto ocl_instance = std::make_shared<OpenCL>();
ocl_instance->_queue = cl::CommandQueue(ocl_instance->_context, ocl_instance->_device);
cl_int err;
// Allocate shared buffers for input and output data which will be set to infer request
cl::Buffer shared_input_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, in_size, NULL, &err);
cl::Buffer shared_output_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, out_size, NULL, &err);
auto remote_context = ov::runtime::gpu::ocl::ClContext(ie, ocl_instance->_queue.get());
auto exec_net_shared = ie.compile_model(function, remote_context);
auto gpu_context = exec_net_shared.get_context().as<ov::runtime::gpu::ocl::ClContext>();
auto gpu_in_tensor = gpu_context.create_tensor(input->get_output_element_type(0), input->get_output_shape(0), shared_input_buffer);
auto gpu_out_tensor = gpu_context.create_tensor(output->get_output_element_type(0), output->get_output_shape(0), shared_output_buffer);
auto out_tensor = FuncTestUtils::create_and_fill_tensor(output->get_output_element_type(0), output->get_output_shape(0));
auto inf_req_shared = exec_net_shared.create_infer_request();
inf_req_shared.set_tensor(input, gpu_in_tensor);
inf_req_shared.set_tensor(output, gpu_out_tensor);
// 1. Pre-processing. Enqueue non-blocking copy from host ptr to shared device input buffer
{
void* buffer = fakeImageData.data();
ocl_instance->_queue.enqueueWriteBuffer(shared_input_buffer, false, 0, in_size, buffer);
}
// 2. Enqueue inference primitives. With shared queue this call ensures that all kernels are scheduled to the corresponding queue
// before giving the control back
inf_req_shared.start_async();
// 3. Post-processing. Enqueue copy from shared blob with inference result to another output blob
// Note: inf_req_shared.Wait() can be dropped in some cases, but if plugin-side post-processing is required,
// then the result may be incorrect without Wait().
{
ocl_instance->_queue.enqueueReadBuffer(shared_output_buffer, false, 0, out_size, out_tensor.data(), nullptr, nullptr);
}
// 4. Wait for infer request and post-processing completion
ocl_instance->_queue.finish();
// compare results
{
ASSERT_EQ(output->get_element_type(), ov::element::f32);
ASSERT_EQ(output_tensor_regular.get_size(), out_tensor.get_size());
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
ASSERT_NO_THROW(output_tensor_regular.data());
FuncTestUtils::compare_tensor(output_tensor_regular, out_tensor, thr);
}
}
class OVRemoteTensorBatched_Test : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<size_t> {
void SetUp() override {
num_batch = this->GetParam();
};
public:
static std::string getTestCaseName(const testing::TestParamInfo<std::size_t> &obj) {
return "num_batch_" + std::to_string(obj.param);
}
protected:
size_t num_batch;
std::vector<std::shared_ptr<ngraph::Function>> fn_ptrs;
};
TEST_P(OVRemoteTensorBatched_Test, DISABLED_canInputNV12) {
#if defined(ANDROID)
GTEST_SKIP();
#endif
const int height = 16;
const int width = 16;
// ------------------------------------------------------
// Prepare input data
std::vector<ov::runtime::Tensor> fake_image_data_y;
std::vector<ov::runtime::Tensor> fake_image_data_uv;
for (int i = 0; i < num_batch; i++) {
fake_image_data_y.push_back(FuncTestUtils::create_and_fill_tensor(ov::element::u8, {1, 1, height, width}, 50, 0, 1, i));
fake_image_data_uv.push_back(FuncTestUtils::create_and_fill_tensor(ov::element::u8, {1, 2, height / 2, width / 2}, 256, 0, 1, i));
}
auto ie = ov::runtime::Core();
// ------------------------------------------------------
// inference using remote tensor with batch
auto fn_ptr_remote = ngraph::builder::subgraph::makeConvPoolRelu({num_batch, 3, height, width});
using namespace ov::preprocess;
auto function = PrePostProcessor()
.input(InputInfo()
.tensor(InputTensorInfo().set_element_type(ov::element::i8).set_color_format(ov::preprocess::ColorFormat::NV12_TWO_PLANES))
.preprocess(PreProcessSteps().convert_element_type(ov::element::f32)))
.build(fn_ptr_remote);
auto exec_net_b = ie.compile_model(fn_ptr_remote, CommonTestUtils::DEVICE_GPU);
auto inf_req_remote = exec_net_b.create_infer_request();
auto cldnn_context = exec_net_b.get_context().as<ov::runtime::gpu::ocl::ClContext>();
cl_context ctx = cldnn_context.get();
auto ocl_instance = std::make_shared<OpenCL>(ctx);
cl_int err;
std::vector<cl_mem> nv12_image_plane_y, nv12_image_plane_uv;
std::vector<cl::Image2D> img_y, img_uv;
std::vector<std::pair<ov::runtime::RemoteTensor, ov::runtime::RemoteTensor>> tensor_remote;
for (int i = 0; i < num_batch; i++) {
cl_image_format image_format;
cl_image_desc image_desc = { 0 };
image_format.image_channel_order = CL_R;
image_format.image_channel_data_type = CL_UNORM_INT8;
image_desc.image_type = CL_MEM_OBJECT_IMAGE2D;
image_desc.image_width = width;
image_desc.image_height = height;
nv12_image_plane_y.push_back(clCreateImage(ocl_instance->_context.get(), CL_MEM_READ_WRITE, &image_format, &image_desc, NULL, &err));
ASSERT_EQ(err, 0);
image_format.image_channel_order = CL_RG;
image_desc.image_width = width / 2;
image_desc.image_height = height / 2;
nv12_image_plane_uv.push_back(clCreateImage(ocl_instance->_context.get(), CL_MEM_READ_WRITE, &image_format, &image_desc, NULL, &err));
ASSERT_EQ(err, 0);
size_t origin[3] = { 0, 0, 0 };
size_t y_region[3] = { (size_t)width, (size_t)height, 1 };
size_t uv_region[3] = { (size_t)width / 2, (size_t)height / 2, 1 };
err = clEnqueueWriteImage(ocl_instance->_queue.get(), nv12_image_plane_y[i],
true, origin, y_region, 0, 0, fake_image_data_y[i].data(), 0, NULL, NULL);
ASSERT_EQ(err, 0);
err = clEnqueueWriteImage(ocl_instance->_queue.get(), nv12_image_plane_uv[i],
true, origin, uv_region, 0, 0, fake_image_data_uv[i].data(), 0, NULL, NULL);
ASSERT_EQ(err, 0);
img_y.push_back(cl::Image2D(nv12_image_plane_y[i]));
img_uv.push_back(cl::Image2D(nv12_image_plane_uv[i]));
tensor_remote.push_back(cldnn_context.create_tensor_nv12(img_y[i], img_uv[i]));
}
if (num_batch == 1) {
inf_req_remote.set_tensor(fn_ptr_remote->get_parameters().front()->get_friendly_name() + "/y", tensor_remote[0].first);
inf_req_remote.set_tensor(fn_ptr_remote->get_parameters().front()->get_friendly_name() + "/uv", tensor_remote[0].second);
} else {
GTEST_SKIP() << "Not implemented test";
}
inf_req_remote.infer();
auto outputTensor_shared = inf_req_remote.get_tensor(
ngraph::op::util::create_ie_output_name(fn_ptr_remote->get_results().front()->input_value(0)));
// ------------------------------------------------------
// Setup to inference using local tensor with batch=1
auto fn_ptr_local = ngraph::builder::subgraph::makeConvPoolRelu({1, 3, height, width});
// net_local.getInputsInfo().begin()->second->setLayout(Layout::NCHW);
// net_local.getInputsInfo().begin()->second->setPrecision(Precision::U8);
// net_local.getInputsInfo().begin()->second->getPreProcess().setColorFormat(ColorFormat::NV12);
auto exec_net_b1 = ie.compile_model(fn_ptr_local, CommonTestUtils::DEVICE_GPU);
auto inf_req_local = exec_net_b1.create_infer_request();
// Run regular input for each image and compare against batched tensor
for (int i = 0; i < num_batch; i++) {
auto y_tensor = ov::runtime::Tensor{ov::element::u8, {1, 1, height, width}};
auto uv_tensor = ov::runtime::Tensor{ov::element::u8, {1, 2, height / 2, width / 2}};
inf_req_local.set_tensor(fn_ptr_local->get_parameters().front()->get_friendly_name() + "/y", y_tensor);
inf_req_local.set_tensor(fn_ptr_local->get_parameters().front()->get_friendly_name() + "/uv", uv_tensor);
inf_req_local.infer();
auto output_tensor_local = inf_req_local.get_tensor(
ngraph::op::util::create_ie_output_name(fn_ptr_local->get_results().front()->input_value(0)));
// This network generates [1, size] tensor whether batch=1 or 2. So need to split
auto split_shared_tensor = ov::runtime::Tensor{output_tensor_local.get_element_type(),
output_tensor_local.get_shape(),
outputTensor_shared.data<float_t>() + output_tensor_local.get_size() * i};
ASSERT_EQ(output_tensor_local.get_size(), split_shared_tensor.get_size());
float thr = 0.1;
FuncTestUtils::compare_tensor(output_tensor_local, split_shared_tensor, thr, "", false);
}
}
const std::vector<size_t> num_batches{1, 2, 4};
INSTANTIATE_TEST_SUITE_P(smoke_RemoteTensor, OVRemoteTensorBatched_Test, ::testing::ValuesIn(num_batches), OVRemoteTensorBatched_Test::getTestCaseName);

View File

@ -18,6 +18,8 @@ public:
virtual device_info get_info() const = 0;
virtual memory_capabilities get_mem_caps() const = 0;
virtual bool is_same(const device::ptr other) = 0;
virtual ~device() = default;
};

View File

@ -20,6 +20,7 @@ public:
runtime_types runtime_type,
void* user_context = nullptr,
void* user_device = nullptr,
int ctx_device_id = 0,
int target_tile_id = -1);
std::map<std::string, device::ptr> get_available_devices() const {

View File

@ -10,14 +10,19 @@
namespace cldnn {
device_query::device_query(engine_types engine_type, runtime_types runtime_type, void* user_context, void* user_device, int target_tile_id) {
device_query::device_query(engine_types engine_type,
runtime_types runtime_type,
void* user_context,
void* user_device,
int ctx_device_id,
int target_tile_id) {
switch (engine_type) {
case engine_types::ocl: {
if (runtime_type != runtime_types::ocl)
throw std::runtime_error("Unsupported runtime type for ocl engine");
ocl::ocl_device_detector ocl_detector;
_available_devices = ocl_detector.get_available_devices(user_context, user_device, target_tile_id);
_available_devices = ocl_detector.get_available_devices(user_context, user_device, ctx_device_id, target_tile_id);
break;
}
default: throw std::runtime_error("Unsupported engine type in device_query");

View File

@ -300,5 +300,13 @@ ocl_device::ocl_device(const cl::Device dev, const cl::Context& ctx, const cl_pl
, _info(init_device_info(dev))
, _mem_caps(init_memory_caps(dev, _info)) { }
bool ocl_device::is_same(const device::ptr other) {
auto casted = downcast<ocl_device>(other.get());
if (!casted)
return false;
return _context == casted->get_context() && _device == casted->get_device() && _platform == casted->get_platform();
}
} // namespace ocl
} // namespace cldnn

View File

@ -28,6 +28,8 @@ public:
const cl::Context& get_context() const { return _context; }
cl_platform_id get_platform() const { return _platform; }
bool is_same(const device::ptr other) override;
~ocl_device() = default;
private:

View File

@ -91,11 +91,14 @@ static std::vector<cl::Device> getSubDevices(cl::Device& rootDevice) {
return subDevices;
}
std::map<std::string, device::ptr> ocl_device_detector::get_available_devices(void* user_context, void* user_device, int target_tile_id) const {
std::map<std::string, device::ptr> ocl_device_detector::get_available_devices(void* user_context,
void* user_device,
int ctx_device_id,
int target_tile_id) const {
bool host_out_of_order = true; // Change to false, if debug requires in-order queue.
std::vector<device::ptr> dev_orig, dev_sorted;
if (user_context != nullptr) {
dev_orig = create_device_list_from_user_context(host_out_of_order, user_context);
dev_orig = create_device_list_from_user_context(host_out_of_order, user_context, ctx_device_id);
} else if (user_device != nullptr) {
dev_orig = create_device_list_from_user_device(host_out_of_order, user_device);
} else {
@ -171,13 +174,14 @@ std::vector<device::ptr> ocl_device_detector::create_device_list(bool out_out_or
return ret;
}
std::vector<device::ptr> ocl_device_detector::create_device_list_from_user_context(bool out_out_order, void* user_context) const {
std::vector<device::ptr> ocl_device_detector::create_device_list_from_user_context(bool out_out_order, void* user_context, int ctx_device_id) const {
cl::Context ctx = cl::Context(static_cast<cl_context>(user_context), true);
auto all_devices = ctx.getInfo<CL_CONTEXT_DEVICES>();
std::vector<device::ptr> ret;
for (auto& device : all_devices) {
if (!does_device_match_config(out_out_order, device))
for (size_t i = 0; i < all_devices.size(); i++) {
auto& device = all_devices[i];
if (!does_device_match_config(out_out_order, device) || i != ctx_device_id)
continue;
ret.emplace_back(std::make_shared<ocl_device>(device, ctx, device.getInfo<CL_DEVICE_PLATFORM>()));
}

View File

@ -19,10 +19,11 @@ class ocl_device_detector {
public:
ocl_device_detector() = default;
std::map<std::string, device::ptr> get_available_devices(void* user_context, void* user_device, int target_tile_id = -1) const;
std::map<std::string, device::ptr> get_available_devices(void *user_context, void *user_device, int ctx_device_id = 0, int target_tile_id = -1) const;
private:
std::vector<device::ptr> create_device_list(bool out_out_order) const;
std::vector<device::ptr> create_device_list_from_user_context(bool out_out_order, void* user_context) const;
std::vector<device::ptr> create_device_list_from_user_context(bool out_out_order, void* user_context, int ctx_device_id = 0) const;
std::vector<device::ptr> create_device_list_from_user_device(bool out_out_order, void* user_device) const;
};