[GPU] Gemm onednn implementation (#9984)

* [GPU] Gemm onednn implementation

* [GPU] Added implementation choice logic
This commit is contained in:
Ilya Znamenskiy
2022-02-07 11:48:42 +03:00
committed by GitHub
parent 9f9df184c4
commit ac28063b19
5 changed files with 748 additions and 196 deletions

View File

@@ -15,9 +15,143 @@
namespace cldnn {
namespace onednn {
struct gemm_onednn : typed_primitive_onednn_impl<gemm, dnnl::matmul::desc> {
using parent = typed_primitive_onednn_impl<gemm, dnnl::matmul::desc>;
using parent::parent;
protected:
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<gemm_onednn>(*this);
}
std::unordered_map<int, dnnl::memory> get_arguments(gemm_inst& instance) const override {
std::unordered_map<int, dnnl::memory> args = parent::get_arguments(instance);
auto& engine = instance.get_network().get_engine();
auto dnnl_engine = engine.get_onednn_engine();
{
auto& weights = instance.input_memory(1);
args.insert({DNNL_ARG_WEIGHTS, weights.get_onednn_memory(_pd.weights_desc(0))});
}
if (instance.inputs_memory_count() == 3) {
auto& weights = instance.input_memory(2);
args.insert({DNNL_ARG_BIAS, weights.get_onednn_memory(_pd.weights_desc(1))});
}
return args;
}
static dnnl::memory::format_tag transpose_format(dnnl::memory::format_tag fmt) {
switch (fmt) {
case dnnl::memory::format_tag::ab: return dnnl::memory::format_tag::ba;
case dnnl::memory::format_tag::abc: return dnnl::memory::format_tag::acb;
case dnnl::memory::format_tag::abcd: return dnnl::memory::format_tag::abdc;
default: throw std::runtime_error("Unsupported fmt in transpose_format gemm function");
}
}
static std::shared_ptr<dnnl::matmul::desc> get_gemm_descriptor(const gemm_node& arg) {
auto prim = arg.get_primitive();
auto gemm_with_bias = prim->dependencies().size() == 3;
auto& input0 = arg.get_dependency(0);
auto& input1 = arg.get_dependency(1);
auto in0_l = input0.get_output_layout();
auto in1_l = input1.get_output_layout();
auto out_l = arg.get_output_layout();
size_t in0_batched_size = in0_l.count() / (in0_l.size.spatial[0] * in0_l.size.spatial[1]);
size_t in1_batched_size = in1_l.count() / (in1_l.size.spatial[0] * in1_l.size.spatial[1]);
size_t out_batched_size = out_l.count() / (out_l.size.spatial[0] * out_l.size.spatial[1]);
auto batched_dims_can_be_removed = in0_batched_size == 1 && in1_batched_size == 1 && out_batched_size == 1;
if (gemm_with_bias) {
auto bias_l = arg.get_dependency(2).get_output_layout();
size_t bias_batched_size = bias_l.count() / (bias_l.size.spatial[0] * bias_l.size.spatial[1]);
batched_dims_can_be_removed &= bias_batched_size == 1;
}
size_t rank = cldnn::format::dimension(out_l.format);
dnnl::memory::data_type in0_dt = onednn::convert_data_type(in0_l.data_type);
dnnl::memory::data_type in1_dt = onednn::convert_data_type(in1_l.data_type);
dnnl::memory::data_type out_dt = onednn::convert_data_type(out_l.data_type);
dnnl::memory::dims in0_dims = onednn::convert_gemm_tensor(in0_l.size, rank, batched_dims_can_be_removed);
dnnl::memory::dims in1_dims = onednn::convert_gemm_tensor(in1_l.size, rank, batched_dims_can_be_removed);
dnnl::memory::dims out_dims = onednn::convert_gemm_tensor(out_l.size, rank, batched_dims_can_be_removed);
dnnl::memory::format_tag in0_fmt = onednn::convert_gemm_data_format(in0_dims);
dnnl::memory::format_tag in1_fmt = onednn::convert_gemm_data_format(in1_dims);
dnnl::memory::format_tag out_fmt = onednn::convert_gemm_data_format(out_dims);
if (prim->transpose_input0) {
in0_fmt = transpose_format(in0_fmt);
std::swap(in0_dims[in0_dims.size() - 1], in0_dims[in0_dims.size() - 2]);
}
if (prim->transpose_input1) {
in1_fmt = transpose_format(in1_fmt);
std::swap(in1_dims[in1_dims.size() - 1], in1_dims[in1_dims.size() - 2]);
}
dnnl::memory::desc in0_md(in0_dims, in0_dt, in0_fmt);
dnnl::memory::desc in1_md(in1_dims, in1_dt, in1_fmt);
dnnl::memory::desc out_md(out_dims, out_dt, out_fmt);
if (gemm_with_bias) {
auto bias_l = arg.get_dependency(2).get_output_layout();
auto bias_rank = cldnn::format::dimension(bias_l.format);
dnnl::memory::data_type bias_dt = onednn::convert_data_type(bias_l.data_type);
dnnl::memory::dims bias_dims = onednn::convert_gemm_tensor(bias_l.size, bias_rank, batched_dims_can_be_removed);
dnnl::memory::format_tag bias_fmt = onednn::convert_gemm_data_format(bias_dims);
dnnl::memory::desc bias_md(bias_dims, bias_dt, bias_fmt);
return std::make_shared<dnnl::matmul::desc>(
in0_md,
in1_md,
bias_md,
out_md);
} else {
return std::make_shared<dnnl::matmul::desc>(
in0_md,
in1_md,
out_md);
}
}
public:
static primitive_impl* create(const gemm_node& arg) {
auto& engine = arg.get_program().get_engine();
auto desc = get_gemm_descriptor(arg);
auto attr = arg.get_onednn_primitive_attributes();
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
return new gemm_onednn(arg, desc, attr, prim_desc);
}
};
namespace detail {
attach_gemm_onednn::attach_gemm_onednn() {
implementation_map<gemm>::add(impl_types::onednn, gemm_onednn::create, {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::u8, format::bfwzyx),
std::make_tuple(data_types::i8, format::bfwzyx),
});
}
} // namespace detail

View File

@@ -47,6 +47,28 @@ dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims, bool is_grouped)
return res;
}
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed) {
auto sizes = t.sizes(default_fmt_for_dims(dims, false));
dnnl::memory::dims res(sizes.begin(), sizes.end());
if (dims > 3) {
for (size_t i = 0; i < dims - 3; i++) {
res[i + 1] *= res[i];
}
res.erase(res.begin(), res.begin() + dims - 3);
}
if (res.size() == 3 && batched_dims_can_be_removed) {
res.erase(res.begin());
}
return res;
}
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) {
if (dims.size() > 3)
throw std::runtime_error("[clDNN] Unsupported dims size for onednn gemm: should be <= 3");
return dims.size() == 3 ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::ab;
}
dnnl::memory::dims convert_spatials(cldnn::tensor t, size_t dims) {
auto spatials = t.spatial;
dnnl::memory::dims res(dims);
@@ -75,7 +97,7 @@ dnnl::memory::data_type convert_data_type(cldnn::data_types dt) {
case cldnn::data_types::i8: return dnnl::memory::data_type::s8;
case cldnn::data_types::u8: return dnnl::memory::data_type::u8;
case cldnn::data_types::i32: return dnnl::memory::data_type::s32;
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to ondnn type");
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to onednn type");
}
}
@@ -95,7 +117,7 @@ dnnl::memory::format_tag convert_data_format(cldnn::format fmt) {
case cldnn::format::bs_fs_yx_bsv4_fsv2: return dnnl::memory::format_tag::ABcd4a2b;
case cldnn::format::bs_fs_yx_bsv32_fsv16: return dnnl::memory::format_tag::NChw32n16c;
case cldnn::format::bs_fs_zyx_bsv16_fsv16: return dnnl::memory::format_tag::NCdhw16n16c;
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to ondnn layout " + fmt_to_str(fmt));
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to onednn layout " + fmt_to_str(fmt));
}
}

