[GPU] Host time optimizations for in order queue (#11255)

* [GPU] Host time optimizations

* Fix failed fusings_gpu/permute_eltwise_loop.basic/* tests
This commit is contained in:
Sergey Shlyapnikov 2022-03-30 10:53:53 +03:00 committed by GitHub
parent f13b6252e9
commit cd703580b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 56 deletions

View File

@ -102,7 +102,10 @@ public:
}
network_output get_output(const primitive_id& output_id) {
return network_output(get_primitive_event(output_id), get_output_memory(output_id), get_stream_ptr());
event::ptr evt;
if (get_stream().get_queue_type() == queue_types::out_of_order)
evt = get_primitive_event(output_id);
return network_output(evt, get_output_memory(output_id), get_stream_ptr());
}
memory::ptr get_output_memory(const primitive_id& output_id);
@ -133,8 +136,12 @@ public:
}
std::map<primitive_id, event::ptr> result;
for (auto& id : primitive_ids) {
if (std::find(optimized_primitives.begin(), optimized_primitives.end(), id) == optimized_primitives.end())
result.emplace(id, get_primitive_event(id));
if (std::find(optimized_primitives.begin(), optimized_primitives.end(), id) == optimized_primitives.end()) {
if (has_event(id))
result.emplace(id, get_primitive_event(id));
else
result.emplace(id, nullptr);
}
}
return result;
}

View File

@ -118,7 +118,9 @@ struct loop_impl : typed_primitive_impl<loop> {
loop_carried_dep.clear();
for (const auto& backedge : node.get_back_edges()) {
event::ptr body_event = body_network->get_primitive_event(backedge.from);
event::ptr body_event;
if (body_network->has_event(backedge.from))
body_event = body_network->get_primitive_event(backedge.from);
loop_carried_dep.emplace_back(body_event);
}

View File

@ -209,14 +209,6 @@ protected:
if (profiling) {
stream.finish();
event->set();
} else {
// Create and set user event as complete
event = stream.create_user_event(true);
}
if (!event) {
std::string error_msg = "Event was not created properly for " + instance.id();
throw std::runtime_error(error_msg);
}
return event;

View File

@ -700,34 +700,39 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
}
}
for (auto& inst : _program->get_processing_order()) {
// Special handling for mutable data. The event should be the same as the user or dependency with highest
// processing_num as the mutable_data can be updated when is both user or dependency.
if (inst->is_type<mutable_data>()) {
decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0;
for (auto& user : inst->get_users()) {
auto user_proc_num = _program->get_processing_order().get_processing_number(user);
if (user_proc_num > proc_num) {
_events[inst->id()] = _events[user->id()];
proc_num = user_proc_num;
// Store events only in case of OOO queue or enabled Profiling
auto store_events = get_stream().get_queue_type() == queue_types::out_of_order ||
get_engine().configuration().enable_profiling;
if (store_events) {
for (auto& inst : _program->get_processing_order()) {
// Special handling for mutable data. The event should be the same as the user or dependency with highest
// processing_num as the mutable_data can be updated when is both user or dependency.
if (inst->is_type<mutable_data>()) {
decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0;
for (auto& user : inst->get_users()) {
auto user_proc_num = _program->get_processing_order().get_processing_number(user);
if (user_proc_num > proc_num) {
_events[inst->id()] = _events[user->id()];
proc_num = user_proc_num;
}
}
}
if (!inst->get_dependencies().empty()) {
for (auto& dep : inst->get_dependencies()) {
auto dep_proc_num = _program->get_processing_order().get_processing_number(dep);
if (dep_proc_num > proc_num) {
_events[inst->id()] = _events[dep->id()];
proc_num = dep_proc_num;
if (!inst->get_dependencies().empty()) {
for (auto& dep : inst->get_dependencies()) {
auto dep_proc_num = _program->get_processing_order().get_processing_number(dep);
if (dep_proc_num > proc_num) {
_events[inst->id()] = _events[dep->id()];
proc_num = dep_proc_num;
}
}
}
}
}
}
for (auto& dout : _data_outputs) { // data primitives are not executed so if they are marked as output we need to add
// them valid events manually
_events[dout->id()] = get_stream().create_user_event(true);
for (auto& dout : _data_outputs) { // data primitives are not executed so if they are marked as output we need to add
// them valid events manually
_events[dout->id()] = get_stream().create_user_event(true);
}
}
for (auto& prim : _primitives) {
@ -828,17 +833,15 @@ std::vector<std::shared_ptr<primitive_inst>> network::get_primitives(const std::
}
void network::execute_primitive(const std::shared_ptr<primitive_inst>& primitive,
const std::vector<event::ptr>& events) {
auto id = primitive->id();
auto it = _events.find(id);
bool found = (it != _events.end());
CLDNN_ERROR_BOOL(id,
"Invalid primitive call ",
found,
"Primitive " + id + " is tried to be executed for the second time");
const std::vector<event::ptr>& events) {
event::ptr ev = primitive->execute(events);
_events.insert({id, ev});
// Collect events only for OOO queue and Profiling mode
if (get_stream().get_queue_type() == queue_types::out_of_order ||
get_engine().configuration().enable_profiling) {
auto id = primitive->id();
_events.insert({id, ev});
}
}
void network::allocate_primitive_instance(program_node const& node) {

View File

@ -148,18 +148,21 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
return _impl->execute(events, *this);
std::vector<event::ptr> dependencies;
dependencies.reserve(_exec_deps.size());
for (auto& input : _exec_deps) {
auto id = input->id();
try {
// if the requested event does not exits it means that it has not been executed, so the processing_order is
// wrong or synchronization failed.
auto ev = get_network().get_primitive_event(id);
dependencies.emplace_back(ev);
} catch (const std::out_of_range& oor) {
std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") +
std::string(oor.what() + std::string("\n"));
CLDNN_ERROR_MESSAGE(id, temp);
auto queue_type = get_network().get_stream().get_queue_type();
if (queue_type == queue_types::out_of_order) {
dependencies.reserve(_exec_deps.size());
for (auto& input : _exec_deps) {
auto id = input->id();
try {
// if the requested event does not exists it means that it has not been executed, so the processing_order is
// wrong or synchronization failed.
auto ev = get_network().get_primitive_event(id);
dependencies.emplace_back(ev);
} catch (const std::out_of_range& oor) {
std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") +
std::string(oor.what() + std::string("\n"));
CLDNN_ERROR_MESSAGE(id, temp);
}
}
}
return _impl->execute(dependencies, *this);

View File

@ -52,7 +52,7 @@ public:
};
class permute_eltwise_loop: public LoopFusingTest {};
TEST_P(permute_eltwise_loop, basic_taylor) {
TEST_P(permute_eltwise_loop, basic) {
auto p = GetParam();
auto num_iteration_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});
auto trip_count_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});