[GPU] Add batching buffer to new API (#9565)
This commit is contained in:
parent
c7d216db13
commit
baa1a8e59d
@ -412,10 +412,6 @@ void InferRequest::SetBlobs(const std::string& name, const std::vector<Blob::Ptr
|
||||
IE_THROW() << "SetBlobs method doesn't support outputs";
|
||||
}
|
||||
|
||||
if (is_buffer) {
|
||||
IE_THROW(NotImplemented) << "SetBlobs method doesn't support buffer blobs";
|
||||
}
|
||||
|
||||
const TensorDesc& desc = foundInput->getTensorDesc();
|
||||
|
||||
size_t dataBinSize = blobs.front()->size() * blobs.front()->element_size() * blobs.size();
|
||||
@ -445,8 +441,7 @@ void InferRequest::SetBlobs(const std::string& name, const std::vector<Blob::Ptr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inputTensorsMap.insert({ name, blobs });
|
||||
inputTensorsMap[name] = blobs;
|
||||
}
|
||||
|
||||
void InferRequest::checkBlobs() {
|
||||
@ -618,19 +613,56 @@ void InferRequest::enqueue() {
|
||||
std::vector<cldnn::event::ptr> dependencies;
|
||||
|
||||
for (const auto& inputTensor : inputTensorsMap) {
|
||||
const std::string name = inputTensor.first;
|
||||
const auto& blobs = inputTensor.second;
|
||||
|
||||
auto blobsDesc = blobs.front()->getTensorDesc();
|
||||
blobsDesc.getDims().front() = blobs.size();
|
||||
|
||||
bool is_surface = std::all_of(blobs.begin(), blobs.end(), [](const Blob::Ptr& blob) {
|
||||
return blob->is<gpu::ClImage2DBlob>();
|
||||
});
|
||||
bool is_buffer = std::all_of(blobs.begin(), blobs.end(), [](const Blob::Ptr& blob) {
|
||||
return blob->is<gpu::ClBufferBlob>();
|
||||
});
|
||||
bool is_remote = is_buffer || is_surface;
|
||||
|
||||
if (is_surface) {
|
||||
for (size_t i = 0; i < blobs.size(); ++i) {
|
||||
std::string new_name = inputTensor.first + "_" + std::to_string(i);
|
||||
std::string new_name = name + "_" + std::to_string(i);
|
||||
_inputs[new_name] = blobs[i];
|
||||
_deviceInputs[new_name] = blobs[i];
|
||||
}
|
||||
} else {
|
||||
uint8_t* dst = nullptr;
|
||||
if (_deviceInputs.find(name) != _deviceInputs.end()) {
|
||||
if (_deviceInputs[name]->getTensorDesc() == blobsDesc) {
|
||||
dst = _deviceInputs[name]->buffer().as<uint8_t*>();
|
||||
}
|
||||
}
|
||||
if (dst == nullptr) {
|
||||
cldnn::layout layout(DataTypeFromPrecision(blobsDesc.getPrecision()),
|
||||
FormatFromTensorDesc(blobsDesc),
|
||||
tensor_from_dims(blobsDesc.getDims()));
|
||||
|
||||
auto mergedBlobs = std::make_shared<RemoteCLbuffer>(m_graph->GetContext(),
|
||||
m_graph->GetNetwork()->get_stream(),
|
||||
blobsDesc,
|
||||
layout);
|
||||
mergedBlobs->allocate();
|
||||
dst = mergedBlobs->buffer().as<uint8_t*>();
|
||||
|
||||
_inputs[name] = mergedBlobs;
|
||||
if (is_remote) {
|
||||
_deviceInputs[name] = mergedBlobs;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& blob : blobs) {
|
||||
const uint8_t* src = blob->cbuffer().as<const uint8_t*>();
|
||||
std::copy(src, src + blob->byteSize(), dst);
|
||||
dst += blob->byteSize();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -931,5 +931,158 @@ TEST_P(OVRemoteTensorBatched_Test, NV12toBGR_image) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(OVRemoteTensorBatched_Test, NV12toBGR_buffer) {
|
||||
#if defined(ANDROID)
|
||||
GTEST_SKIP();
|
||||
#endif
|
||||
const int height = 16;
|
||||
const int width = 16;
|
||||
|
||||
// ------------------------------------------------------
|
||||
// Prepare input data
|
||||
std::vector<ov::runtime::Tensor> fake_image_data_y, fake_image_data_uv;
|
||||
for (size_t i = 0; i < num_batch * 2; ++i) {
|
||||
fake_image_data_y.push_back(FuncTestUtils::create_and_fill_tensor(ov::element::u8, {1, 1, height, width}, 50, 0, 1, i));
|
||||
fake_image_data_uv.push_back(FuncTestUtils::create_and_fill_tensor(ov::element::u8, {1, 2, height / 2, width / 2}, 256, 0, 1, i));
|
||||
}
|
||||
|
||||
auto ie = ov::runtime::Core();
|
||||
|
||||
// ------------------------------------------------------
|
||||
// inference using remote tensor
|
||||
auto fn_ptr_remote = ngraph::builder::subgraph::makeConvPoolRelu({num_batch, 3, height, width});
|
||||
|
||||
using namespace ov::preprocess;
|
||||
auto p = PrePostProcessor(fn_ptr_remote);
|
||||
p.input().tensor().set_element_type(ov::element::u8)
|
||||
.set_color_format(ov::preprocess::ColorFormat::NV12_TWO_PLANES, {"y", "uv"})
|
||||
.set_memory_type(GPU_CONFIG_KEY(BUFFER));
|
||||
p.input().preprocess().convert_color(ov::preprocess::ColorFormat::BGR);
|
||||
p.input().model().set_layout("NCHW");
|
||||
auto function = p.build();
|
||||
|
||||
auto param_input_y = fn_ptr_remote->get_parameters().at(0);
|
||||
auto param_input_uv = fn_ptr_remote->get_parameters().at(1);
|
||||
auto output = function->get_results().at(0);
|
||||
|
||||
auto ocl_instance = std::make_shared<OpenCL>();
|
||||
ocl_instance->_queue = cl::CommandQueue(ocl_instance->_context, ocl_instance->_device);
|
||||
cl_int err;
|
||||
|
||||
auto in_size_y = ov::shape_size(param_input_y->get_output_shape(0)) * param_input_y->get_output_element_type(0).size();
|
||||
auto in_size_uv = ov::shape_size(param_input_uv->get_output_shape(0)) * param_input_uv->get_output_element_type(0).size();
|
||||
auto out_size = ov::shape_size(output->get_output_shape(0)) * output->get_output_element_type(0).size();
|
||||
|
||||
auto remote_context = ov::runtime::intel_gpu::ocl::ClContext(ie, ocl_instance->_queue.get());
|
||||
auto exec_net_shared = ie.compile_model(function, remote_context);
|
||||
auto gpu_context = exec_net_shared.get_context().as<ov::runtime::intel_gpu::ocl::ClContext>();
|
||||
|
||||
std::vector<cl::Buffer> shared_input_y_buffer, shared_input_uv_buffer;
|
||||
std::vector<ov::runtime::Tensor> gpu_in_y_tensor, gpu_in_uv_tensor;
|
||||
|
||||
for (size_t i = 0; i < num_batch; ++i) {
|
||||
shared_input_y_buffer.emplace_back(cl::Buffer(ocl_instance->_context, CL_MEM_READ_WRITE, in_size_y, NULL, &err));
|
||||
shared_input_uv_buffer.emplace_back(cl::Buffer(ocl_instance->_context, CL_MEM_READ_WRITE, in_size_uv, NULL, &err));
|
||||
|
||||
gpu_in_y_tensor.emplace_back(gpu_context.create_tensor(param_input_y->get_output_element_type(0),
|
||||
fake_image_data_y[i].get_shape(),
|
||||
shared_input_y_buffer[i]));
|
||||
gpu_in_uv_tensor.emplace_back(gpu_context.create_tensor(param_input_uv->get_output_element_type(0),
|
||||
fake_image_data_uv[i].get_shape(),
|
||||
shared_input_uv_buffer[i]));
|
||||
}
|
||||
cl::Buffer shared_output_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, out_size, NULL, &err);
|
||||
auto gpu_out_tensor = gpu_context.create_tensor(output->get_output_element_type(0), output->get_output_shape(0), shared_output_buffer);
|
||||
auto out_tensor = FuncTestUtils::create_and_fill_tensor(output->get_output_element_type(0), output->get_output_shape(0));
|
||||
|
||||
auto inf_req_shared = exec_net_shared.create_infer_request();
|
||||
inf_req_shared.set_tensors(*param_input_y->output(0).get_tensor().get_names().begin(), gpu_in_y_tensor);
|
||||
inf_req_shared.set_tensors(*param_input_uv->output(0).get_tensor().get_names().begin(), gpu_in_uv_tensor);
|
||||
inf_req_shared.set_tensor(output, gpu_out_tensor);
|
||||
|
||||
for (size_t i = 0; i < num_batch; ++i) {
|
||||
void* buffer_y = fake_image_data_y[i].data();
|
||||
void* buffer_uv = fake_image_data_uv[i].data();
|
||||
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_input_y_buffer[i], false, 0, in_size_y, buffer_y);
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_input_uv_buffer[i], false, 0, in_size_uv, buffer_uv);
|
||||
}
|
||||
|
||||
inf_req_shared.start_async();
|
||||
ocl_instance->_queue.enqueueReadBuffer(shared_output_buffer, false, 0, out_size, out_tensor.data(), nullptr, nullptr);
|
||||
ocl_instance->_queue.finish();
|
||||
ASSERT_NO_THROW(out_tensor.data());
|
||||
|
||||
// ------------------------------------------------------
|
||||
// inference using the same InferRequest but with new data
|
||||
inf_req_shared.wait();
|
||||
|
||||
std::vector<cl::Buffer> shared_input_y_buffer_new, shared_input_uv_buffer_new;
|
||||
std::vector<ov::runtime::Tensor> gpu_in_y_tensor_new, gpu_in_uv_tensor_new;
|
||||
for (size_t i = 0; i < num_batch; ++i) {
|
||||
shared_input_y_buffer_new.emplace_back(cl::Buffer(ocl_instance->_context, CL_MEM_READ_WRITE, in_size_y, NULL, &err));
|
||||
shared_input_uv_buffer_new.emplace_back(cl::Buffer(ocl_instance->_context, CL_MEM_READ_WRITE, in_size_uv, NULL, &err));
|
||||
|
||||
gpu_in_y_tensor_new.emplace_back(gpu_context.create_tensor(param_input_y->get_output_element_type(0),
|
||||
fake_image_data_y[i + num_batch].get_shape(),
|
||||
shared_input_y_buffer_new[i]));
|
||||
gpu_in_uv_tensor_new.emplace_back(gpu_context.create_tensor(param_input_uv->get_output_element_type(0),
|
||||
fake_image_data_uv[i + num_batch].get_shape(),
|
||||
shared_input_uv_buffer_new[i]));
|
||||
}
|
||||
cl::Buffer shared_output_buffer_new(ocl_instance->_context, CL_MEM_READ_WRITE, out_size, NULL, &err);
|
||||
auto gpu_out_tensor_new = gpu_context.create_tensor(output->get_output_element_type(0), output->get_output_shape(0), shared_output_buffer_new);
|
||||
auto out_tensor_new = FuncTestUtils::create_and_fill_tensor(output->get_output_element_type(0), output->get_output_shape(0));
|
||||
|
||||
inf_req_shared.set_tensors(*param_input_y->output(0).get_tensor().get_names().begin(), gpu_in_y_tensor_new);
|
||||
inf_req_shared.set_tensors(*param_input_uv->output(0).get_tensor().get_names().begin(), gpu_in_uv_tensor_new);
|
||||
inf_req_shared.set_tensor(output, gpu_out_tensor_new);
|
||||
|
||||
for (size_t i = 0; i < num_batch; ++i) {
|
||||
void* buffer_y = fake_image_data_y[i + num_batch].data();
|
||||
void* buffer_uv = fake_image_data_uv[i + num_batch].data();
|
||||
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_input_y_buffer_new[i], false, 0, in_size_y, buffer_y);
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_input_uv_buffer_new[i], false, 0, in_size_uv, buffer_uv);
|
||||
}
|
||||
inf_req_shared.start_async();
|
||||
ocl_instance->_queue.enqueueReadBuffer(shared_output_buffer_new, false, 0, out_size, out_tensor_new.data(), nullptr, nullptr);
|
||||
ocl_instance->_queue.finish();
|
||||
ASSERT_NO_THROW(out_tensor_new.data());
|
||||
|
||||
// ------------------------------------------------------
|
||||
// regular inference
|
||||
auto fn_ptr_regular = ngraph::builder::subgraph::makeConvPoolRelu({1, 3, height, width});
|
||||
|
||||
using namespace ov::preprocess;
|
||||
auto p_reg = PrePostProcessor(fn_ptr_regular);
|
||||
p_reg.input().tensor().set_element_type(ov::element::u8)
|
||||
.set_color_format(ov::preprocess::ColorFormat::NV12_TWO_PLANES, {"y", "uv"})
|
||||
.set_memory_type(GPU_CONFIG_KEY(BUFFER));
|
||||
p_reg.input().preprocess().convert_color(ov::preprocess::ColorFormat::BGR);
|
||||
p_reg.input().model().set_layout("NCHW");
|
||||
auto function_regular = p_reg.build();
|
||||
|
||||
auto param_input_y_reg = fn_ptr_regular->get_parameters().at(0);
|
||||
auto param_input_uv_reg = fn_ptr_regular->get_parameters().at(1);
|
||||
|
||||
auto exec_net_regular = ie.compile_model(function_regular, CommonTestUtils::DEVICE_GPU);
|
||||
auto inf_req_regular = exec_net_regular.create_infer_request();
|
||||
|
||||
for (size_t i = 0; i < num_batch; ++i) {
|
||||
inf_req_regular.set_tensor(param_input_y_reg, fake_image_data_y[i + num_batch]);
|
||||
inf_req_regular.set_tensor(param_input_uv_reg, fake_image_data_uv[i + num_batch]);
|
||||
inf_req_regular.infer();
|
||||
auto output_tensor_regular = inf_req_regular.get_tensor(exec_net_regular.output());
|
||||
|
||||
ASSERT_EQ(output_tensor_regular.get_size() * num_batch, out_tensor_new.get_size());
|
||||
float thr = 0.1;
|
||||
|
||||
FuncTestUtils::compareRawBuffers<float>(static_cast<float*>(out_tensor_new.data()) + i * output_tensor_regular.get_size(),
|
||||
static_cast<float*>(output_tensor_regular.data()),
|
||||
output_tensor_regular.get_size(), output_tensor_regular.get_size(), thr);
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<size_t> num_batches{ 1, 2, 4 };
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_RemoteTensor, OVRemoteTensorBatched_Test, ::testing::ValuesIn(num_batches), OVRemoteTensorBatched_Test::getTestCaseName);
|
||||
|
Loading…
Reference in New Issue
Block a user