[GPU] Fix permute performance degradation (#10559)

* [GPU] Fix permute performance degradation

Signed-off-by: Andrew Kwangwoong Park <andrew.kwangwoong.park@intel.com>

* add description for update

Signed-off-by: Andrew Kwangwoong Park <andrew.kwangwoong.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park 2022-02-22 11:35:04 +09:00 committed by GitHub
parent aea0532d76
commit 33062bef7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 40 deletions

View File

@ -24,6 +24,18 @@ public:
program_node& input() const { return get_dependency(0); } program_node& input() const { return get_dependency(0); }
std::vector<uint16_t> get_permute_order() const { return get_primitive()->permute_order; } std::vector<uint16_t> get_permute_order() const { return get_primitive()->permute_order; }
bool is_rotating_except_batch() const {
// Target transform: Rotate feature dim to back to be taken as inner-most axis
// ex) 0(b), 4(f), 1(z), 2(y), 3(x)
// ex) 0(b), 3(f), 1(y), 2(x)
auto& order = get_primitive()->permute_order;
if ((int32_t) order[1] != order.size() - 1) return false;
if ((int32_t) order[0] != 0) return false;
for (int32_t i = 2; i < (int32_t) order.size(); ++i) {
if ((int32_t)order[i] != (i - 1)) return false;
}
return true;
}
}; };
using permute_node = typed_program_node<permute>; using permute_node = typed_program_node<permute>;

View File

@ -375,23 +375,11 @@ bool layout_optimizer::can_fuse_reorder(program_node& prev, program_node& next,
return true; return true;
if (next.is_type<permute>()) { if (next.is_type<permute>()) {
auto is_rotating_except_batch = [](const std::vector<uint16_t>& order) {
// Target transform: Rotate feature dim to back to be taken as inner-most axis
// ex) 0(b), 4(f), 1(z), 2(y), 3(x)
// ex) 0(b), 3(f), 1(y), 2(x)
if ((int32_t) order[1] != order.size() - 1) return false;
if ((int32_t) order[0] != 0) return false;
for (int32_t i = 2; i < (int32_t) order.size(); ++i) {
if ((int32_t)order[i] != (i - 1)) return false;
}
return true;
};
auto& permute_order = next.as<permute>().get_primitive()->permute_order; auto& permute_order = next.as<permute>().get_primitive()->permute_order;
if ((fmt_prev == format::b_fs_yx_fsv4 || fmt_prev == format::b_fs_yx_fsv32 || fmt_prev == format::b_fs_zyx_fsv32 || if ((fmt_prev == format::b_fs_yx_fsv4 || fmt_prev == format::b_fs_yx_fsv32 || fmt_prev == format::b_fs_zyx_fsv32 ||
fmt_prev == format::b_fs_yx_fsv16 || fmt_prev == format::b_fs_zyx_fsv16 || fmt_prev == format::bs_fs_yx_bsv16_fsv16) fmt_prev == format::b_fs_yx_fsv16 || fmt_prev == format::b_fs_zyx_fsv16 || fmt_prev == format::bs_fs_yx_bsv16_fsv16)
&& permute_order[1] == 2 && permute_order[1] == 2
&& (!is_rotating_except_batch(permute_order))) { && (!next.as<permute>().is_rotating_except_batch())) {
return false; return false;
} }
return true; return true;
@ -439,23 +427,11 @@ bool layout_optimizer::can_fuse_reorder_to_prev(program_node& prev, program_node
return true; return true;
if (prev.is_type<permute>()) { if (prev.is_type<permute>()) {
auto is_rotating_except_batch = [](const std::vector<uint16_t>& order) {
// Target transform: Rotate feature dim to back to be taken as inner-most axis
// ex) 0(b), 4(f), 1(z), 2(y), 3(x)
// ex) 0(b), 3(f), 1(y), 2(x)
if ((int32_t) order[1] != order.size() - 1) return false;
if ((int32_t) order[0] != 0) return false;
for (int32_t i = 2; i < (int32_t) order.size(); ++i) {
if ((int32_t)order[i] != (i - 1)) return false;
}
return true;
};
auto& permute_order = prev.as<permute>().get_primitive()->permute_order; auto& permute_order = prev.as<permute>().get_primitive()->permute_order;
if ((fmt_prev == format::b_fs_yx_fsv4 || fmt_prev == format::b_fs_yx_fsv32 || fmt_prev == format::b_fs_zyx_fsv32 || if ((fmt_prev == format::b_fs_yx_fsv4 || fmt_prev == format::b_fs_yx_fsv32 || fmt_prev == format::b_fs_zyx_fsv32 ||
fmt_prev == format::b_fs_yx_fsv16 || fmt_prev == format::b_fs_zyx_fsv16 || fmt_prev == format::bs_fs_yx_bsv16_fsv16) fmt_prev == format::b_fs_yx_fsv16 || fmt_prev == format::b_fs_zyx_fsv16 || fmt_prev == format::bs_fs_yx_bsv16_fsv16)
&& permute_order[1] == 2 && permute_order[1] == 2
&& (!is_rotating_except_batch(permute_order))) { && (!prev.as<permute>().is_rotating_except_batch())) {
return false; return false;
} }
return true; return true;
@ -1707,6 +1683,17 @@ format layout_optimizer::get_preferred_format(program_node& node) {
} else { } else {
expected = format::any; expected = format::any;
} }
} else if (node.is_type<permute>()) {
if (node.get_dependencies().size() == 1 && node.get_dependencies().front()->is_type<convolution>()) {
auto& conv_node = node.get_dependencies().front()->as<convolution>();
const auto& fmt = get_preferred_format(conv_node);
// if the preferred format of the previous conv of permute is fs_b_yx_fsv32,
// it is better to set to b_fs_yx_fsv32 that supports tiled permute (permute_tile_8x8_4x4_fsv)
// because fs_b_yx_fsv32 is only supported by permute_ref.
if (node.as<permute>().is_rotating_except_batch() && fmt == format::fs_b_yx_fsv32) {
expected = format::b_fs_yx_fsv32;
}
}
} }
return expected; return expected;
@ -1716,21 +1703,8 @@ bool layout_optimizer::all_users_simple_format_until_output(program_node& origin
if (cur_node.is_output()) return true; if (cur_node.is_output()) return true;
if (cur_depth > max_depth) return false; if (cur_depth > max_depth) return false;
auto is_rotating_except_batch = [](const std::vector<uint16_t>& order) {
// Target transform: Rotate feature dim to back to be taken as inner-most axis
// ex) 0(b), 4(f), 1(z), 2(y), 3(x)
// ex) 0(b), 3(f), 1(y), 2(x)
if ((int32_t) order[1] != order.size() - 1) return false;
if ((int32_t) order[0] != 0) return false;
for (int32_t i = 2; i < (int32_t) order.size(); ++i) {
if ((int32_t)order[i] != (i - 1)) return false;
}
return true;
};
if (cur_node.is_type<permute>()) { if (cur_node.is_type<permute>()) {
auto& permute_order = cur_node.as<permute>().get_primitive()->permute_order; if (!cur_node.as<permute>().is_rotating_except_batch())
if (!is_rotating_except_batch(permute_order))
return false; return false;
} }