[GPU] Allow setting remote output for dynamic model (#20608)
This commit is contained in:
parent
63299ec217
commit
d9c4ca3021
@ -25,8 +25,17 @@ enum class TensorOwner : uint8_t {
|
||||
};
|
||||
|
||||
struct TensorWrapper {
|
||||
TensorWrapper(const std::shared_ptr<ov::ITensor>& _ptr, TensorOwner _owner)
|
||||
: ptr(_ptr)
|
||||
, owner(_owner)
|
||||
, actual_size(_ptr ? _ptr->get_byte_size() : 0) {}
|
||||
|
||||
TensorWrapper(const TensorWrapper& other) = default;
|
||||
TensorWrapper() = default;
|
||||
|
||||
std::shared_ptr<ov::ITensor> ptr;
|
||||
TensorOwner owner;
|
||||
size_t actual_size;
|
||||
};
|
||||
|
||||
class SyncInferRequest : public ov::ISyncInferRequest {
|
||||
|
@ -3,6 +3,8 @@
|
||||
//
|
||||
|
||||
#include "intel_gpu/plugin/usm_host_tensor.hpp"
|
||||
#include "intel_gpu/runtime/memory.hpp"
|
||||
#include "intel_gpu/runtime/memory_caps.hpp"
|
||||
#include "openvino/runtime/make_tensor.hpp"
|
||||
#include "openvino/core/preprocess/input_tensor_info.hpp"
|
||||
#include "openvino/core/parallel.hpp"
|
||||
@ -415,11 +417,13 @@ void SyncInferRequest::wait() {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::intel_gpu_plugin, "SyncInferRequest::wait");
|
||||
OPENVINO_ASSERT(!m_internal_outputs.empty(), "[GPU] Inference was not started!\n");
|
||||
|
||||
auto& network = *m_graph->get_network();
|
||||
|
||||
// wait for completion & collect outputs as requested by the model
|
||||
// for in_order_queue, it is enough to call finish only once
|
||||
bool do_sync_per_output = (m_graph->get_network()->get_stream().get_queue_type() == QueueTypes::in_order) ? false : true;
|
||||
bool do_sync_per_output = (network.get_stream().get_queue_type() == QueueTypes::in_order) ? false : true;
|
||||
if (!do_sync_per_output)
|
||||
m_graph->get_network()->get_stream().finish();
|
||||
network.get_stream().finish();
|
||||
|
||||
std::vector<cldnn::event::ptr> copy_events;
|
||||
|
||||
@ -442,6 +446,7 @@ void SyncInferRequest::wait() {
|
||||
auto output_tensor = output_tensor_wrapper.ptr;
|
||||
auto remote_ptr = std::dynamic_pointer_cast<RemoteTensorImpl>(output_tensor);
|
||||
bool is_remote = remote_ptr != nullptr;
|
||||
bool is_dynamic = port.get_partial_shape().is_dynamic();
|
||||
|
||||
if (is_remote) {
|
||||
GPU_DEBUG_TRACE_DETAIL << name << " handle output tensor (remote): " << remote_ptr->get_original_memory()->buffer_ptr() << std::endl;
|
||||
@ -449,6 +454,10 @@ void SyncInferRequest::wait() {
|
||||
GPU_DEBUG_TRACE_DETAIL << name << " handle output tensor (host): " << output_tensor->data() << std::endl;
|
||||
}
|
||||
|
||||
OPENVINO_ASSERT(output_tensor_wrapper.owner == TensorOwner::PLUGIN || output_tensor_wrapper.actual_size >= output_memory->size(),
|
||||
"[GPU] Output tensor set by user has smaller size (", output_tensor->get_byte_size(), ") ",
|
||||
"than required (", output_memory->size(), ")");
|
||||
|
||||
bool need_output_update = output_layout.bytes_count() == 0 || (output_memory && output_tensor->get_byte_size() != output_memory->size());
|
||||
if (need_output_update) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::intel_gpu_plugin, "SyncInferRequest::wait::update_output");
|
||||
@ -460,7 +469,7 @@ void SyncInferRequest::wait() {
|
||||
OPENVINO_ASSERT(ov::shape_size(port.get_shape()) == ov::shape_size(mem_shape), "[GPU] Unexpected elements count for output tensor");
|
||||
mem_shape = port.get_shape();
|
||||
}
|
||||
if (port.get_partial_shape().is_dynamic()) {
|
||||
if (is_dynamic) {
|
||||
bool need_reallocate = true;
|
||||
auto usm_host_tensor = std::dynamic_pointer_cast<USMHostTensor>(output_tensor);
|
||||
if (usm_host_tensor && output_memory)
|
||||
@ -488,11 +497,23 @@ void SyncInferRequest::wait() {
|
||||
copy_events.push_back(ev);
|
||||
}
|
||||
}
|
||||
} else if (is_remote && is_dynamic) {
|
||||
auto& stream = m_graph->get_network()->get_stream();
|
||||
auto user_mem = remote_ptr->get_original_memory();
|
||||
if (user_mem->get_allocation_type() == cldnn::allocation_type::cl_mem && output_memory->get_allocation_type() != cldnn::allocation_type::cl_mem) {
|
||||
// WA: Copy between cl_mem and usm memory may fail for some reason (driver bug?)
|
||||
// so this explicit memcpy is used to provide correct output for cl_mem output in dynamic cases
|
||||
cldnn::mem_lock<uint8_t, cldnn::mem_lock_type::write> lock_dst(user_mem, stream);
|
||||
cldnn::mem_lock<uint8_t, cldnn::mem_lock_type::read> lock_src(output_memory, stream);
|
||||
std::memcpy(lock_dst.data(), lock_src.data(), output_memory->size());
|
||||
} else {
|
||||
copy_events.push_back(output_memory->copy_to(stream, *user_mem, false));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!copy_events.empty()) {
|
||||
auto& stream = m_graph->get_network()->get_stream();
|
||||
auto& stream = network.get_stream();
|
||||
if (stream.get_queue_type() == QueueTypes::in_order) {
|
||||
// wait only the last one
|
||||
stream.wait_for_events({copy_events.back()});
|
||||
@ -831,7 +852,7 @@ std::vector<cldnn::event::ptr> SyncInferRequest::prepare_output(const std::strin
|
||||
auto device_tensor_et = convert_to_supported_device_type(element_type);
|
||||
bool convert_needed = is_convert_required(device_tensor_et, element_type);
|
||||
cldnn::primitive_id internal_name = m_output_names_map.at(name);
|
||||
if (is_remote && !convert_needed) {
|
||||
if (is_remote && !convert_needed && !is_dynamic) {
|
||||
m_plugin_outputs[name] = user_tensor_wrapper;
|
||||
}
|
||||
|
||||
|
@ -7,21 +7,25 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "openvino/core/dimension.hpp"
|
||||
#include "openvino/core/except.hpp"
|
||||
#include "openvino/core/model.hpp"
|
||||
#include "openvino/core/preprocess/pre_post_process.hpp"
|
||||
#include "openvino/runtime/intel_gpu/ocl/ocl.hpp"
|
||||
#include "openvino/runtime/core.hpp"
|
||||
#include "openvino/runtime/intel_gpu/properties.hpp"
|
||||
#include "openvino/runtime/properties.hpp"
|
||||
#include "openvino/runtime/remote_tensor.hpp"
|
||||
|
||||
#include <remote_blob_tests/remote_blob_helpers.hpp>
|
||||
#include <common_test_utils/test_common.hpp>
|
||||
#include <functional_test_utils/plugin_cache.hpp>
|
||||
#include "remote_blob_tests/remote_blob_helpers.hpp"
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "common_test_utils/ov_tensor_utils.hpp"
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include "base/ov_behavior_test_utils.hpp"
|
||||
#include "ov_models/subgraph_builders.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "openvino/core/preprocess/pre_post_process.hpp"
|
||||
#include "subgraphs_builders.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
#include "common_test_utils/ov_tensor_utils.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
|
||||
@ -35,6 +39,7 @@ protected:
|
||||
};
|
||||
|
||||
namespace {
|
||||
std::vector<bool> ov_dynamic {true, false};
|
||||
std::vector<bool> ov_with_auto_batching {true, false};
|
||||
enum class RemoteTensorSharingType {
|
||||
USER_CL_TENSOR = 0,
|
||||
@ -61,7 +66,7 @@ std::ostream& operator<<(std::ostream& stream, RemoteTensorSharingType sharing_t
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using RemoteTensorSharingTestOptionsParams = std::tuple<RemoteTensorSharingType, bool /*auto-batching*/>;
|
||||
using RemoteTensorSharingTestOptionsParams = std::tuple<RemoteTensorSharingType, bool /*auto-batching*/, bool /*dynamic*/>;
|
||||
|
||||
class OVRemoteTensorInputBlob_Test : public OVRemoteTensor_Test,
|
||||
public testing::WithParamInterface<RemoteTensorSharingTestOptionsParams> {
|
||||
@ -75,7 +80,8 @@ public:
|
||||
deviceName = ov::test::utils::DEVICE_GPU;
|
||||
RemoteTensorSharingType sharing_type;
|
||||
bool with_auto_batching;
|
||||
std::tie(sharing_type, with_auto_batching) = this->GetParam();
|
||||
bool is_dynamic;
|
||||
std::tie(sharing_type, with_auto_batching, is_dynamic) = this->GetParam();
|
||||
if (with_auto_batching) {
|
||||
config =
|
||||
{ov::hint::performance_mode(ov::hint::PerformanceMode::THROUGHPUT),
|
||||
@ -84,17 +90,24 @@ public:
|
||||
};
|
||||
}
|
||||
fn_ptr = ov::test::behavior::getDefaultNGraphFunctionForTheDevice();
|
||||
if (is_dynamic) {
|
||||
std::map<size_t, ov::PartialShape> target_shape = {{0, ov::PartialShape::dynamic(4)}};
|
||||
fn_ptr->reshape(target_shape);
|
||||
}
|
||||
}
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<RemoteTensorSharingTestOptionsParams>& obj) {
|
||||
RemoteTensorSharingType sharing_type;
|
||||
bool with_auto_batching;
|
||||
std::tie(sharing_type, with_auto_batching) = obj.param;
|
||||
bool is_dynamic;
|
||||
std::tie(sharing_type, with_auto_batching, is_dynamic) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "OVRemoteTensorInputBlob_Test_";
|
||||
result << sharing_type;
|
||||
if (with_auto_batching)
|
||||
result << "_WITH_AUTO_BATCHING";
|
||||
if (is_dynamic)
|
||||
result << "_DYNAMIC";
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
@ -102,8 +115,9 @@ public:
|
||||
TEST_P(OVRemoteTensorInputBlob_Test, smoke_cantCreateBlobWithInvalidSize) {
|
||||
RemoteTensorSharingType sharing_type;
|
||||
bool with_auto_batching;
|
||||
std::tie(sharing_type, with_auto_batching) = GetParam();
|
||||
if (with_auto_batching)
|
||||
bool is_dynamic;
|
||||
std::tie(sharing_type, with_auto_batching, is_dynamic) = GetParam();
|
||||
if (with_auto_batching || is_dynamic)
|
||||
GTEST_SKIP();
|
||||
|
||||
if (sharing_type == RemoteTensorSharingType::PLUGIN_CL_TENSOR ||
|
||||
@ -164,7 +178,8 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
auto function = p.build();
|
||||
RemoteTensorSharingType sharing_type;
|
||||
bool with_auto_batching;
|
||||
std::tie(sharing_type, with_auto_batching) = GetParam();
|
||||
bool is_dynamic;
|
||||
std::tie(sharing_type, with_auto_batching, is_dynamic) = GetParam();
|
||||
|
||||
// auto-batching relies on availability of the lock() for the tensor (and the *USM_DEVICE is not lockable)
|
||||
if (with_auto_batching
|
||||
@ -173,12 +188,13 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
GTEST_SKIP();
|
||||
|
||||
auto exec_net = ie.compile_model(function, deviceName, config);
|
||||
ov::Shape input_shape{1, 2, 32, 32};
|
||||
|
||||
// regular inference
|
||||
auto inf_req_regular = exec_net.create_infer_request();
|
||||
auto input = function->get_parameters().at(0);
|
||||
auto output = function->get_results().at(0);
|
||||
auto fakeImageData = ov::test::utils::create_and_fill_tensor(input->get_element_type(), input->get_shape());
|
||||
auto fakeImageData = ov::test::utils::create_and_fill_tensor(input->get_element_type(), input_shape);
|
||||
|
||||
inf_req_regular.set_tensor(input, fakeImageData);
|
||||
|
||||
@ -192,7 +208,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
auto ocl_instance = std::make_shared<OpenCL>(ctx);
|
||||
cl_int err;
|
||||
|
||||
auto imSize = ov::shape_size(input->get_shape());
|
||||
auto imSize = ov::shape_size(input_shape);
|
||||
|
||||
switch (sharing_type) {
|
||||
case RemoteTensorSharingType::USER_CL_TENSOR: {
|
||||
@ -202,7 +218,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_buffer, true, 0, imSize, buffer);
|
||||
}
|
||||
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input->get_shape(), shared_buffer);
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input_shape, shared_buffer);
|
||||
inf_req_shared.set_tensor(input, cldnn_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
@ -220,7 +236,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
FAIL() << "Failed to copy data from host buffer to USM device";
|
||||
}
|
||||
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input->get_shape(), shared_buffer);
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input_shape, shared_buffer);
|
||||
inf_req_shared.set_tensor(input, cldnn_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
@ -238,7 +254,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
std::memcpy(shared_buffer, buffer, imSize);
|
||||
}
|
||||
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input->get_shape(), shared_buffer);
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input_shape, shared_buffer);
|
||||
inf_req_shared.set_tensor(input, cldnn_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
@ -247,7 +263,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::PLUGIN_CL_TENSOR: {
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input->get_shape());
|
||||
auto cldnn_tensor = cldnn_context.create_tensor(input->get_element_type(), input_shape);
|
||||
ASSERT_TRUE(cldnn_tensor.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
auto cl_tensor = cldnn_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
{
|
||||
@ -263,7 +279,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
if (!ocl_instance->supports_usm())
|
||||
GTEST_SKIP();
|
||||
|
||||
auto cldnn_tensor = cldnn_context.create_usm_host_tensor(input->get_element_type(), input->get_shape());
|
||||
auto cldnn_tensor = cldnn_context.create_usm_host_tensor(input->get_element_type(), input_shape);
|
||||
ASSERT_TRUE(cldnn_tensor.is<ov::intel_gpu::ocl::USMTensor>());
|
||||
{
|
||||
auto cl_tensor = cldnn_tensor.as<ov::intel_gpu::ocl::USMTensor>();
|
||||
@ -282,7 +298,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
if (!ocl_instance->supports_usm())
|
||||
GTEST_SKIP();
|
||||
|
||||
auto cldnn_tensor = cldnn_context.create_usm_device_tensor(input->get_element_type(), input->get_shape());
|
||||
auto cldnn_tensor = cldnn_context.create_usm_device_tensor(input->get_element_type(), input_shape);
|
||||
ASSERT_TRUE(cldnn_tensor.is<ov::intel_gpu::ocl::USMTensor>());
|
||||
{
|
||||
auto cl_tensor = cldnn_tensor.as<ov::intel_gpu::ocl::USMTensor>();
|
||||
@ -300,7 +316,7 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::PLUGIN_HOST_TENSOR: {
|
||||
auto cldnn_tensor = cldnn_context.create_host_tensor(input->get_element_type(), input->get_shape());
|
||||
auto cldnn_tensor = cldnn_context.create_host_tensor(input->get_element_type(), input_shape);
|
||||
{
|
||||
ASSERT_NO_THROW(cldnn_tensor.data());
|
||||
void* shared_buffer = cldnn_tensor.data();
|
||||
@ -331,6 +347,277 @@ TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputRemoteTensor) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(OVRemoteTensorInputBlob_Test, smoke_canInputOutputRemoteTensor) {
|
||||
#if defined(ANDROID)
|
||||
GTEST_SKIP();
|
||||
#endif
|
||||
auto ie = ov::Core();
|
||||
|
||||
using namespace ov::preprocess;
|
||||
auto p = PrePostProcessor(fn_ptr);
|
||||
p.input().tensor().set_element_type(ov::element::i8);
|
||||
p.input().preprocess().convert_element_type(ov::element::f32);
|
||||
|
||||
auto model = p.build();
|
||||
RemoteTensorSharingType sharing_type;
|
||||
bool with_auto_batching;
|
||||
bool is_dynamic;
|
||||
std::tie(sharing_type, with_auto_batching, is_dynamic) = GetParam();
|
||||
|
||||
// auto-batching relies on availability of the lock() for the tensor (and the *USM_DEVICE is not lockable)
|
||||
if (with_auto_batching)
|
||||
GTEST_SKIP();
|
||||
|
||||
auto compiled_model = ie.compile_model(model, deviceName, config);
|
||||
|
||||
ov::Shape input_shape{1, 2, 32, 32};
|
||||
ov::Shape output_shape{1, 2, 32, 32};
|
||||
// regular inference
|
||||
auto inf_req_regular = compiled_model.create_infer_request();
|
||||
auto input = model->get_parameters().at(0);
|
||||
auto output = model->get_results().at(0);
|
||||
|
||||
auto input_data = ov::test::utils::create_and_fill_tensor(input->get_element_type(), input_shape);
|
||||
|
||||
inf_req_regular.set_tensor(input, input_data);
|
||||
|
||||
inf_req_regular.infer();
|
||||
auto output_tensor_regular = inf_req_regular.get_tensor(output);
|
||||
|
||||
// inference using remote tensor
|
||||
auto inf_req_shared = compiled_model.create_infer_request();
|
||||
auto gpu_context = compiled_model.get_context().as<ov::intel_gpu::ocl::ClContext>();
|
||||
cl_context ctx = gpu_context;
|
||||
auto ocl_instance = std::make_shared<OpenCL>(ctx);
|
||||
cl_int err;
|
||||
|
||||
auto allocated_out_shape = output_shape;
|
||||
if (is_dynamic) {
|
||||
// In dynamic case we allocate more than required to check that out tensor is reshaped correctly
|
||||
allocated_out_shape[1]++;
|
||||
}
|
||||
|
||||
auto in_size = ov::shape_size(input_shape);
|
||||
auto out_size = ov::shape_size(output_shape) * output->get_output_element_type(0).bitwidth() / 8;
|
||||
auto allocated_out_size = ov::shape_size(allocated_out_shape) * output->get_output_element_type(0).bitwidth() / 8;
|
||||
auto output_tensor_shared = ov::test::utils::create_and_fill_tensor(output->get_output_element_type(0), output_shape);
|
||||
|
||||
switch (sharing_type) {
|
||||
case RemoteTensorSharingType::USER_CL_TENSOR: {
|
||||
cl::Buffer shared_input_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, in_size, NULL, &err);
|
||||
cl::Buffer shared_output_buffer(ocl_instance->_context, CL_MEM_READ_WRITE, allocated_out_size, NULL, &err);
|
||||
{
|
||||
void* buffer = input_data.data();
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_input_buffer, true, 0, in_size, buffer);
|
||||
}
|
||||
|
||||
auto input_remote_tensor = gpu_context.create_tensor(input->get_element_type(), input_shape, shared_input_buffer);
|
||||
auto output_remote_tensor = gpu_context.create_tensor(output->get_output_element_type(0), allocated_out_shape, shared_output_buffer);
|
||||
inf_req_shared.set_tensor(input, input_remote_tensor);
|
||||
inf_req_shared.set_tensor(output, output_remote_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
{
|
||||
void* buffer = output_tensor_shared.data();
|
||||
auto out_tensor = inf_req_shared.get_output_tensor();
|
||||
ASSERT_EQ(out_tensor.get_shape(), output_shape);
|
||||
ocl_instance->_queue.enqueueReadBuffer(shared_output_buffer, true, 0, out_size, buffer);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::USER_USM_DEVICE_TENSOR: {
|
||||
if (!ocl_instance->supports_usm())
|
||||
GTEST_SKIP();
|
||||
|
||||
void* shared_input_buffer = ocl_instance->allocate_usm_device_buffer(in_size);
|
||||
void* shared_output_buffer = ocl_instance->allocate_usm_device_buffer(allocated_out_size);
|
||||
{
|
||||
void* buffer = input_data.data();
|
||||
err = ocl_instance->memcpy(ocl_instance->_queue, shared_input_buffer, buffer, in_size, true, nullptr, nullptr);
|
||||
if (err != CL_SUCCESS)
|
||||
FAIL() << "Failed to copy data from host buffer to USM device";
|
||||
}
|
||||
|
||||
auto input_remote_tensor = gpu_context.create_tensor(input->get_element_type(), input_shape, shared_input_buffer);
|
||||
auto output_remote_tensor = gpu_context.create_tensor(output->get_output_element_type(0), allocated_out_shape, shared_output_buffer);
|
||||
inf_req_shared.set_tensor(input, input_remote_tensor);
|
||||
inf_req_shared.set_tensor(output, output_remote_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
{
|
||||
void* buffer = output_tensor_shared.data();
|
||||
auto out_tensor = inf_req_shared.get_output_tensor();
|
||||
ASSERT_EQ(out_tensor.get_shape(), output_shape);
|
||||
err = ocl_instance->memcpy(ocl_instance->_queue, buffer, shared_output_buffer, out_size, true, nullptr, nullptr);
|
||||
if (err != CL_SUCCESS)
|
||||
FAIL() << "Failed to copy data from USM device to host buffer";
|
||||
}
|
||||
|
||||
|
||||
ocl_instance->free_mem(shared_input_buffer);
|
||||
ocl_instance->free_mem(shared_output_buffer);
|
||||
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::USER_USM_HOST_TENSOR: {
|
||||
if (!ocl_instance->supports_usm())
|
||||
GTEST_SKIP();
|
||||
|
||||
void* shared_input_buffer = ocl_instance->allocate_usm_host_buffer(in_size);
|
||||
void* shared_output_buffer = ocl_instance->allocate_usm_host_buffer(allocated_out_size);
|
||||
{
|
||||
void* buffer = input_data.data();
|
||||
std::memcpy(shared_input_buffer, buffer, in_size);
|
||||
}
|
||||
|
||||
auto input_remote_tensor = gpu_context.create_tensor(input->get_element_type(), input_shape, shared_input_buffer);
|
||||
auto output_remote_tensor = gpu_context.create_tensor(output->get_output_element_type(0), allocated_out_shape, shared_output_buffer);
|
||||
inf_req_shared.set_tensor(input, input_remote_tensor);
|
||||
inf_req_shared.set_tensor(output, output_remote_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
{
|
||||
void* buffer = output_tensor_shared.data();
|
||||
auto out_tensor = inf_req_shared.get_output_tensor();
|
||||
ASSERT_EQ(out_tensor.get_shape(), output_shape);
|
||||
err = ocl_instance->memcpy(ocl_instance->_queue, buffer, shared_output_buffer, out_size, true, nullptr, nullptr);
|
||||
if (err != CL_SUCCESS)
|
||||
FAIL() << "Failed to copy data from USM host to host buffer";
|
||||
}
|
||||
|
||||
ocl_instance->free_mem(shared_input_buffer);
|
||||
ocl_instance->free_mem(shared_output_buffer);
|
||||
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::PLUGIN_CL_TENSOR: {
|
||||
auto input_remote_tensor = gpu_context.create_tensor(input->get_element_type(), input_shape);
|
||||
auto output_remote_tensor = gpu_context.create_tensor(output->get_output_element_type(0), allocated_out_shape);
|
||||
ASSERT_TRUE(input_remote_tensor.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
auto cl_tensor = input_remote_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
{
|
||||
cl::Buffer shared_buffer = cl_tensor;
|
||||
void* buffer = input_data.data();
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_buffer, true, 0, in_size, buffer);
|
||||
}
|
||||
inf_req_shared.set_tensor(input, input_remote_tensor);
|
||||
inf_req_shared.set_tensor(output, output_remote_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
{
|
||||
auto out_cl_tensor = output_remote_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
|
||||
void* buffer = output_tensor_shared.data();
|
||||
auto out_tensor = inf_req_shared.get_output_tensor();
|
||||
ASSERT_EQ(out_tensor.get_shape(), output_shape);
|
||||
ocl_instance->_queue.enqueueReadBuffer(out_cl_tensor, true, 0, out_size, buffer);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::PLUGIN_USM_HOST_TENSOR: {
|
||||
if (!ocl_instance->supports_usm())
|
||||
GTEST_SKIP();
|
||||
|
||||
auto input_remote_tensor = gpu_context.create_usm_host_tensor(input->get_element_type(), input_shape);
|
||||
auto output_remote_tensor = gpu_context.create_usm_host_tensor(output->get_output_element_type(0), allocated_out_shape);
|
||||
ASSERT_TRUE(input_remote_tensor.is<ov::intel_gpu::ocl::USMTensor>());
|
||||
{
|
||||
auto cl_tensor = input_remote_tensor.as<ov::intel_gpu::ocl::USMTensor>();
|
||||
void* shared_buffer = cl_tensor.get();
|
||||
ASSERT_EQ(ocl_instance->get_allocation_type(shared_buffer), CL_MEM_TYPE_HOST_INTEL);
|
||||
void* buffer = input_data.data();
|
||||
std::memcpy(shared_buffer, buffer, in_size);
|
||||
}
|
||||
|
||||
inf_req_shared.set_tensor(input, input_remote_tensor);
|
||||
inf_req_shared.set_tensor(output, output_remote_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
{
|
||||
void* buffer = output_tensor_shared.data();
|
||||
auto out_tensor = inf_req_shared.get_output_tensor();
|
||||
auto cl_tensor = out_tensor.as<ov::intel_gpu::ocl::USMTensor>();
|
||||
void* shared_output_buffer = cl_tensor.get();
|
||||
ASSERT_EQ(ocl_instance->get_allocation_type(shared_output_buffer), CL_MEM_TYPE_HOST_INTEL);
|
||||
ASSERT_EQ(out_tensor.get_shape(), output_shape);
|
||||
std::memcpy(buffer, shared_output_buffer, out_size);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::PLUGIN_USM_DEVICE_TENSOR: {
|
||||
if (!ocl_instance->supports_usm())
|
||||
GTEST_SKIP();
|
||||
|
||||
auto input_remote_tensor = gpu_context.create_usm_device_tensor(input->get_element_type(), input_shape);
|
||||
auto output_remote_tensor = gpu_context.create_usm_device_tensor(output->get_output_element_type(0), allocated_out_shape);
|
||||
ASSERT_TRUE(input_remote_tensor.is<ov::intel_gpu::ocl::USMTensor>());
|
||||
{
|
||||
auto cl_tensor = input_remote_tensor.as<ov::intel_gpu::ocl::USMTensor>();
|
||||
void* shared_buffer = cl_tensor.get();
|
||||
ASSERT_EQ(ocl_instance->get_allocation_type(shared_buffer), CL_MEM_TYPE_DEVICE_INTEL);
|
||||
void* buffer = input_data.data();
|
||||
err = ocl_instance->memcpy(ocl_instance->_queue, shared_buffer, buffer, in_size, true, nullptr, nullptr);
|
||||
if (err != CL_SUCCESS)
|
||||
FAIL() << "Failed to copy data from host buffer to USM device";
|
||||
}
|
||||
|
||||
inf_req_shared.set_tensor(input, input_remote_tensor);
|
||||
inf_req_shared.set_tensor(output, output_remote_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
{
|
||||
auto cl_tensor = output_remote_tensor.as<ov::intel_gpu::ocl::USMTensor>();
|
||||
void* shared_output_buffer = cl_tensor.get();
|
||||
|
||||
void* buffer = output_tensor_shared.data();
|
||||
auto out_tensor = inf_req_shared.get_output_tensor();
|
||||
ASSERT_EQ(out_tensor.get_shape(), output_shape);
|
||||
err = ocl_instance->memcpy(ocl_instance->_queue, buffer, shared_output_buffer, out_size, true, nullptr, nullptr);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case RemoteTensorSharingType::PLUGIN_HOST_TENSOR: {
|
||||
auto input_tensor = gpu_context.create_host_tensor(input->get_element_type(), input_shape);
|
||||
auto output_tensor = gpu_context.create_host_tensor(output->get_output_element_type(0), allocated_out_shape);
|
||||
{
|
||||
ASSERT_NO_THROW(input_tensor.data());
|
||||
void* shared_buffer = input_tensor.data();
|
||||
if (ocl_instance->supports_usm()) {
|
||||
ASSERT_EQ(ocl_instance->get_allocation_type(shared_buffer), CL_MEM_TYPE_HOST_INTEL);
|
||||
}
|
||||
void* buffer = input_data.data();
|
||||
std::memcpy(shared_buffer, buffer, in_size);
|
||||
}
|
||||
|
||||
inf_req_shared.set_tensor(input, input_tensor);
|
||||
inf_req_shared.set_tensor(output, output_tensor);
|
||||
inf_req_shared.infer();
|
||||
|
||||
{
|
||||
void* buffer = output_tensor_shared.data();
|
||||
auto out_tensor = inf_req_shared.get_output_tensor();
|
||||
ASSERT_EQ(out_tensor.get_shape(), output_shape);
|
||||
err = ocl_instance->memcpy(ocl_instance->_queue, buffer, output_tensor.data(), out_size, true, nullptr, nullptr);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// compare results
|
||||
{
|
||||
ASSERT_EQ(output->get_element_type(), ov::element::f32);
|
||||
ASSERT_EQ(output_tensor_regular.get_size(), output_tensor_shared.get_size());
|
||||
auto thr = FuncTestUtils::GetComparisonThreshold(InferenceEngine::Precision::FP32);
|
||||
ASSERT_NO_THROW(output_tensor_regular.data());
|
||||
ASSERT_NO_THROW(output_tensor_shared.data());
|
||||
ov::test::utils::compare(output_tensor_regular, output_tensor_shared, thr);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
smoke_GPU,
|
||||
OVRemoteTensorInputBlob_Test,
|
||||
@ -342,9 +629,125 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
RemoteTensorSharingType::PLUGIN_USM_HOST_TENSOR,
|
||||
RemoteTensorSharingType::PLUGIN_USM_DEVICE_TENSOR,
|
||||
RemoteTensorSharingType::PLUGIN_HOST_TENSOR}),
|
||||
::testing::ValuesIn(ov_with_auto_batching)),
|
||||
::testing::ValuesIn(ov_with_auto_batching),
|
||||
::testing::ValuesIn(ov_dynamic)),
|
||||
OVRemoteTensorInputBlob_Test::getTestCaseName);
|
||||
|
||||
TEST(OVRemoteTensorTests, smoke_MixedTensorTypes) {
|
||||
#if defined(ANDROID)
|
||||
GTEST_SKIP();
|
||||
#endif
|
||||
auto core = ov::Core();
|
||||
auto model = ov::test::behavior::getDefaultNGraphFunctionForTheDevice();
|
||||
std::map<size_t, ov::PartialShape> dynamic_shape = {{0, ov::PartialShape::dynamic(4)}};
|
||||
model->reshape(dynamic_shape);
|
||||
|
||||
auto dynamic_compiled_model = core.compile_model(model, ov::test::utils::DEVICE_GPU);
|
||||
|
||||
auto input = model->get_parameters().at(0);
|
||||
auto output = model->get_results().at(0);
|
||||
|
||||
auto gpu_context = dynamic_compiled_model.get_context().as<ov::intel_gpu::ocl::ClContext>();
|
||||
cl_context ctx = gpu_context;
|
||||
auto ocl_instance = std::make_shared<OpenCL>(ctx);
|
||||
|
||||
ov::Shape output_shape_allocated{1, 3, 32, 32};
|
||||
auto user_output_tensor = gpu_context.create_tensor(output->get_element_type(), output_shape_allocated);
|
||||
ov::Tensor output_tensor_copy_0(output->get_element_type(), output_shape_allocated);
|
||||
ov::Tensor output_tensor_copy_1(output->get_element_type(), output_shape_allocated);
|
||||
|
||||
{
|
||||
auto infer_request = dynamic_compiled_model.create_infer_request();
|
||||
{
|
||||
// Run infer request with user's input & output tensor
|
||||
// Output tensor size is larger than required
|
||||
ov::Shape input_shape{1, 2, 32, 32};
|
||||
auto input_tensor = gpu_context.create_tensor(input->get_element_type(), input_shape);
|
||||
ov::Shape output_shape_actual{1, 2, 32, 32};
|
||||
|
||||
infer_request.set_tensor(input, input_tensor);
|
||||
infer_request.set_tensor(output, user_output_tensor);
|
||||
infer_request.infer();
|
||||
auto output_tensor = infer_request.get_tensor(output);
|
||||
|
||||
ASSERT_TRUE(output_tensor.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
ASSERT_TRUE(user_output_tensor.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
auto t1 = output_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
auto t2 = user_output_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
|
||||
ASSERT_EQ(t1.get(), t2.get());
|
||||
ASSERT_EQ(output_tensor.get_shape(), output_shape_actual);
|
||||
}
|
||||
|
||||
{
|
||||
// Keep same output, but use larger input
|
||||
// In that case user tensor is not enough to store the result and the plugin throws exception
|
||||
ov::Shape input_shape{1, 4, 32, 32};
|
||||
auto input_tensor = gpu_context.create_tensor(input->get_element_type(), input_shape);
|
||||
|
||||
infer_request.set_tensor(input, input_tensor);
|
||||
OV_EXPECT_THROW(infer_request.infer(), ov::Exception, HasSubstr("Output tensor set by user has smaller size"));
|
||||
}
|
||||
|
||||
{
|
||||
// Now try to increase buffer size comparing to the 1st run
|
||||
// User output buffer is supposed to be the same
|
||||
ov::Shape input_shape{1, 3, 32, 32};
|
||||
ov::Shape output_shape_actual{1, 3, 32, 32};
|
||||
auto input_tensor_1 = gpu_context.create_tensor(input->get_element_type(), input_shape);
|
||||
auto data = ov::test::utils::create_and_fill_tensor(input->get_element_type(), input_shape);
|
||||
ASSERT_TRUE(input_tensor_1.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
auto cl_tensor = input_tensor_1.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
cl::Buffer shared_buffer = cl_tensor;
|
||||
void* buffer = data.data();
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_buffer, true, 0, ov::shape_size(input_shape), buffer);
|
||||
|
||||
infer_request.set_tensor(input, input_tensor_1);
|
||||
infer_request.infer();
|
||||
auto output_tensor = infer_request.get_tensor(output);
|
||||
ASSERT_TRUE(output_tensor.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
ASSERT_TRUE(user_output_tensor.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
auto t1 = output_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
auto t2 = user_output_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
|
||||
// inference result of this iteration is stored to output_tensor_copy_0 for further values check
|
||||
ocl_instance->_queue.enqueueReadBuffer(t2, true, 0, user_output_tensor.get_byte_size(), output_tensor_copy_0.data());
|
||||
ASSERT_EQ(t1.get(), t2.get());
|
||||
ASSERT_EQ(output_tensor.get_shape(), output_shape_actual);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
auto infer_request = dynamic_compiled_model.create_infer_request();
|
||||
ov::Shape input_shape_0{1, 2, 32, 32};
|
||||
ov::Shape output_shape_actual_0{1, 2, 32, 32};
|
||||
auto input_tensor_0 = gpu_context.create_tensor(input->get_element_type(), input_shape_0);
|
||||
auto data = ov::test::utils::create_and_fill_tensor(input->get_element_type(), input_shape_0);
|
||||
ASSERT_TRUE(input_tensor_0.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
auto cl_tensor = input_tensor_0.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
cl::Buffer shared_buffer = cl_tensor;
|
||||
void* buffer = data.data();
|
||||
ocl_instance->_queue.enqueueWriteBuffer(shared_buffer, true, 0, ov::shape_size(input_shape_0), buffer);
|
||||
|
||||
infer_request.set_tensor(input, input_tensor_0);
|
||||
infer_request.infer();
|
||||
|
||||
auto output_tensor = infer_request.get_tensor(output);
|
||||
|
||||
ASSERT_FALSE(output_tensor.is<ov::RemoteTensor>());
|
||||
ASSERT_EQ(output_tensor.get_shape(), output_shape_actual_0);
|
||||
}
|
||||
|
||||
// Finally, check that last result stored in user output tensor is not corrupted when we run after one more iteration with another output buffer
|
||||
ASSERT_TRUE(user_output_tensor.is<ov::intel_gpu::ocl::ClBufferTensor>());
|
||||
auto t2 = user_output_tensor.as<ov::intel_gpu::ocl::ClBufferTensor>();
|
||||
ocl_instance->_queue.enqueueReadBuffer(t2, true, 0, user_output_tensor.get_byte_size(), output_tensor_copy_1.data());
|
||||
|
||||
for (size_t i = 0; i < output_tensor_copy_0.get_size(); i++) {
|
||||
ASSERT_EQ(output_tensor_copy_0.data<float>()[i], output_tensor_copy_1.data<float>()[i]) << " i = " << i;
|
||||
}
|
||||
}
|
||||
|
||||
class OVRemoteTensor_TestsWithContext : public OVRemoteTensor_Test, public testing::WithParamInterface<bool> {
|
||||
protected:
|
||||
std::shared_ptr<ngraph::Function> fn_ptr;
|
||||
|
Loading…
Reference in New Issue
Block a user