[GPU] Integration Prelu fusing (#8958)

* [GPU] Add new data format: os_is_yx_osa2_isa8_osv8_isv2

* fusing prelu

- add onednn_post_op_type::binary_relu

* update onednn_gpu
This commit is contained in:
Sungeun Kim 2021-12-14 20:09:50 +09:00 committed by GitHub
parent 7ac9a8f88e
commit bf3046e3a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 88 additions and 9 deletions

View File

@ -158,6 +158,7 @@ struct format {
os_is_zyx_osa4_isa8_osv8_isv2, ///< format for weights for MMAD fsv32 convolution
os_is_zyx_osa4_isa8_osv8_isv4, ///< format for weights for MMAD fsv32 convolution
os_is_yx_osa4_isa8_osv8_isv4, ///< format for weights for MMAD fsv32 convolution
os_is_yx_osa2_isa8_osv8_isv2,
os_is_yx_osa2_isa8_osv16_isv2,
os_is_yx_osa2_isa8_osv16_isv4,
is_o_yx_isv32, ///< format for weights for 1x1 MMAD convolutions
@ -291,6 +292,7 @@ struct format {
{ os_is_yx_osa4_isa8_osv8_isv4, { 1, 1, 2, 0, "oiyx", "oixy", {{0, 32}, {1, 32}}}},
{ os_is_zyx_osa4_isa8_osv8_isv2, { 1, 1, 3, 0, "oizyx", "oixyz", {{0, 32}, {1, 16}}}},
{ os_is_zyx_osa4_isa8_osv8_isv4, { 1, 1, 3, 0, "oizyx", "oixyz", {{0, 32}, {1, 32}}}},
{ os_is_yx_osa2_isa8_osv8_isv2, { 1, 1, 2, 0, "oiyx", "oixy?", {{0, 16}, {1, 16}}}},
{ os_is_yx_osa2_isa8_osv16_isv2, { 1, 1, 2, 0, "oiyx", "oixy", {{0, 32}, {1, 16}}}},
{ os_is_yx_osa2_isa8_osv16_isv4, { 1, 1, 2, 0, "oiyx", "oixy", {{0, 32}, {1, 32}}}},
{ os_is_zyx_isa8_osv8_isv4, { 1, 1, 3, 0, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}}},

View File

@ -90,6 +90,7 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::os_is_zyx_osa4_isa8_osv8_isv4, { 0, 1, 2, 3, 4, -1 } },
{ WeightsLayout::g_os_is_yx_osa4_isa8_osv8_isv2, { 0, 1, -1, 2, 3, 4 } },
{ WeightsLayout::g_os_is_zyx_osa4_isa8_osv8_isv2, { 0, 1, 2, 3, 4, 5 } },
{ WeightsLayout::os_is_yx_osa2_isa8_osv8_isv2, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_is_yx_osa2_isa8_osv16_isv4, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv4, { 0, 1, -1, 2, 3, 4 } },
{ WeightsLayout::os_is_yx_osa2_isa8_osv16_isv2, { 0, 1, -1, 2, 3, -1 } },
@ -526,6 +527,10 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
newDims[0] = RoundUp(newDims[0], 32);
newDims[3] = RoundUp(newDims[3], 32);
break;
case os_is_yx_osa2_isa8_osv8_isv2:
newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 16);
break;
case os_is_yx_osa4_isa8_osv8_isv4:
case g_os_is_yx_osa4_isa8_osv8_isv4:
case os_is_yx_osa2_isa8_osv16_isv4:

View File

