[IE CLDNN] Don't force expected reorder layout & improve i64->i32 fallback (#1088)

This commit is contained in:
Jedrzej Hajduczenia 2020-07-02 09:18:38 +02:00 committed by GitHub
parent c8a6a7b6d0
commit fe4ff33a82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -44,9 +44,6 @@ void add_required_reorders::add_reorder(program_impl& p, program_node* node, pro
auto new_reorder = std::make_shared<reorder>(node->id() + "_reorder_" + usr->id(), node->id(), reorder_layout);
auto& new_reorder_node = p.get_or_create(new_reorder);
// make sure that new_reorder_node has correct layout
new_reorder_node.set_output_layout(reorder_layout, false);
// ToDo: add a method to program_impl class which adds an intermediate node given a node and its user
auto it = std::find(usr->get_dependencies().begin(), usr->get_dependencies().end(), node);
if (it == usr->get_dependencies().end()) {
@ -98,7 +95,6 @@ void add_required_reorders::run(program_impl& p) {
usr->set_output_layout(current_layout, false);
if (usr->type()->does_possible_implementation_exist(p.get_engine(), *usr)) {
correct_layout_selected = true;
break;
} else {
current_layout = original_layout;
current_layout.data_type = data_types::i32;
@ -106,9 +102,27 @@ void add_required_reorders::run(program_impl& p) {
usr->set_output_layout(current_layout, false);
if (usr->type()->does_possible_implementation_exist(p.get_engine(), *usr)) {
correct_layout_selected = true;
break;
}
}
if (correct_layout_selected) {
// change output_data_type field in usr to i32
if ((static_cast<bool>(usr->get_primitive()->output_data_type) == true) &&
(*(usr->get_primitive()->output_data_type) == data_types::i64)) {
std::const_pointer_cast<primitive>(usr->get_primitive())->output_data_type = data_types::i32;
}
// add reorders between usr int32 output and inputs of its users
auto next_usr_itr = usr->get_users().begin();
while (next_usr_itr != usr->get_users().end()) {
auto next_usr = *next_usr_itr++;
if (!next_usr->is_type<reorder>()) {
if ((next_usr->get_output_layout() != usr->get_output_layout())) {
add_reorder(p, usr, next_usr);
}
}
}
break;
}
}
}
@ -185,7 +199,13 @@ void add_required_reorders::run(program_impl& p) {
" kernel which satisfies output format dependecies.");
}
// add reorders between usr int32 outputs and inputs of its users
// change output_data_type field in usr to i32
if ((static_cast<bool>(usr->get_primitive()->output_data_type) == true) &&
(*(usr->get_primitive()->output_data_type) == data_types::i64)) {
std::const_pointer_cast<primitive>(usr->get_primitive())->output_data_type = data_types::i32;
}
// add reorders between usr int32 output and inputs of its users
auto next_usr_itr = usr->get_users().begin();
while (next_usr_itr != usr->get_users().end()) {
auto next_usr = *next_usr_itr++;