[GPU] Enable a network with onednn deconv (#9597)

This commit is contained in:
Mingyu Kim 2022-01-13 12:03:08 +09:00 committed by GitHub
parent c597bc8928
commit 3ae3c8dfb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 54 additions and 1 deletions

View File

@ -163,6 +163,7 @@ struct format {
os_is_yx_osa2_isa8_osv16_isv2,
os_is_yx_osa2_isa8_osv16_isv4,
is_os_yx_isa2_osa8_isv8_osv2,
is_os_yx_isa4_osa8_isv8_osv4, ///< format for weights for MMAD fsv32 convolution
is_o_yx_isv32, ///< format for weights for 1x1 MMAD convolutions
is_o32_yx_isv32_swizzled_by_4, ///< format for weights for 1x1 MMAD convolutions
os_is_y_x8_osv8_isv4, ///< format for weights for 1x1 MMAD convolutions
@ -303,6 +304,7 @@ struct format {
{ os_is_yx_osa4_isa8_osv8_isv4_swizzled_by_4, { 1, 1, 2, 0, "oiyx", "oixy?", {{0, 32}, {1, 32}}}},
{ os_is_zyx_osa4_isa8_osv8_isv4_swizzled_by_4, { 1, 1, 3, 0, "oizyx", "oixyz", {{0, 32}, {1, 32}}}},
{ is_os_yx_isa2_osa8_isv8_osv2, { 1, 1, 2, 0, "ioyx", "ioxy?", {{1, 16}, {0, 16}}}},
{ is_os_yx_isa4_osa8_isv8_osv4, { 1, 1, 2, 0, "ioyx", "ioxy?", {{1, 32}, {0, 32}}}},
{ is_o_yx_isv32, { 1, 1, 2, 0, "oyxi", "oixy?", {{1, 32}}}},
{ is_o32_yx_isv32_swizzled_by_4, { 1, 1, 2, 0, "oyxi", "oixy?", {}}},
{ os_is_y_x8_osv8_isv4, { 1, 1, 2, 0, "oyxi", "oixy?", {}}},

View File

@ -9,6 +9,8 @@
#include <typeinfo>
#include <tuple>
#include <string>
#include <sstream>
#include "to_string_utils.h"
namespace cldnn {
@ -124,6 +126,19 @@ struct implementation_key<loop> {
type operator()(const layout&) { return -1; }
};
namespace {
template <typename key_type>
std::string get_key_name(const key_type &) { return std::string(""); }
template <>
std::string get_key_name(const int32_t &k) { return std::to_string(k); }
template <>
std::string get_key_name(const std::tuple<data_types, format::type> &key) {
return dt_to_str(std::get<0>(key)) + "/" + fmt_to_str(std::get<1>(key));
}
} // namespace
template <typename primitive_kind>
class implementation_map {
public:
@ -147,8 +162,11 @@ public:
return factory;
}
}
std::stringstream target_impl_type_ss;
target_impl_type_ss << target_impl_type;
throw std::runtime_error(std::string("implementation_map for ") + typeid(primitive_kind).name() +
" could not find any implementation to match key");
"could not find any implementation to match key: " +
get_key_name(key) + ", impl_type: " + target_impl_type_ss.str() + ", node_id: " + primitive.id());
}
// check if for a given engine and type there exist an implementation

View File

@ -179,6 +179,11 @@ attach_deconvolution_onednn::attach_deconvolution_onednn() {
std::make_tuple(data_types::u8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv32),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16),

View File

@ -277,6 +277,7 @@ static cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_groupe
case dnnl::memory::format_tag::ABcd2a8b8a2b: return cldnn::format::os_is_yx_osa2_isa8_osv8_isv2;
case dnnl::memory::format_tag::ABcd2a8b16a4b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv4;
case dnnl::memory::format_tag::ABcd2a8b16a2b: return cldnn::format::os_is_yx_osa2_isa8_osv16_isv2;
case dnnl::memory::format_tag::BAcd4b8a8b4a: return cldnn::format::is_os_yx_isa4_osa8_isv8_osv4;
default: throw std::runtime_error(std::string("Unsupported onednn fmt ") + dnnl_fmt_tag2str((dnnl_format_tag_t)fmt));
}
}

View File

@ -391,6 +391,8 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::g_os_zyx_is_osv32_isv32;
case format::is_os_yx_isa2_osa8_isv8_osv2:
return kernel_selector::weights_layout::is_os_yx_isa2_osa8_isv8_osv2;
case format::is_os_yx_isa4_osa8_isv8_osv4:
return kernel_selector::weights_layout::is_os_yx_isa4_osa8_isv8_osv4;
default:
throw std::invalid_argument("Unable to convert tensor layout " + fmt_to_str(f) + " to weights layout");
}
@ -510,6 +512,8 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::is_os_yx_isv16_osv16;
case kernel_selector::weights_layout::is_os_yx_isa2_osa8_isv8_osv2:
return cldnn::format::is_os_yx_isa2_osa8_isv8_osv2;
case kernel_selector::weights_layout::is_os_yx_isa4_osa8_isv8_osv4:
return cldnn::format::is_os_yx_isa4_osa8_isv8_osv4;
case kernel_selector::weights_layout::os_is_yx_osv8_isv2:
return cldnn::format::os_is_yx_osv8_isv2;
case kernel_selector::weights_layout::os_is_yx_osv8_isv4:

View File

@ -1392,6 +1392,8 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
bool valid_params = valid_ic && valid_groups && onednn_valid_post_ops && valid_batch;
if (!valid_params)
impl_candidate = impl_types::ocl;
if (input_layout.data_type != deconv.get_output_layout().data_type)
impl_candidate = impl_types::ocl;
}
// [WA] oneDNN doesn't support > 32 post-ops. Remove once oneDNN improve post-ops for GPU.

View File

@ -118,6 +118,7 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::is_os_zyx_isv16_osv16, { 0, 1, 2, 4, 3, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv16, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isa2_osa8_isv8_osv2, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isa4_osa8_isv8_osv4, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::os_is_osv32_isv32_swizzled_by_4, { -1, -1, -1, 0, 1, -1 } },
{ WeightsLayout::os_is_zyx_isv8_osv16_isv2, { 0, 1, 2, 3, 4, -1 } },
{ WeightsLayout::os_is_yx_isv8_osv16_isv2, { 0, 1, -1, 2, 3, -1 } },
@ -543,6 +544,7 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
case g_os_is_yx_osa4_isa8_osv8_isv4:
case os_is_yx_osa2_isa8_osv16_isv4:
case g_os_is_yx_osa2_isa8_osv16_isv4:
case is_os_yx_isa4_osa8_isv8_osv4:
newDims[3] = RoundUp(newDims[3], 32);
newDims[2] = RoundUp(newDims[2], 32);
break;

View File

@ -106,6 +106,7 @@ enum WeightsLayout {
os_is_yx_isa8_osv16_isv4, // for fully connected MMAD
os_is_zyx_isa8_osv16_isv4, // for fully connected MMAD
os_is_yx_osa4_isa8_osv8_isv4, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
is_os_yx_isa4_osa8_isv8_osv4, // for onednn deconvolution
g_os_is_yx_osa4_isa8_osv8_isv4, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
g_os_is_zyx_osa4_isa8_osv8_isv4, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,
os_is_yx_osa4_isa8_osv8_isv2, // for MMAD convolution swizzled from ofm 0..7 to 0,4,8,12,16,20,24,28,

View File

@ -748,6 +748,12 @@ inline uint get_g_is_os_yx_isa2_osa8_isv8_osv2(uint g, uint o, uint i, uint z, u
return get_g_os_is_yx_osa2_isa8_osv8_isv2(g, i, o, z, y, x, size_x, size_y, size_z, size_ofm, size_ifm, offset);
}
inline uint get_g_is_os_yx_isa4_osa8_isv8_osv4(uint g, uint o, uint i, uint z, uint y, uint x,
uint size_x, uint size_y, uint size_z, uint size_ifm, uint size_ofm, uint offset)
{
return get_g_os_is_yx_osa4_isa8_osv8_isv4(g, i, o, z, y, x, size_x, size_y, size_z, size_ofm, size_ifm, offset);
}
#define GET_FILTER_OS_IS_YX_OSA4_ISA8_OSV8_ISV4_INDEX(prefix, o, i, y, x) \
get_g_os_is_yx_osa4_isa8_osv8_isv4( \
0, o, i, 0, y, x, \
@ -911,6 +917,15 @@ inline uint get_g_is_os_yx_isa2_osa8_isv8_osv2(uint g, uint o, uint i, uint z, u
CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET))
#define GET_FILTER_IS_OS_YX_ISA4_OSA8_ISV8_OSV4_INDEX(prefix, o, i, y, x) \
get_g_is_os_yx_isa4_osa8_isv8_osv4( \
0, o, i, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
1, \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET))
inline uint get_is_o_yx_isv32_index(uint o, uint i, uint y, uint x, uint i_size, uint o_size, uint x_size, uint y_size)
{

View File

@ -205,6 +205,8 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_OS_IS_YX_OSA2_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISA2_OSA8_ISV8_OSV2
return GET_FILTER_IS_OS_YX_ISA2_OSA8_ISV8_OSV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISA4_OSA8_ISV8_OSV4
return GET_FILTER_IS_OS_YX_ISA4_OSA8_ISV8_OSV4_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSA4_ISA8_OSV8_ISV2
return GET_FILTER_OS_IS_YX_OSA4_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV2

View File

@ -396,6 +396,7 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::g_os_is_yx_osa2_isa8_osv16_isv2: return "G_OS_IS_YX_OSA2_ISA8_OSV16_ISV2";
case WeightsLayout::os_is_yx_osa2_isa8_osv8_isv2: return "OS_IS_YX_OSA2_ISA8_OSV8_ISV2";
case WeightsLayout::is_os_yx_isa2_osa8_isv8_osv2: return "IS_OS_YX_ISA2_OSA8_ISV8_OSV2";
case WeightsLayout::is_os_yx_isa4_osa8_isv8_osv4: return "IS_OS_YX_ISA4_OSA8_ISV8_OSV4";
case WeightsLayout::g_os_is_yx_isv16_osv16: return "G_OS_IS_YX_ISV16_OSV16";
case WeightsLayout::g_os_is_yx_osv16_isv4: return "G_OS_IS_YX_OSV16_ISV4";
case WeightsLayout::g_os_is_zyx_osv16_isv16: return "G_OS_IS_ZYX_OSV16_ISV16";