[GPU] Integration Prelu fusing (#8958)
* [GPU] Add new data format: os_is_yx_osa2_isa8_osv8_isv2 * fusing prelu - add onednn_post_op_type::binary_relu * update onednn_gpu
This commit is contained in:
parent
7ac9a8f88e
commit
bf3046e3a0
@ -158,6 +158,7 @@ struct format {
|
||||
os_is_zyx_osa4_isa8_osv8_isv2, ///< format for weights for MMAD fsv32 convolution
|
||||
os_is_zyx_osa4_isa8_osv8_isv4, ///< format for weights for MMAD fsv32 convolution
|
||||
os_is_yx_osa4_isa8_osv8_isv4, ///< format for weights for MMAD fsv32 convolution
|
||||
os_is_yx_osa2_isa8_osv8_isv2,
|
||||
os_is_yx_osa2_isa8_osv16_isv2,
|
||||
os_is_yx_osa2_isa8_osv16_isv4,
|
||||
is_o_yx_isv32, ///< format for weights for 1x1 MMAD convolutions
|
||||
@ -291,6 +292,7 @@ struct format {
|
||||
{ os_is_yx_osa4_isa8_osv8_isv4, { 1, 1, 2, 0, "oiyx", "oixy", {{0, 32}, {1, 32}}}},
|
||||
{ os_is_zyx_osa4_isa8_osv8_isv2, { 1, 1, 3, 0, "oizyx", "oixyz", {{0, 32}, {1, 16}}}},
|
||||
{ os_is_zyx_osa4_isa8_osv8_isv4, { 1, 1, 3, 0, "oizyx", "oixyz", {{0, 32}, {1, 32}}}},
|
||||
{ os_is_yx_osa2_isa8_osv8_isv2, { 1, 1, 2, 0, "oiyx", "oixy?", {{0, 16}, {1, 16}}}},
|
||||
{ os_is_yx_osa2_isa8_osv16_isv2, { 1, 1, 2, 0, "oiyx", "oixy", {{0, 32}, {1, 16}}}},
|
||||
{ os_is_yx_osa2_isa8_osv16_isv4, { 1, 1, 2, 0, "oiyx", "oixy", {{0, 32}, {1, 32}}}},
|
||||
{ os_is_zyx_isa8_osv8_isv4, { 1, 1, 3, 0, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}}},
|
||||
|
@ -90,6 +90,7 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
|
||||
{ WeightsLayout::os_is_zyx_osa4_isa8_osv8_isv4, { 0, 1, 2, 3, 4, -1 } },
|
||||
{ WeightsLayout::g_os_is_yx_osa4_isa8_osv8_isv2, { 0, 1, -1, 2, 3, 4 } },
|
||||
{ WeightsLayout::g_os_is_zyx_osa4_isa8_osv8_isv2, { 0, 1, 2, 3, 4, 5 } },
|
||||
{ WeightsLayout::os_is_yx_osa2_isa8_osv8_isv2, { 0, 1, -1, 2, 3, -1 } },
|
||||
{ WeightsLayout::os_is_yx_osa2_isa8_osv16_isv4, { 0, 1, -1, 2, 3, -1 } },
|
||||
{ WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv4, { 0, 1, -1, 2, 3, 4 } },
|
||||
{ WeightsLayout::os_is_yx_osa2_isa8_osv16_isv2, { 0, 1, -1, 2, 3, -1 } },
|
||||
@ -526,6 +527,10 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
|
||||
newDims[0] = RoundUp(newDims[0], 32);
|
||||
newDims[3] = RoundUp(newDims[3], 32);
|
||||
break;
|
||||
case os_is_yx_osa2_isa8_osv8_isv2:
|
||||
newDims[2] = RoundUp(newDims[2], 16);
|
||||
newDims[3] = RoundUp(newDims[3], 16);
|
||||
break;
|
||||
case os_is_yx_osa4_isa8_osv8_isv4:
|
||||
case g_os_is_yx_osa4_isa8_osv8_isv4:
|
||||
case os_is_yx_osa2_isa8_osv16_isv4:
|
||||
|
@ -112,6 +112,7 @@ enum WeightsLayout {
|
||||
os_is_zyx_osa4_isa8_osv8_isv4, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
|
||||
g_os_is_yx_osa4_isa8_osv8_isv2, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
|
||||
g_os_is_zyx_osa4_isa8_osv8_isv2, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
|
||||
os_is_yx_osa2_isa8_osv8_isv2,
|
||||
os_is_yx_osa2_isa8_osv16_isv4,
|
||||
os_is_yx_osa2_isa8_osv16_isv2,
|
||||
g_os_is_yx_osa2_isa8_osv16_isv4,
|
||||
|
@ -662,6 +662,34 @@ inline uint get_g_os_is_yx_osa4_isa8_osv8_isv2(uint g, uint o, uint i, uint z, u
|
||||
return idx;
|
||||
}
|
||||
|
||||
inline uint get_g_os_is_yx_osa2_isa8_osv8_isv2(uint g, uint o, uint i, uint z, uint y, uint x,
|
||||
uint size_x, uint size_y, uint size_z, uint size_ifm, uint size_ofm, uint offset)
|
||||
{
|
||||
const uint isv_idx = i % 2;
|
||||
const uint isa_idx = (i / 2) % 8;
|
||||
const uint is_idx = (i / 16);
|
||||
const uint osv_idx = o % 8;
|
||||
const uint osa_idx = (o / 8) % 2;
|
||||
const uint os_idx = (o / 16);
|
||||
|
||||
const uint ifm_16_aligned = ((size_ifm + 15)/16);
|
||||
const uint ofm_16_aligned = ((size_ofm + 15)/16);
|
||||
|
||||
size_t idx = offset +
|
||||
isv_idx +
|
||||
osv_idx * 2 +
|
||||
isa_idx * 8 * 2 +
|
||||
osa_idx * 8 * 16 +
|
||||
x * 16 * 16 +
|
||||
y * size_x * 16 * 16 +
|
||||
z * size_y * size_x * 16 * 16 +
|
||||
is_idx * 16 * 16 * size_x * size_y * size_z +
|
||||
os_idx * 16 * 16 * ifm_16_aligned * size_x * size_y * size_z +
|
||||
g * 16 * 16 * ifm_16_aligned * ofm_16_aligned * size_x * size_y * size_z;
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
inline uint get_g_os_is_yx_osa2_isa8_osv16_isv4(uint g, uint o, uint i, uint y, uint x, uint size_x, uint size_y, uint size_ifm, uint size_ofm, uint offset)
|
||||
{
|
||||
const uint isv_idx = i % 4;
|
||||
@ -794,6 +822,16 @@ inline uint get_g_os_is_yx_osa2_isa8_osv16_isv2(uint g, uint o, uint i, uint y,
|
||||
CAT(prefix, _OFM_NUM), \
|
||||
CAT(prefix, _OFFSET))
|
||||
|
||||
#define GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV8_ISV2_INDEX(prefix, o, i, y, x) \
|
||||
get_g_os_is_yx_osa2_isa8_osv8_isv2( \
|
||||
0, o, i, 0, y, x, \
|
||||
CAT(prefix, _SIZE_X), \
|
||||
CAT(prefix, _SIZE_Y), \
|
||||
1, \
|
||||
CAT(prefix, _IFM_NUM), \
|
||||
CAT(prefix, _OFM_NUM), \
|
||||
CAT(prefix, _OFFSET))
|
||||
|
||||
#define GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV16_ISV4_INDEX(prefix, o, i, y, x) \
|
||||
get_g_os_is_yx_osa2_isa8_osv16_isv4( \
|
||||
0, o, i, y, x, \
|
||||
|
@ -201,6 +201,8 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
|
||||
return GET_FILTER_OS_IS_ZYX_OSV32_ISV4_INDEX(OUTPUT, o, i, z, y, x);
|
||||
#elif defined OUTPUT_LAYOUT_OS_IS_YX_ISA8_OSV8_ISV4_SWIZZLED_BY_4
|
||||
return GET_FILTER_OS_IS_YX_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, g, o, i, y, x);
|
||||
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSA2_ISA8_OSV8_ISV2
|
||||
return GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
|
||||
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSA4_ISA8_OSV8_ISV2
|
||||
return GET_FILTER_OS_IS_YX_OSA4_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
|
||||
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV2
|
||||
|
@ -393,6 +393,7 @@ std::string toString(WeightsLayout layout) {
|
||||
case WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv4: return "G_OS_IS_YX_OSA2_ISA8_OSV16_ISV4";
|
||||
case WeightsLayout::os_is_yx_osa2_isa8_osv16_isv2: return "OS_IS_YX_OSA2_ISA8_OSV16_ISV2";
|
||||
case WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv2: return "G_OS_IS_YX_OSA2_ISA8_OSV16_ISV2";
|
||||
case WeightsLayout::os_is_yx_osa2_isa8_osv8_isv2: return "OS_IS_YX_OSA2_ISA8_OSV8_ISV2";
|
||||
case WeightsLayout::g_os_is_yx_isv16_osv16: return "G_OS_IS_YX_ISV16_OSV16";
|
||||
case WeightsLayout::g_os_is_yx_osv16_isv4: return "G_OS_IS_YX_OSV16_ISV4";
|
||||
case WeightsLayout::g_os_is_zyx_osv16_isv16: return "G_OS_IS_ZYX_OSV16_ISV16";
|
||||
|
@ -475,6 +475,9 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
|
||||
};
|
||||
|
||||
auto conv_supports_fusings = [&](convolution_node& node) -> bool {
|
||||
if (_lo.get_optimization_attributes().use_onednn_impls == 1)
|
||||
return true;
|
||||
|
||||
// Since reorder inputs is called after this pass
|
||||
// we have to check that blocked formats can be used in the network and layer is optimized for it.
|
||||
if ((node.get_output_layout().format == format::b_fs_yx_fsv16 ||
|
||||
@ -692,10 +695,6 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
|
||||
if (!input_data_supports_fusings(input_data, activation_node.id()))
|
||||
return;
|
||||
|
||||
if ((input_data.get_users().size() != 1 || activation_node.get_dependencies().size() != 1) &&
|
||||
_lo.get_optimization_attributes().use_onednn_impls == 1)
|
||||
return;
|
||||
|
||||
bool should_fuse = input_data.is_type<binary_convolution>();
|
||||
|
||||
should_fuse |= input_data.is_type<convolution>() && conv_supports_fusings(input_data.as<convolution>());
|
||||
|
@ -110,6 +110,14 @@ protected:
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_relu:
|
||||
{
|
||||
auto binary_op_mem = instance.fused_memory(memory_offset);
|
||||
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(static_cast<int>(onednn_post_op_idx)) | DNNL_ARG_WEIGHTS,
|
||||
binary_op_mem->get_onednn_memory(_pd.dnnl::primitive_desc_base::weights_desc(0))});
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::scale:
|
||||
{
|
||||
auto scale_op_mem = instance.fused_memory(memory_offset);
|
||||
|
@ -259,6 +259,7 @@ cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_grouped) {
|
||||
case dnnl::memory::format_tag::Acdb16a: return cldnn::format::os_yxi_osv16;
|
||||
case dnnl::memory::format_tag::ABcde16b16a: return cldnn::format::os_is_zyx_isv16_osv16;
|
||||
case dnnl::memory::format_tag::aBcd16b: return cldnn::format::o_is_yx_isv16;
|
||||
case dnnl::memory::format_tag::ABcd2a8b8a2b: return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
|
||||
case dnnl::memory::format_tag::ABcd2a8b16a4b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv4;
|
||||
case dnnl::memory::format_tag::ABcd2a8b16a2b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
|
||||
default: throw std::runtime_error(std::string("Unsupported onednn fmt ") + dnnl_fmt_tag2str((dnnl_format_tag_t)fmt));
|
||||
|
@ -43,6 +43,7 @@ enum class onednn_post_op_type : uint32_t {
|
||||
binary_add,
|
||||
binary_max,
|
||||
binary_min,
|
||||
binary_relu,
|
||||
scale,
|
||||
sum,
|
||||
optimized,
|
||||
|
@ -363,6 +363,8 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
|
||||
return kernel_selector::weights_layout::os_is_zyx_osv16_isv16;
|
||||
case format::g_os_is_zyx_osv16_isv16:
|
||||
return kernel_selector::weights_layout::g_os_is_zyx_osv16_isv16;
|
||||
case format::os_is_yx_osa2_isa8_osv8_isv2:
|
||||
return kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv8_isv2;
|
||||
case format::os_is_yx_osa2_isa8_osv16_isv4:
|
||||
return kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv16_isv4;
|
||||
case format::g_os_is_yx_osa2_isa8_osv16_isv4:
|
||||
@ -452,6 +454,8 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
|
||||
return cldnn::format::g_os_is_zyx_osa4_isa8_osv8_isv2;
|
||||
case kernel_selector::weights_layout::os_is_yx_osa4_isa8_osv8_isv4:
|
||||
return cldnn::format::os_is_yx_osa4_isa8_osv8_isv4;
|
||||
case kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv8_isv2:
|
||||
return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
|
||||
case kernel_selector::weights_layout::os_is_yx_osa2_isa8_osv16_isv2:
|
||||
return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
|
||||
case kernel_selector::weights_layout::g_os_is_yx_osa2_isa8_osv16_isv2:
|
||||
|
@ -339,6 +339,14 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_relu:
|
||||
{
|
||||
int mask;
|
||||
cur_p_ops.get_params_prelu(idx, mask);
|
||||
new_p_ops.append_prelu(mask);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::scale:
|
||||
{
|
||||
break;
|
||||
@ -713,6 +721,7 @@ void program_node::init_onednn_primitive_attributes() {
|
||||
type == onednn_post_op_type::binary_mul ||
|
||||
type == onednn_post_op_type::binary_max ||
|
||||
type == onednn_post_op_type::binary_min ||
|
||||
type == onednn_post_op_type::binary_relu ||
|
||||
type == onednn_post_op_type::scale ||
|
||||
type == onednn_post_op_type::sum;
|
||||
if (has_memory_buffers)
|
||||
@ -723,10 +732,18 @@ void program_node::init_onednn_primitive_attributes() {
|
||||
auto node = cldnn_post_ops[idx].node;
|
||||
|
||||
if (node->is_type<activation>()) {
|
||||
auto fused_desc = node->as<activation>().get_primitive();
|
||||
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
|
||||
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||
auto& a_node = node->as<activation>();
|
||||
if (!a_node.get_primitive()->additional_params_input.empty()) {
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
int oc_dim = node->get_output_layout().size.feature.size();
|
||||
post_ops.append_prelu(1 << oc_dim);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_relu, dep_idx);
|
||||
} else {
|
||||
auto fused_desc = node->as<activation>().get_primitive();
|
||||
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
|
||||
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||
}
|
||||
} else if (node->is_type<eltwise>()) {
|
||||
auto& e_node = node->as<eltwise>();
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
|
2
thirdparty/onednn_gpu
vendored
2
thirdparty/onednn_gpu
vendored
@ -1 +1 @@
|
||||
Subproject commit 75d978369d0c5be04ec36c3cea2e00a14da1ec83
|
||||
Subproject commit b2cd3a8e50a715f9326a35f4c503bd11e60235a5
|
Loading…
Reference in New Issue
Block a user