[GPU] Host time optimizations for in order queue (#11255)
* [GPU] Host time optimizations * Fix failed fusings_gpu/permute_eltwise_loop.basic/* tests
This commit is contained in:
parent
f13b6252e9
commit
cd703580b6
@ -102,7 +102,10 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
network_output get_output(const primitive_id& output_id) {
|
network_output get_output(const primitive_id& output_id) {
|
||||||
return network_output(get_primitive_event(output_id), get_output_memory(output_id), get_stream_ptr());
|
event::ptr evt;
|
||||||
|
if (get_stream().get_queue_type() == queue_types::out_of_order)
|
||||||
|
evt = get_primitive_event(output_id);
|
||||||
|
return network_output(evt, get_output_memory(output_id), get_stream_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
memory::ptr get_output_memory(const primitive_id& output_id);
|
memory::ptr get_output_memory(const primitive_id& output_id);
|
||||||
@ -133,8 +136,12 @@ public:
|
|||||||
}
|
}
|
||||||
std::map<primitive_id, event::ptr> result;
|
std::map<primitive_id, event::ptr> result;
|
||||||
for (auto& id : primitive_ids) {
|
for (auto& id : primitive_ids) {
|
||||||
if (std::find(optimized_primitives.begin(), optimized_primitives.end(), id) == optimized_primitives.end())
|
if (std::find(optimized_primitives.begin(), optimized_primitives.end(), id) == optimized_primitives.end()) {
|
||||||
result.emplace(id, get_primitive_event(id));
|
if (has_event(id))
|
||||||
|
result.emplace(id, get_primitive_event(id));
|
||||||
|
else
|
||||||
|
result.emplace(id, nullptr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,9 @@ struct loop_impl : typed_primitive_impl<loop> {
|
|||||||
|
|
||||||
loop_carried_dep.clear();
|
loop_carried_dep.clear();
|
||||||
for (const auto& backedge : node.get_back_edges()) {
|
for (const auto& backedge : node.get_back_edges()) {
|
||||||
event::ptr body_event = body_network->get_primitive_event(backedge.from);
|
event::ptr body_event;
|
||||||
|
if (body_network->has_event(backedge.from))
|
||||||
|
body_event = body_network->get_primitive_event(backedge.from);
|
||||||
loop_carried_dep.emplace_back(body_event);
|
loop_carried_dep.emplace_back(body_event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,14 +209,6 @@ protected:
|
|||||||
if (profiling) {
|
if (profiling) {
|
||||||
stream.finish();
|
stream.finish();
|
||||||
event->set();
|
event->set();
|
||||||
} else {
|
|
||||||
// Create and set user event as complete
|
|
||||||
event = stream.create_user_event(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!event) {
|
|
||||||
std::string error_msg = "Event was not created properly for " + instance.id();
|
|
||||||
throw std::runtime_error(error_msg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return event;
|
return event;
|
||||||
|
@ -700,34 +700,39 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& inst : _program->get_processing_order()) {
|
// Store events only in case of OOO queue or enabled Profiling
|
||||||
// Special handling for mutable data. The event should be the same as the user or dependency with highest
|
auto store_events = get_stream().get_queue_type() == queue_types::out_of_order ||
|
||||||
// processing_num as the mutable_data can be updated when is both user or dependency.
|
get_engine().configuration().enable_profiling;
|
||||||
if (inst->is_type<mutable_data>()) {
|
if (store_events) {
|
||||||
decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0;
|
for (auto& inst : _program->get_processing_order()) {
|
||||||
for (auto& user : inst->get_users()) {
|
// Special handling for mutable data. The event should be the same as the user or dependency with highest
|
||||||
auto user_proc_num = _program->get_processing_order().get_processing_number(user);
|
// processing_num as the mutable_data can be updated when is both user or dependency.
|
||||||
if (user_proc_num > proc_num) {
|
if (inst->is_type<mutable_data>()) {
|
||||||
_events[inst->id()] = _events[user->id()];
|
decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0;
|
||||||
proc_num = user_proc_num;
|
for (auto& user : inst->get_users()) {
|
||||||
|
auto user_proc_num = _program->get_processing_order().get_processing_number(user);
|
||||||
|
if (user_proc_num > proc_num) {
|
||||||
|
_events[inst->id()] = _events[user->id()];
|
||||||
|
proc_num = user_proc_num;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (!inst->get_dependencies().empty()) {
|
if (!inst->get_dependencies().empty()) {
|
||||||
for (auto& dep : inst->get_dependencies()) {
|
for (auto& dep : inst->get_dependencies()) {
|
||||||
auto dep_proc_num = _program->get_processing_order().get_processing_number(dep);
|
auto dep_proc_num = _program->get_processing_order().get_processing_number(dep);
|
||||||
if (dep_proc_num > proc_num) {
|
if (dep_proc_num > proc_num) {
|
||||||
_events[inst->id()] = _events[dep->id()];
|
_events[inst->id()] = _events[dep->id()];
|
||||||
proc_num = dep_proc_num;
|
proc_num = dep_proc_num;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& dout : _data_outputs) { // data primitives are not executed so if they are marked as output we need to add
|
for (auto& dout : _data_outputs) { // data primitives are not executed so if they are marked as output we need to add
|
||||||
// them valid events manually
|
// them valid events manually
|
||||||
_events[dout->id()] = get_stream().create_user_event(true);
|
_events[dout->id()] = get_stream().create_user_event(true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& prim : _primitives) {
|
for (auto& prim : _primitives) {
|
||||||
@ -828,17 +833,15 @@ std::vector<std::shared_ptr<primitive_inst>> network::get_primitives(const std::
|
|||||||
}
|
}
|
||||||
|
|
||||||
void network::execute_primitive(const std::shared_ptr<primitive_inst>& primitive,
|
void network::execute_primitive(const std::shared_ptr<primitive_inst>& primitive,
|
||||||
const std::vector<event::ptr>& events) {
|
const std::vector<event::ptr>& events) {
|
||||||
auto id = primitive->id();
|
|
||||||
auto it = _events.find(id);
|
|
||||||
bool found = (it != _events.end());
|
|
||||||
CLDNN_ERROR_BOOL(id,
|
|
||||||
"Invalid primitive call ",
|
|
||||||
found,
|
|
||||||
"Primitive " + id + " is tried to be executed for the second time");
|
|
||||||
|
|
||||||
event::ptr ev = primitive->execute(events);
|
event::ptr ev = primitive->execute(events);
|
||||||
_events.insert({id, ev});
|
|
||||||
|
// Collect events only for OOO queue and Profiling mode
|
||||||
|
if (get_stream().get_queue_type() == queue_types::out_of_order ||
|
||||||
|
get_engine().configuration().enable_profiling) {
|
||||||
|
auto id = primitive->id();
|
||||||
|
_events.insert({id, ev});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void network::allocate_primitive_instance(program_node const& node) {
|
void network::allocate_primitive_instance(program_node const& node) {
|
||||||
|
@ -148,18 +148,21 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
|
|||||||
return _impl->execute(events, *this);
|
return _impl->execute(events, *this);
|
||||||
|
|
||||||
std::vector<event::ptr> dependencies;
|
std::vector<event::ptr> dependencies;
|
||||||
dependencies.reserve(_exec_deps.size());
|
auto queue_type = get_network().get_stream().get_queue_type();
|
||||||
for (auto& input : _exec_deps) {
|
if (queue_type == queue_types::out_of_order) {
|
||||||
auto id = input->id();
|
dependencies.reserve(_exec_deps.size());
|
||||||
try {
|
for (auto& input : _exec_deps) {
|
||||||
// if the requested event does not exits it means that it has not been executed, so the processing_order is
|
auto id = input->id();
|
||||||
// wrong or synchronization failed.
|
try {
|
||||||
auto ev = get_network().get_primitive_event(id);
|
// if the requested event does not exists it means that it has not been executed, so the processing_order is
|
||||||
dependencies.emplace_back(ev);
|
// wrong or synchronization failed.
|
||||||
} catch (const std::out_of_range& oor) {
|
auto ev = get_network().get_primitive_event(id);
|
||||||
std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") +
|
dependencies.emplace_back(ev);
|
||||||
std::string(oor.what() + std::string("\n"));
|
} catch (const std::out_of_range& oor) {
|
||||||
CLDNN_ERROR_MESSAGE(id, temp);
|
std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") +
|
||||||
|
std::string(oor.what() + std::string("\n"));
|
||||||
|
CLDNN_ERROR_MESSAGE(id, temp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return _impl->execute(dependencies, *this);
|
return _impl->execute(dependencies, *this);
|
||||||
|
@ -52,7 +52,7 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
class permute_eltwise_loop: public LoopFusingTest {};
|
class permute_eltwise_loop: public LoopFusingTest {};
|
||||||
TEST_P(permute_eltwise_loop, basic_taylor) {
|
TEST_P(permute_eltwise_loop, basic) {
|
||||||
auto p = GetParam();
|
auto p = GetParam();
|
||||||
auto num_iteration_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});
|
auto num_iteration_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});
|
||||||
auto trip_count_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});
|
auto trip_count_mem = engine.allocate_memory({data_types::i64, format::bfyx, {1, 1, 1, 1}});
|
||||||
|
Loading…
Reference in New Issue
Block a user