[GPU] Host time optimizations for in order queue (#11255)

* [GPU] Host time optimizations

* Fix failed fusings_gpu/permute_eltwise_loop.basic/* tests
This commit is contained in:
Sergey Shlyapnikov 2022-03-30 10:53:53 +03:00 committed by GitHub
parent f13b6252e9
commit cd703580b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 56 deletions

View File

@ -102,7 +102,10 @@ public:
} }
network_output get_output(const primitive_id& output_id) { network_output get_output(const primitive_id& output_id) {
return network_output(get_primitive_event(output_id), get_output_memory(output_id), get_stream_ptr()); event::ptr evt;
if (get_stream().get_queue_type() == queue_types::out_of_order)
evt = get_primitive_event(output_id);
return network_output(evt, get_output_memory(output_id), get_stream_ptr());
} }
memory::ptr get_output_memory(const primitive_id& output_id); memory::ptr get_output_memory(const primitive_id& output_id);
@ -133,8 +136,12 @@ public:
} }
std::map<primitive_id, event::ptr> result; std::map<primitive_id, event::ptr> result;
for (auto& id : primitive_ids) { for (auto& id : primitive_ids) {
if (std::find(optimized_primitives.begin(), optimized_primitives.end(), id) == optimized_primitives.end()) if (std::find(optimized_primitives.begin(), optimized_primitives.end(), id) == optimized_primitives.end()) {
result.emplace(id, get_primitive_event(id)); if (has_event(id))
result.emplace(id, get_primitive_event(id));
else
result.emplace(id, nullptr);
}
} }
return result; return result;
} }

View File

@ -118,7 +118,9 @@ struct loop_impl : typed_primitive_impl<loop> {
loop_carried_dep.clear(); loop_carried_dep.clear();
for (const auto& backedge : node.get_back_edges()) { for (const auto& backedge : node.get_back_edges()) {
event::ptr body_event = body_network->get_primitive_event(backedge.from); event::ptr body_event;
if (body_network->has_event(backedge.from))
body_event = body_network->get_primitive_event(backedge.from);
loop_carried_dep.emplace_back(body_event); loop_carried_dep.emplace_back(body_event);
} }

View File

@ -209,14 +209,6 @@ protected:
if (profiling) { if (profiling) {
stream.finish(); stream.finish();
event->set(); event->set();
} else {
// Create and set user event as complete
event = stream.create_user_event(true);
}
if (!event) {
std::string error_msg = "Event was not created properly for " + instance.id();
throw std::runtime_error(error_msg);
} }
return event; return event;

View File

@ -700,34 +700,39 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
} }
} }
for (auto& inst : _program->get_processing_order()) { // Store events only in case of OOO queue or enabled Profiling
// Special handling for mutable data. The event should be the same as the user or dependency with highest auto store_events = get_stream().get_queue_type() == queue_types::out_of_order ||
// processing_num as the mutable_data can be updated when is both user or dependency. get_engine().configuration().enable_profiling;
if (inst->is_type<mutable_data>()) { if (store_events) {
decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0; for (auto& inst : _program->get_processing_order()) {
for (auto& user : inst->get_users()) { // Special handling for mutable data. The event should be the same as the user or dependency with highest
auto user_proc_num = _program->get_processing_order().get_processing_number(user); // processing_num as the mutable_data can be updated when is both user or dependency.
if (user_proc_num > proc_num) { if (inst->is_type<mutable_data>()) {
_events[inst->id()] = _events[user->id()]; decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0;
proc_num = user_proc_num; for (auto& user : inst->get_users()) {
auto user_proc_num = _program->get_processing_order().get_processing_number(user);
if (user_proc_num > proc_num) {
_events[inst->id()] = _events[user->id()];
proc_num = user_proc_num;
}
} }
}
if (!inst->get_dependencies().empty()) { if (!inst->get_dependencies().empty()) {
for (auto& dep : inst->get_dependencies()) { for (auto& dep : inst->get_dependencies()) {
auto dep_proc_num = _program->get_processing_order().get_processing_number(dep); auto dep_proc_num = _program->get_processing_order().get_processing_number(dep);
if (dep_proc_num > proc_num) { if (dep_proc_num > proc_num) {
_events[inst->id()] = _events[dep->id()]; _events[inst->id()] = _events[dep->id()];
proc_num = dep_proc_num; proc_num = dep_proc_num;
}
} }
} }
} }
} }
}
for (auto& dout : _data_outputs) { // data primitives are not executed so if they are marked as output we need to add for (auto& dout : _data_outputs) { // data primitives are not executed so if they are marked as output we need to add
// them valid events manually // them valid events manually
_events[dout->id()] = get_stream().create_user_event(true); _events[dout->id()] = get_stream().create_user_event(true);
}
} }
for (auto& prim : _primitives) { for (auto& prim : _primitives) {
@ -828,17 +833,15 @@ std::vector<std::shared_ptr<primitive_inst>> network::get_primitives(const std::
} }
void network::execute_primitive(const std::shared_ptr<primitive_inst>& primitive, void network::execute_primitive(const std::shared_ptr<primitive_inst>& primitive,
const std::vector<event::ptr>& events) { const std::vector<event::ptr>& events) {
auto id = primitive->id();
auto it = _events.find(id);
bool found = (it != _events.end());
CLDNN_ERROR_BOOL(id,
"Invalid primitive call ",
found,
"Primitive " + id + " is tried to be executed for the second time");
event::ptr ev = primitive->execute(events); event::ptr ev = primitive->execute(events);
_events.insert({id, ev});
// Collect events only for OOO queue and Profiling mode
if (get_stream().get_queue_type() == queue_types::out_of_order ||
get_engine().configuration().enable_profiling) {
auto id = primitive->id();
_events.insert({id, ev});
}
} }
void network::allocate_primitive_instance(program_node const& node) { void network::allocate_primitive_instance(program_node const& node) {

View File

@ -148,18 +148,21 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
return _impl->execute(events, *this); return _impl->execute(events, *this);
std::vector<event::ptr> dependencies; std::vector<event::ptr> dependencies;
dependencies.reserve(_exec_deps.size()); auto queue_type = get_network().get_stream().get_queue_type();
for (auto& input : _exec_deps) { if (queue_type == queue_types::out_of_order) {
auto id = input->id(); dependencies.reserve(_exec_deps.size());
try { for (auto& input : _exec_deps) {
// if the requested event does not exits it means that it has not been executed, so the processing_order is auto id = input->id();
// wrong or synchronization failed. try {
auto ev = get_network().get_primitive_event(id); // if the requested event does not exists it means that it has not been executed, so the processing_order is
dependencies.emplace_back(ev); // wrong or synchronization failed.
} catch (const std::out_of_range& oor) { auto ev = get_network().get_primitive_event(id);
std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") + dependencies.emplace_back(ev);
std::string(oor.what() + std::string("\n")); } catch (const std::out_of_range& oor) {
CLDNN_ERROR_MESSAGE(id, temp); std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") +
std::string(oor.what() + std::string("\n"));
CLDNN_ERROR_MESSAGE(id, temp);
}
} }
} }
return _impl->execute(dependencies, *this); return _impl->execute(dependencies, *this);

View File

@ -52,7 +52,7 @@ public:
}; };
class permute_eltwise_loop: public LoopFusingTest {}; class permute_eltwise_loop: public LoopFusingTest {};
TEST_P(permute_eltwise_loop, basic_taylor) { TEST_P(permute_eltwise_loop, basic) {
auto p = GetParam(); auto p = GetParam();
auto num_iteration_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}}); auto num_iteration_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});
auto trip_count_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}}); auto trip_count_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});