View File

@@ -23,10 +23,12 @@ void combine_bf_with_first_spatial_dim(cldnn::layout& l);
// cldnn -> onednn
dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims = 2, bool is_grouped = false);
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed);
dnnl::memory::dims convert_spatials(cldnn::tensor t, size_t dims = 2);
dnnl::memory::dims flatten_tensor(cldnn::tensor t);
dnnl::memory::data_type convert_data_type(cldnn::data_types dt);
dnnl::memory::format_tag convert_data_format(cldnn::format fmt);
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims);
dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::undef, bool flatten = false);
dnnl::algorithm convert_activation_func(cldnn::activation_func func);
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped = false);

View File

@@ -13,6 +13,7 @@
#include "generic_layer.hpp"
#include <sstream>
#include "gemm_inst.h"
#include "eltwise_inst.h"
#include "pooling_inst.h"
#include "one_hot_inst.h"
@@ -1212,9 +1213,10 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
} else if (node.is_type<fully_connected>()) {
auto& fc_node = node.as<fully_connected>();
auto wei_dt = fc_node.weights().get_output_layout().data_type;
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
bool is_fc = node.is_type<fully_connected>();
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout().data_type :
node.as<gemm>().get_dependency(1).get_output_layout().data_type;
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
@@ -1486,7 +1488,8 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
break;
}
}
} else if (node.is_type<fully_connected>()) {
// TODO: uncomment this code when onednn gemm implementations will have real perf improvements vs cldnn
} else if (node.is_type<fully_connected>()/* || node.is_type<gemm>()*/) {
if (!_optimization_attributes.use_onednn_impls)
return impl_types::ocl;
@@ -1498,29 +1501,75 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
for (auto& fo : node.get_fused_primitives()) {
if (fo.node->is_type<eltwise>()) {
auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
auto out_layout = node.get_output_layout();
auto in_dt = in_layout.data_type;
auto out_dt = out_layout.data_type;
if ((out_layout.count() == in_layout.count()) &&
(data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt)) && in_dt != out_dt &&
fo.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(in_layout)) {
impl_candidate = impl_types::ocl;
break;
// FC checkings
if (node.is_type<fully_connected>()) {
auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
auto out_layout = node.get_output_layout();
auto in_dt = in_layout.data_type;
auto out_dt = out_layout.data_type;
if ((out_layout.count() == in_layout.count()) &&
(data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt)) && in_dt != out_dt &&
fo.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(in_layout)) {
impl_candidate = impl_types::ocl;
break;
}
// Gemm checkings
// TODO: investigate why currently onednn gemm has some "sum" post-op restrictions
// which don't correlate with fc checkings in the code above
// Temprorary WA: disable onednn gemm with sum post-op inside
} else {
auto& e_node = fo.node->as<eltwise>();
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
impl_candidate = impl_types::ocl;
break;
}
}
}
}
// OneDnn doesn't support spatial dimensions for output
auto fc_prim = node.as<fully_connected>().get_primitive();
auto out_layout = node.get_output_layout();
size_t rank = cldnn::format::dimension(out_layout.format);
auto size = out_layout.size;
for (int i = 0; i < rank - 2 - (fc_prim->input_size == 3 ? 1 : 0); i++) {
if (size.spatial[i] != 1) {
impl_candidate = impl_types::ocl;
break;
if (node.is_type<fully_connected>()) {
auto fc_prim = node.as<fully_connected>().get_primitive();
auto out_layout = node.get_output_layout();
size_t rank = cldnn::format::dimension(out_layout.format);
auto size = out_layout.size;
// OneDnn doesn't support spatial dimensions for output
for (int i = 0; i < rank - 2 - (fc_prim->input_size == 3 ? 1 : 0); i++) {
if (size.spatial[i] != 1) {
impl_candidate = impl_types::ocl;
break;
}
}
} else {
impl_candidate = impl_types::ocl;
auto gemm_prim = node.as<gemm>().get_primitive();
auto in0_l = node.get_dependency(0).get_output_layout();
auto in1_l = node.get_dependency(1).get_output_layout();
auto out_l = node.get_output_layout();
auto has_input2 = gemm_prim->dependencies().size() == 3;
size_t in2_batched_size;
if (has_input2) {
auto in2_l = node.get_dependency(2).get_output_layout();
in2_batched_size = in2_l.count() / (in2_l.size.spatial[0] * in2_l.size.spatial[1]);
}
size_t size_k = gemm_prim->transpose_input0 ? in0_l.size.spatial[1] : in0_l.size.spatial[0];
size_t in0_batched_size = in0_l.count() / (in0_l.size.spatial[0] * in0_l.size.spatial[1]);
size_t in1_batched_size = in1_l.count() / (in1_l.size.spatial[0] * in1_l.size.spatial[1]);
size_t out_batched_size = out_l.count() / (out_l.size.spatial[0] * out_l.size.spatial[1]);
auto valid_input_batch = in0_batched_size != 1 && (in1_batched_size == in0_batched_size || in1_batched_size == 1);
auto valid_output_batch = in0_batched_size > in1_batched_size ? out_batched_size == in0_batched_size :
out_batched_size == in1_batched_size;
auto valid_extra_input_batch = has_input2 ? in2_batched_size == 1 || in2_batched_size == out_batched_size : true;
auto valid_scale_factor = gemm_prim->alpha == 1.f && (has_input2 ? gemm_prim->beta == 1.f : true);
auto unsupported_onednn_gemm = !valid_input_batch ||
!valid_output_batch ||
!valid_extra_input_batch ||
!valid_scale_factor;
// Gemm with k < 64 is calculated via ref kernel in onednn so cldnn way is more preferable for such cases
if (size_k < 64 || unsupported_onednn_gemm)
impl_candidate = impl_types::ocl;
}
preferred_impl = impl_candidate;

File diff suppressed because one or more lines are too long