[GPU] Gemm onednn implementation (#9984)
* [GPU] Gemm onednn implementation * [GPU] Added implementation choice logic
This commit is contained in:
@@ -15,9 +15,143 @@
|
||||
namespace cldnn {
|
||||
namespace onednn {
|
||||
|
||||
struct gemm_onednn : typed_primitive_onednn_impl<gemm, dnnl::matmul::desc> {
|
||||
using parent = typed_primitive_onednn_impl<gemm, dnnl::matmul::desc>;
|
||||
using parent::parent;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<primitive_impl> clone() const override {
|
||||
return make_unique<gemm_onednn>(*this);
|
||||
}
|
||||
|
||||
std::unordered_map<int, dnnl::memory> get_arguments(gemm_inst& instance) const override {
|
||||
std::unordered_map<int, dnnl::memory> args = parent::get_arguments(instance);
|
||||
auto& engine = instance.get_network().get_engine();
|
||||
auto dnnl_engine = engine.get_onednn_engine();
|
||||
|
||||
{
|
||||
auto& weights = instance.input_memory(1);
|
||||
args.insert({DNNL_ARG_WEIGHTS, weights.get_onednn_memory(_pd.weights_desc(0))});
|
||||
}
|
||||
|
||||
if (instance.inputs_memory_count() == 3) {
|
||||
auto& weights = instance.input_memory(2);
|
||||
args.insert({DNNL_ARG_BIAS, weights.get_onednn_memory(_pd.weights_desc(1))});
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
static dnnl::memory::format_tag transpose_format(dnnl::memory::format_tag fmt) {
|
||||
switch (fmt) {
|
||||
case dnnl::memory::format_tag::ab: return dnnl::memory::format_tag::ba;
|
||||
case dnnl::memory::format_tag::abc: return dnnl::memory::format_tag::acb;
|
||||
case dnnl::memory::format_tag::abcd: return dnnl::memory::format_tag::abdc;
|
||||
default: throw std::runtime_error("Unsupported fmt in transpose_format gemm function");
|
||||
}
|
||||
}
|
||||
|
||||
static std::shared_ptr<dnnl::matmul::desc> get_gemm_descriptor(const gemm_node& arg) {
|
||||
auto prim = arg.get_primitive();
|
||||
auto gemm_with_bias = prim->dependencies().size() == 3;
|
||||
|
||||
auto& input0 = arg.get_dependency(0);
|
||||
auto& input1 = arg.get_dependency(1);
|
||||
|
||||
auto in0_l = input0.get_output_layout();
|
||||
auto in1_l = input1.get_output_layout();
|
||||
auto out_l = arg.get_output_layout();
|
||||
|
||||
size_t in0_batched_size = in0_l.count() / (in0_l.size.spatial[0] * in0_l.size.spatial[1]);
|
||||
size_t in1_batched_size = in1_l.count() / (in1_l.size.spatial[0] * in1_l.size.spatial[1]);
|
||||
size_t out_batched_size = out_l.count() / (out_l.size.spatial[0] * out_l.size.spatial[1]);
|
||||
|
||||
auto batched_dims_can_be_removed = in0_batched_size == 1 && in1_batched_size == 1 && out_batched_size == 1;
|
||||
if (gemm_with_bias) {
|
||||
auto bias_l = arg.get_dependency(2).get_output_layout();
|
||||
size_t bias_batched_size = bias_l.count() / (bias_l.size.spatial[0] * bias_l.size.spatial[1]);
|
||||
batched_dims_can_be_removed &= bias_batched_size == 1;
|
||||
}
|
||||
|
||||
size_t rank = cldnn::format::dimension(out_l.format);
|
||||
|
||||
dnnl::memory::data_type in0_dt = onednn::convert_data_type(in0_l.data_type);
|
||||
dnnl::memory::data_type in1_dt = onednn::convert_data_type(in1_l.data_type);
|
||||
dnnl::memory::data_type out_dt = onednn::convert_data_type(out_l.data_type);
|
||||
|
||||
dnnl::memory::dims in0_dims = onednn::convert_gemm_tensor(in0_l.size, rank, batched_dims_can_be_removed);
|
||||
dnnl::memory::dims in1_dims = onednn::convert_gemm_tensor(in1_l.size, rank, batched_dims_can_be_removed);
|
||||
dnnl::memory::dims out_dims = onednn::convert_gemm_tensor(out_l.size, rank, batched_dims_can_be_removed);
|
||||
|
||||
dnnl::memory::format_tag in0_fmt = onednn::convert_gemm_data_format(in0_dims);
|
||||
dnnl::memory::format_tag in1_fmt = onednn::convert_gemm_data_format(in1_dims);
|
||||
dnnl::memory::format_tag out_fmt = onednn::convert_gemm_data_format(out_dims);
|
||||
|
||||
if (prim->transpose_input0) {
|
||||
in0_fmt = transpose_format(in0_fmt);
|
||||
std::swap(in0_dims[in0_dims.size() - 1], in0_dims[in0_dims.size() - 2]);
|
||||
}
|
||||
|
||||
if (prim->transpose_input1) {
|
||||
in1_fmt = transpose_format(in1_fmt);
|
||||
std::swap(in1_dims[in1_dims.size() - 1], in1_dims[in1_dims.size() - 2]);
|
||||
}
|
||||
|
||||
dnnl::memory::desc in0_md(in0_dims, in0_dt, in0_fmt);
|
||||
dnnl::memory::desc in1_md(in1_dims, in1_dt, in1_fmt);
|
||||
dnnl::memory::desc out_md(out_dims, out_dt, out_fmt);
|
||||
|
||||
if (gemm_with_bias) {
|
||||
auto bias_l = arg.get_dependency(2).get_output_layout();
|
||||
auto bias_rank = cldnn::format::dimension(bias_l.format);
|
||||
dnnl::memory::data_type bias_dt = onednn::convert_data_type(bias_l.data_type);
|
||||
dnnl::memory::dims bias_dims = onednn::convert_gemm_tensor(bias_l.size, bias_rank, batched_dims_can_be_removed);
|
||||
dnnl::memory::format_tag bias_fmt = onednn::convert_gemm_data_format(bias_dims);
|
||||
dnnl::memory::desc bias_md(bias_dims, bias_dt, bias_fmt);
|
||||
|
||||
return std::make_shared<dnnl::matmul::desc>(
|
||||
in0_md,
|
||||
in1_md,
|
||||
bias_md,
|
||||
out_md);
|
||||
} else {
|
||||
return std::make_shared<dnnl::matmul::desc>(
|
||||
in0_md,
|
||||
in1_md,
|
||||
out_md);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const gemm_node& arg) {
|
||||
auto& engine = arg.get_program().get_engine();
|
||||
auto desc = get_gemm_descriptor(arg);
|
||||
auto attr = arg.get_onednn_primitive_attributes();
|
||||
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
|
||||
|
||||
return new gemm_onednn(arg, desc, attr, prim_desc);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
attach_gemm_onednn::attach_gemm_onednn() {
|
||||
implementation_map<gemm>::add(impl_types::onednn, gemm_onednn::create, {
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::u8, format::bfyx),
|
||||
std::make_tuple(data_types::i8, format::bfyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bfzyx),
|
||||
std::make_tuple(data_types::f16, format::bfzyx),
|
||||
std::make_tuple(data_types::u8, format::bfzyx),
|
||||
std::make_tuple(data_types::i8, format::bfzyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bfwzyx),
|
||||
std::make_tuple(data_types::f16, format::bfwzyx),
|
||||
std::make_tuple(data_types::u8, format::bfwzyx),
|
||||
std::make_tuple(data_types::i8, format::bfwzyx),
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@@ -47,6 +47,28 @@ dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims, bool is_grouped)
|
||||
return res;
|
||||
}
|
||||
|
||||
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed) {
|
||||
auto sizes = t.sizes(default_fmt_for_dims(dims, false));
|
||||
dnnl::memory::dims res(sizes.begin(), sizes.end());
|
||||
if (dims > 3) {
|
||||
for (size_t i = 0; i < dims - 3; i++) {
|
||||
res[i + 1] *= res[i];
|
||||
}
|
||||
res.erase(res.begin(), res.begin() + dims - 3);
|
||||
}
|
||||
if (res.size() == 3 && batched_dims_can_be_removed) {
|
||||
res.erase(res.begin());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) {
|
||||
if (dims.size() > 3)
|
||||
throw std::runtime_error("[clDNN] Unsupported dims size for onednn gemm: should be <= 3");
|
||||
return dims.size() == 3 ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::ab;
|
||||
}
|
||||
|
||||
|
||||
dnnl::memory::dims convert_spatials(cldnn::tensor t, size_t dims) {
|
||||
auto spatials = t.spatial;
|
||||
dnnl::memory::dims res(dims);
|
||||
@@ -75,7 +97,7 @@ dnnl::memory::data_type convert_data_type(cldnn::data_types dt) {
|
||||
case cldnn::data_types::i8: return dnnl::memory::data_type::s8;
|
||||
case cldnn::data_types::u8: return dnnl::memory::data_type::u8;
|
||||
case cldnn::data_types::i32: return dnnl::memory::data_type::s32;
|
||||
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to ondnn type");
|
||||
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to onednn type");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +117,7 @@ dnnl::memory::format_tag convert_data_format(cldnn::format fmt) {
|
||||
case cldnn::format::bs_fs_yx_bsv4_fsv2: return dnnl::memory::format_tag::ABcd4a2b;
|
||||
case cldnn::format::bs_fs_yx_bsv32_fsv16: return dnnl::memory::format_tag::NChw32n16c;
|
||||
case cldnn::format::bs_fs_zyx_bsv16_fsv16: return dnnl::memory::format_tag::NCdhw16n16c;
|
||||
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to ondnn layout " + fmt_to_str(fmt));
|
||||
default: throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to onednn layout " + fmt_to_str(fmt));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,10 +23,12 @@ void combine_bf_with_first_spatial_dim(cldnn::layout& l);
|
||||
|
||||
// cldnn -> onednn
|
||||
dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims = 2, bool is_grouped = false);
|
||||
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed);
|
||||
dnnl::memory::dims convert_spatials(cldnn::tensor t, size_t dims = 2);
|
||||
dnnl::memory::dims flatten_tensor(cldnn::tensor t);
|
||||
dnnl::memory::data_type convert_data_type(cldnn::data_types dt);
|
||||
dnnl::memory::format_tag convert_data_format(cldnn::format fmt);
|
||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims);
|
||||
dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::undef, bool flatten = false);
|
||||
dnnl::algorithm convert_activation_func(cldnn::activation_func func);
|
||||
cldnn::format find_format(dnnl::memory::desc desc, bool is_grouped = false);
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "generic_layer.hpp"
|
||||
#include <sstream>
|
||||
|
||||
#include "gemm_inst.h"
|
||||
#include "eltwise_inst.h"
|
||||
#include "pooling_inst.h"
|
||||
#include "one_hot_inst.h"
|
||||
@@ -1212,9 +1213,10 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
|
||||
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
} else if (node.is_type<fully_connected>()) {
|
||||
auto& fc_node = node.as<fully_connected>();
|
||||
auto wei_dt = fc_node.weights().get_output_layout().data_type;
|
||||
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
|
||||
bool is_fc = node.is_type<fully_connected>();
|
||||
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout().data_type :
|
||||
node.as<gemm>().get_dependency(1).get_output_layout().data_type;
|
||||
|
||||
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
|
||||
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
|
||||
@@ -1486,7 +1488,8 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (node.is_type<fully_connected>()) {
|
||||
// TODO: uncomment this code when onednn gemm implementations will have real perf improvements vs cldnn
|
||||
} else if (node.is_type<fully_connected>()/* || node.is_type<gemm>()*/) {
|
||||
if (!_optimization_attributes.use_onednn_impls)
|
||||
return impl_types::ocl;
|
||||
|
||||
@@ -1498,29 +1501,75 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
|
||||
|
||||
for (auto& fo : node.get_fused_primitives()) {
|
||||
if (fo.node->is_type<eltwise>()) {
|
||||
auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
|
||||
auto out_layout = node.get_output_layout();
|
||||
auto in_dt = in_layout.data_type;
|
||||
auto out_dt = out_layout.data_type;
|
||||
if ((out_layout.count() == in_layout.count()) &&
|
||||
(data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt)) && in_dt != out_dt &&
|
||||
fo.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(in_layout)) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
// FC checkings
|
||||
if (node.is_type<fully_connected>()) {
|
||||
auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
|
||||
auto out_layout = node.get_output_layout();
|
||||
auto in_dt = in_layout.data_type;
|
||||
auto out_dt = out_layout.data_type;
|
||||
if ((out_layout.count() == in_layout.count()) &&
|
||||
(data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt)) && in_dt != out_dt &&
|
||||
fo.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(in_layout)) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
}
|
||||
// Gemm checkings
|
||||
// TODO: investigate why currently onednn gemm has some "sum" post-op restrictions
|
||||
// which don't correlate with fc checkings in the code above
|
||||
// Temprorary WA: disable onednn gemm with sum post-op inside
|
||||
} else {
|
||||
auto& e_node = fo.node->as<eltwise>();
|
||||
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OneDnn doesn't support spatial dimensions for output
|
||||
auto fc_prim = node.as<fully_connected>().get_primitive();
|
||||
auto out_layout = node.get_output_layout();
|
||||
size_t rank = cldnn::format::dimension(out_layout.format);
|
||||
auto size = out_layout.size;
|
||||
for (int i = 0; i < rank - 2 - (fc_prim->input_size == 3 ? 1 : 0); i++) {
|
||||
if (size.spatial[i] != 1) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
if (node.is_type<fully_connected>()) {
|
||||
auto fc_prim = node.as<fully_connected>().get_primitive();
|
||||
auto out_layout = node.get_output_layout();
|
||||
size_t rank = cldnn::format::dimension(out_layout.format);
|
||||
auto size = out_layout.size;
|
||||
// OneDnn doesn't support spatial dimensions for output
|
||||
for (int i = 0; i < rank - 2 - (fc_prim->input_size == 3 ? 1 : 0); i++) {
|
||||
if (size.spatial[i] != 1) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
impl_candidate = impl_types::ocl;
|
||||
auto gemm_prim = node.as<gemm>().get_primitive();
|
||||
auto in0_l = node.get_dependency(0).get_output_layout();
|
||||
auto in1_l = node.get_dependency(1).get_output_layout();
|
||||
auto out_l = node.get_output_layout();
|
||||
auto has_input2 = gemm_prim->dependencies().size() == 3;
|
||||
size_t in2_batched_size;
|
||||
if (has_input2) {
|
||||
auto in2_l = node.get_dependency(2).get_output_layout();
|
||||
in2_batched_size = in2_l.count() / (in2_l.size.spatial[0] * in2_l.size.spatial[1]);
|
||||
}
|
||||
size_t size_k = gemm_prim->transpose_input0 ? in0_l.size.spatial[1] : in0_l.size.spatial[0];
|
||||
|
||||
size_t in0_batched_size = in0_l.count() / (in0_l.size.spatial[0] * in0_l.size.spatial[1]);
|
||||
size_t in1_batched_size = in1_l.count() / (in1_l.size.spatial[0] * in1_l.size.spatial[1]);
|
||||
size_t out_batched_size = out_l.count() / (out_l.size.spatial[0] * out_l.size.spatial[1]);
|
||||
|
||||
auto valid_input_batch = in0_batched_size != 1 && (in1_batched_size == in0_batched_size || in1_batched_size == 1);
|
||||
auto valid_output_batch = in0_batched_size > in1_batched_size ? out_batched_size == in0_batched_size :
|
||||
out_batched_size == in1_batched_size;
|
||||
auto valid_extra_input_batch = has_input2 ? in2_batched_size == 1 || in2_batched_size == out_batched_size : true;
|
||||
auto valid_scale_factor = gemm_prim->alpha == 1.f && (has_input2 ? gemm_prim->beta == 1.f : true);
|
||||
auto unsupported_onednn_gemm = !valid_input_batch ||
|
||||
!valid_output_batch ||
|
||||
!valid_extra_input_batch ||
|
||||
!valid_scale_factor;
|
||||
|
||||
// Gemm with k < 64 is calculated via ref kernel in onednn so cldnn way is more preferable for such cases
|
||||
if (size_k < 64 || unsupported_onednn_gemm)
|
||||
impl_candidate = impl_types::ocl;
|
||||
}
|
||||
|
||||
preferred_impl = impl_candidate;
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user