[GPU] Adjust preferred format of resample operation (#9919)
* Adjust preferred format of resample operation * Applied review comment * Not to fix resample layout when there is permute user unless the permute order is rotating
This commit is contained in:
parent
f9b88c385c
commit
54678f47cf
@ -278,6 +278,33 @@ struct format {
|
||||
fmt == bfyx || fmt == fyxb ||
|
||||
fmt == bfzyx || fmt == bfwzyx);
|
||||
}
|
||||
|
||||
static format get_default_format(size_t rank, bool is_weights, bool is_grouped) {
|
||||
auto default_fmt = cldnn::format::bfyx;
|
||||
if (is_weights) {
|
||||
if (is_grouped) {
|
||||
if (rank == 5) {
|
||||
default_fmt = cldnn::format::goiyx;
|
||||
} else if (rank == 6) {
|
||||
default_fmt = cldnn::format::goizyx;
|
||||
}
|
||||
} else {
|
||||
if (rank == 4) {
|
||||
default_fmt = cldnn::format::oiyx;
|
||||
} else if (rank == 5) {
|
||||
default_fmt = cldnn::format::oizyx;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (rank == 5) {
|
||||
default_fmt = cldnn::format::bfzyx;
|
||||
} else if (rank == 6) {
|
||||
default_fmt = cldnn::format::bfwzyx;
|
||||
}
|
||||
}
|
||||
return default_fmt;
|
||||
}
|
||||
|
||||
/// @brief Checks if @p format is of grouped type
|
||||
static bool is_grouped(type fmt) { return group_num(fmt) != 0; }
|
||||
/// @brief Checks if @p format is of image type
|
||||
|
@ -180,6 +180,7 @@ public:
|
||||
explicit layout_optimizer(bool output_size_handling_enabled = true);
|
||||
|
||||
format get_preferred_format(program_node& node);
|
||||
bool all_users_simple_format_until_output(program_node& origin_node, program_node& cur_node, int32_t cur_depth, int32_t max_depth);
|
||||
impl_types get_preferred_impl_type(program_node& node, format preferred_format);
|
||||
|
||||
bool are_data_types_suitable_for_onednn(program_node& node);
|
||||
|
@ -6,7 +6,6 @@
|
||||
#include "primitive_inst.h"
|
||||
#include "program_helpers.h"
|
||||
#include "intel_gpu/runtime/error_handler.hpp"
|
||||
|
||||
#include "data_inst.h"
|
||||
#include "reorder_inst.h"
|
||||
#include "resample_inst.h"
|
||||
@ -1650,11 +1649,56 @@ format layout_optimizer::get_preferred_format(program_node& node) {
|
||||
if (input_layout.format.dimension() == 5 &&
|
||||
(input_layout.data_type == data_types::f32 || input_layout.data_type == data_types::f16))
|
||||
expected = format::bfzyx;
|
||||
} else if (node.is_type<resample>()) {
|
||||
// if the resample is in the last part of the network and there are no users using blocked format,
|
||||
// it is better to reorder to bfyx before resample is done.
|
||||
if (all_users_simple_format_until_output(node, node, 0, 10)) {
|
||||
const auto& dim = format::dimension(node.get_output_layout().format);
|
||||
expected = format::get_default_format(dim, false, false);
|
||||
} else {
|
||||
expected = format::any;
|
||||
}
|
||||
}
|
||||
|
||||
return expected;
|
||||
}
|
||||
|
||||
bool layout_optimizer::all_users_simple_format_until_output(program_node& origin_node, program_node& cur_node, int32_t cur_depth, int32_t max_depth) {
|
||||
if (cur_node.is_output()) return true;
|
||||
if (cur_depth > max_depth) return false;
|
||||
|
||||
auto is_rotating_except_batch = [](const std::vector<uint16_t>& order) {
|
||||
// Target transform: Rotate feature dim to back to be taken as inner-most axis
|
||||
// ex) 0(b), 4(f), 1(z), 2(y), 3(x)
|
||||
// ex) 0(b), 3(f), 1(y), 2(x)
|
||||
if ((int32_t) order[1] != order.size() - 1) return false;
|
||||
if ((int32_t) order[0] != 0) return false;
|
||||
for (int32_t i = 2; i < (int32_t) order.size(); ++i) {
|
||||
if ((int32_t)order[i] != (i - 1)) return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if (cur_node.is_type<permute>()) {
|
||||
auto& permute_order = cur_node.as<permute>().get_primitive()->permute_order;
|
||||
if (!is_rotating_except_batch(permute_order))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cur_node.is_in_data_flow() && (cur_node.type() != origin_node.type())) {
|
||||
const auto& fmt = get_preferred_format(cur_node);
|
||||
if (fmt != format::any && !format::is_simple_data_format(fmt)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool res = true;
|
||||
for (const auto& usr : cur_node.get_users()) {
|
||||
res &= all_users_simple_format_until_output(origin_node, *usr, cur_depth + 1, max_depth);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void layout_optimizer::set_optimization_attribute(optimization_attributes_type attribute, int32_t val) {
|
||||
switch (attribute) {
|
||||
case optimization_attributes_type::splitted_convolution:
|
||||
|
@ -86,39 +86,13 @@ tensor::value_type layout::ifm() const {
|
||||
return dims[dim_idx];
|
||||
}
|
||||
|
||||
static format get_default_format(size_t rank, bool is_weights, bool is_grouped) {
|
||||
auto default_fmt = cldnn::format::bfyx;
|
||||
if (is_weights) {
|
||||
if (is_grouped) {
|
||||
if (rank == 5) {
|
||||
default_fmt = cldnn::format::goiyx;
|
||||
} else if (rank == 6) {
|
||||
default_fmt = cldnn::format::goizyx;
|
||||
}
|
||||
} else {
|
||||
if (rank == 4) {
|
||||
default_fmt = cldnn::format::oiyx;
|
||||
} else if (rank == 5) {
|
||||
default_fmt = cldnn::format::oizyx;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (rank == 5) {
|
||||
default_fmt = cldnn::format::bfzyx;
|
||||
} else if (rank == 6) {
|
||||
default_fmt = cldnn::format::bfwzyx;
|
||||
}
|
||||
}
|
||||
|
||||
return default_fmt;
|
||||
}
|
||||
std::vector<tensor::value_type> layout::get_dims() const {
|
||||
auto default_fmt = get_default_format(format.dimension(), format::is_weights_format(format), format::is_grouped(format));
|
||||
auto default_fmt = format::get_default_format(format.dimension(), format::is_weights_format(format), format::is_grouped(format));
|
||||
return size.sizes(default_fmt);
|
||||
}
|
||||
|
||||
std::vector<tensor::value_type> layout::get_padded_dims() const {
|
||||
auto default_fmt = get_default_format(format.dimension(), format::is_weights_format(format), format::is_grouped(format));
|
||||
auto default_fmt = format::get_default_format(format.dimension(), format::is_weights_format(format), format::is_grouped(format));
|
||||
auto padded_size = size.add(data_padding.lower_size()).add(data_padding.upper_size());
|
||||
return padded_size.sizes(default_fmt);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user