@ -112,6 +112,7 @@ enum WeightsLayout {
os_is_zyx_osa4_isa8_osv8_isv4, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
g_os_is_yx_osa4_isa8_osv8_isv2, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
g_os_is_zyx_osa4_isa8_osv8_isv2, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
os_is_yx_osa2_isa8_osv8_isv2,
os_is_yx_osa2_isa8_osv16_isv4,
os_is_yx_osa2_isa8_osv16_isv2,
g_os_is_yx_osa2_isa8_osv16_isv4,

View File

@ -662,6 +662,34 @@ inline uint get_g_os_is_yx_osa4_isa8_osv8_isv2(uint g, uint o, uint i, uint z, u
return idx;
}
inline uint get_g_os_is_yx_osa2_isa8_osv8_isv2(uint g, uint o, uint i, uint z, uint y, uint x,
uint size_x, uint size_y, uint size_z, uint size_ifm, uint size_ofm, uint offset)
{
const uint isv_idx = i % 2;
const uint isa_idx = (i / 2) % 8;
const uint is_idx = (i / 16);
const uint osv_idx = o % 8;
const uint osa_idx = (o / 8) % 2;
const uint os_idx = (o / 16);
const uint ifm_16_aligned = ((size_ifm + 15)/16);
const uint ofm_16_aligned = ((size_ofm + 15)/16);
size_t idx = offset +
isv_idx +
osv_idx * 2 +
isa_idx * 8 * 2 +
osa_idx * 8 * 16 +
x * 16 * 16 +
y * size_x * 16 * 16 +
z * size_y * size_x * 16 * 16 +
is_idx * 16 * 16 * size_x * size_y * size_z +
os_idx * 16 * 16 * ifm_16_aligned * size_x * size_y * size_z +
g * 16 * 16 * ifm_16_aligned * ofm_16_aligned * size_x * size_y * size_z;
return idx;
}
inline uint get_g_os_is_yx_osa2_isa8_osv16_isv4(uint g, uint o, uint i, uint y, uint x, uint size_x, uint size_y, uint size_ifm, uint size_ofm, uint offset)
{
const uint isv_idx = i % 4;
@ -794,6 +822,16 @@ inline uint get_g_os_is_yx_osa2_isa8_osv16_isv2(uint g, uint o, uint i, uint y,
CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET))
#define GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV8_ISV2_INDEX(prefix, o, i, y, x) \
get_g_os_is_yx_osa2_isa8_osv8_isv2( \
0, o, i, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
1, \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET))
#define GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV16_ISV4_INDEX(prefix, o, i, y, x) \
get_g_os_is_yx_osa2_isa8_osv16_isv4( \
0, o, i, y, x, \

View File

@ -201,6 +201,8 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_OS_IS_ZYX_OSV32_ISV4_INDEX(OUTPUT, o, i, z, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_ISA8_OSV8_ISV4_SWIZZLED_BY_4
return GET_FILTER_OS_IS_YX_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, g, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSA2_ISA8_OSV8_ISV2
return GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSA4_ISA8_OSV8_ISV2
return GET_FILTER_OS_IS_YX_OSA4_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV2

View File

@ -393,6 +393,7 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv4: return "G_OS_IS_YX_OSA2_ISA8_OSV16_ISV4";
case WeightsLayout::os_is_yx_osa2_isa8_osv16_isv2: return "OS_IS_YX_OSA2_ISA8_OSV16_ISV2";
case WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv2: return "G_OS_IS_YX_OSA2_ISA8_OSV16_ISV2";
case WeightsLayout::os_is_yx_osa2_isa8_osv8_isv2: return "OS_IS_YX_OSA2_ISA8_OSV8_ISV2";
case WeightsLayout::g_os_is_yx_isv16_osv16: return "G_OS_IS_YX_ISV16_OSV16";
case WeightsLayout::g_os_is_yx_osv16_isv4: return "G_OS_IS_YX_OSV16_ISV4";
case WeightsLayout::g_os_is_zyx_osv16_isv16: return "G_OS_IS_ZYX_OSV16_ISV16";

View File

@ -475,6 +475,9 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
};
auto conv_supports_fusings = [&](convolution_node& node) -> bool {
if (_lo.get_optimization_attributes().use_onednn_impls == 1)
return true;
// Since reorder inputs is called after this pass
// we have to check that blocked formats can be used in the network and layer is optimized for it.
if ((node.get_output_layout().format == format::b_fs_yx_fsv16 ||
@ -692,10 +695,6 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
if (!input_data_supports_fusings(input_data, activation_node.id()))
return;
if ((input_data.get_users().size() != 1 || activation_node.get_dependencies().size() != 1) &&
_lo.get_optimization_attributes().use_onednn_impls == 1)
return;
bool should_fuse = input_data.is_type<binary_convolution>();
should_fuse |= input_data.is_type<convolution>() && conv_supports_fusings(input_data.as<convolution>());

View File

@ -110,6 +110,14 @@ protected:
break;
}
case onednn_post_op_type::binary_relu:
{
auto binary_op_mem = instance.fused_memory(memory_offset);
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(static_cast<int>(onednn_post_op_idx)) | DNNL_ARG_WEIGHTS,
binary_op_mem->get_onednn_memory(_pd.dnnl::primitive_desc_base::weights_desc(0))});
break;
}
case onednn_post_op_type::scale:
{
auto scale_op_mem = instance.fused_memory(memory_offset);

View File

@ -259,6 +259,7 @@ cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
case dnnl::memory::format_tag::Acdb16a: return cldnn::format::os_yxi_osv16;
case dnnl::memory::format_tag::ABcde16b16a: return cldnn::format::os_is_zyx_isv16_osv16;
case dnnl::memory::format_tag::aBcd16b: return cldnn::format::o_is_yx_isv16;
case dnnl::memory::format_tag::ABcd2a8b8a2b: return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
case dnnl::memory::format_tag::ABcd2a8b16a4b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv4;
case dnnl::memory::format_tag::ABcd2a8b16a2b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
default: throw std::runtime_error(std::string("Unsupported onednn fmt ") + dnnl_fmt_tag2str((dnnl_format_tag_t)fmt));

View File

@ -43,6 +43,7 @@ enum class onednn_post_op_type : uint32_t {
binary_add,
binary_max,
binary_min,
binary_relu,
scale,
sum,
optimized,

View File

@ -363,6 +363,8 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::os_is_zyx_osv16_isv16;
case format::g_os_is_zyx_osv16_isv16:
return kernel_selector::weights_layout::g_os_is_zyx_osv16_isv16;
case format::os_is_yx_osa2_isa8_osv8_isv2:
return kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv8_isv2;
case format::os_is_yx_osa2_isa8_osv16_isv4:
return kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv16_isv4;
case format::g_os_is_yx_osa2_isa8_osv16_isv4:
@ -452,6 +454,8 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::g_os_is_zyx_osa4_isa8_osv8_isv2;
case kernel_selector::weights_layout::os_is_yx_osa4_isa8_osv8_isv4:
return cldnn::format::os_is_yx_osa4_isa8_osv8_isv4;
case kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv8_isv2:
return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
case kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv16_isv2:
return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
case kernel_selector::weights_layout::g_os_is_yx_osa2_isa8_osv16_isv2:

View File

@ -339,6 +339,14 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
break;
}
case onednn_post_op_type::binary_relu:
{
int mask;
cur_p_ops.get_params_prelu(idx, mask);
new_p_ops.append_prelu(mask);
break;
}
case onednn_post_op_type::scale:
{
break;
@ -713,6 +721,7 @@ void program_node::init_onednn_primitive_attributes() {
type == onednn_post_op_type::binary_mul ||
type == onednn_post_op_type::binary_max ||
type == onednn_post_op_type::binary_min ||
type == onednn_post_op_type::binary_relu ||
type == onednn_post_op_type::scale ||
type == onednn_post_op_type::sum;
if (has_memory_buffers)
@ -723,10 +732,18 @@ void program_node::init_onednn_primitive_attributes() {
auto node = cldnn_post_ops[idx].node;
if (node->is_type<activation>()) {
auto fused_desc = node->as<activation>().get_primitive();
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
auto& a_node = node->as<activation>();
if (!a_node.get_primitive()->additional_params_input.empty()) {
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
int oc_dim = node->get_output_layout().size.feature.size();
post_ops.append_prelu(1 << oc_dim);
update_onednn_post_op_list(onednn_post_op_type::binary_relu, dep_idx);
} else {
auto fused_desc = node->as<activation>().get_primitive();
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
}
} else if (node->is_type<eltwise>()) {
auto& e_node = node->as<eltwise>();
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;

@ -1 +1 @@
Subproject commit 75d978369d0c5be04ec36c3cea2e00a14da1ec83
Subproject commit b2cd3a8e50a715f9326a35f4c503bd11e60235a5