Duplicated: OneDNN`s Convolutions and Deconvolutions support (#8029)

* [GPU] Fix incorrect reusage of OneDNN postops configurations

* [GPU] Add OneDNN's convolutions and deconvolutions support

* [GPU] Do not run fusing unittest when imad is not supported

Co-authored-by: Sergey, Shlyapnikov <sergey.shlyapnikov@intel.com>
This commit is contained in:
Mingyu Kim 2021-10-16 17:24:44 +09:00 committed by GitHub
parent b2977632d7
commit ae2913d3b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1010 additions and 156 deletions

View File

@ -84,7 +84,12 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
if (clonedNetwork.getFunction()) { if (clonedNetwork.getFunction()) {
auto nGraphFunc = clonedNetwork.getFunction(); auto nGraphFunc = clonedNetwork.getFunction();
TransformationsPipeline transformations(config); auto transformation_config = CLDNNPlugin::Config(config);
#ifdef ENABLE_ONEDNN_FOR_GPU
if (GetDeviceInfo(config.key_config_map).supports_immad)
transformation_config.enable_fp16_for_quantized_models = false;
#endif
TransformationsPipeline transformations(transformation_config);
transformations.apply(nGraphFunc); transformations.apply(nGraphFunc);
} }

View File

@ -459,7 +459,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
} }
#define GET_DATA_BS_FS_YX_BSV16_FSV16_INDEX(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV16_FSV16_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \ get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -475,7 +475,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 16) CAT(prefix, _PAD_AFTER_SIZE_X), 16, 16)
#define GET_DATA_BS_FS_YX_BSV32_FSV32_INDEX(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV32_FSV32_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \ get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -491,7 +491,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 32, 32) CAT(prefix, _PAD_AFTER_SIZE_X), 32, 32)
#define GET_DATA_BS_FS_YX_BSV4_FSV4_INDEX(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV4_FSV4_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \ get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -507,7 +507,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 4, 4) CAT(prefix, _PAD_AFTER_SIZE_X), 4, 4)
#define GET_DATA_BS_FS_YX_BSV4_FSV2_INDEX(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV4_FSV2_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \ get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -523,7 +523,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 4, 2) CAT(prefix, _PAD_AFTER_SIZE_X), 4, 2)
#define GET_DATA_BS_FS_YX_BSV32_FSV16_INDEX(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV32_FSV16_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \ get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -539,7 +539,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 32, 16) CAT(prefix, _PAD_AFTER_SIZE_X), 32, 16)
#define GET_DATA_BS_FS_ZYX_BSV16_FSV16_INDEX(prefix, b, f, z, y, x) \ #define GET_DATA_BS_FS_ZYX_BSV16_FSV16_INDEX(prefix, b, f, z, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \ get_bs_fs_zyx_bsv_fsv_index( \
b, f, z, y, x, \ b, f, z, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -555,7 +555,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 16) CAT(prefix, _PAD_AFTER_SIZE_X), 16, 16)
#define GET_DATA_BS_FS_YX_BSV16_FSV16_INDEX_SAFE(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV16_FSV16_INDEX_SAFE(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index_safe( \ get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -572,7 +572,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 16, 16) CAT(prefix, _PAD_AFTER_SIZE_X), 16, 16)
#define GET_DATA_BS_FS_YX_BSV32_FSV32_INDEX_SAFE(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV32_FSV32_INDEX_SAFE(prefix, b, f, y, x) \
FUNC_CALL(get_bs_fs_zyx_bsv_fsv_index_safe)( \ get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -589,7 +589,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 32, 32) CAT(prefix, _PAD_AFTER_SIZE_X), 32, 32)
#define GET_DATA_BS_FS_YX_BSV4_FSV4_INDEX_SAFE(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV4_FSV4_INDEX_SAFE(prefix, b, f, y, x) \
FUNC_CALL(get_bs_fs_zyx_bsv_fsv_index_safe)( \ get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -606,7 +606,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 4, 4) CAT(prefix, _PAD_AFTER_SIZE_X), 4, 4)
#define GET_DATA_BS_FS_YX_BSV4_FSV2_INDEX_SAFE(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV4_FSV2_INDEX_SAFE(prefix, b, f, y, x) \
FUNC_CALL(get_bs_fs_zyx_bsv_fsv_index_safe)( \ get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -623,7 +623,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 4, 2) CAT(prefix, _PAD_AFTER_SIZE_X), 4, 2)
#define GET_DATA_BS_FS_YX_BSV32_FSV16_INDEX_SAFE(prefix, b, f, y, x) \ #define GET_DATA_BS_FS_YX_BSV32_FSV16_INDEX_SAFE(prefix, b, f, y, x) \
FUNC_CALL(get_bs_fs_zyx_bsv_fsv_index_safe)( \ get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \ b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \ CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
@ -640,7 +640,7 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_AFTER_SIZE_X), 32, 16) CAT(prefix, _PAD_AFTER_SIZE_X), 32, 16)
#define GET_DATA_BS_FS_ZYX_BSV16_FSV16_INDEX_SAFE(prefix, b, f, z, y, x) \ #define GET_DATA_BS_FS_ZYX_BSV16_FSV16_INDEX_SAFE(prefix, b, f, z, y, x) \
get_bs_fs_zyx_bsv_fsv_index_safe( \ get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, z, y, x, \ b, f, z, y, x, \
CAT(prefix, _SIZE_X ), \ CAT(prefix, _SIZE_X ), \
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \

View File

@ -295,11 +295,57 @@ void prepare_primitive_fusing::fuse_bias(program &p) {
new_node.recalc_output_layout(); new_node.recalc_output_layout();
}; };
auto recalculate_biases = [&](data_node& original_node, data_node& new_node) -> bool {
auto original_mem = original_node.get_attached_memory_ptr();
auto new_mem = new_node.get_attached_memory_ptr();
if (original_mem->count() != new_mem->count() || original_mem->get_layout().data_type != new_mem->get_layout().data_type)
return false;
switch (original_mem->get_layout().data_type) {
case data_types::f32: {
mem_lock<float, mem_lock_type::write> original_bias_mem(original_mem, p.get_stream());
mem_lock<float, mem_lock_type::read> new_bias_mem(new_mem, p.get_stream());
float* original_data = original_bias_mem.data();
float* new_data = new_bias_mem.data();
for (size_t i = 0; i < original_bias_mem.size(); i++)
original_data[i] += new_data[i];
break;
}
case data_types::f16: {
mem_lock<uint16_t, mem_lock_type::write> original_bias_mem(original_mem, p.get_stream());
mem_lock<uint16_t, mem_lock_type::read> new_bias_mem(new_mem, p.get_stream());
uint16_t* original_data = original_bias_mem.data();
uint16_t* new_data = new_bias_mem.data();
for (size_t i = 0; i < original_bias_mem.size(); i++) {
float new_val = half_to_float(original_data[i]) + half_to_float(new_data[i]);
original_data[i] = float_to_half(new_val);
}
break;
}
default:
return false;
}
return true;
};
if (replace_candidate.is_type<convolution>()) { if (replace_candidate.is_type<convolution>()) {
auto& conv = replace_candidate.as<convolution>(); auto& conv = replace_candidate.as<convolution>();
auto desc = conv.get_primitive(); auto desc = conv.get_primitive();
std::vector<primitive_id> biases = {bias_name}; std::vector<primitive_id> biases = {bias_name};
// If the primitive has biases, then we try to combine the values, or do nothing and keep as fused sum.
if (conv.bias_term()) {
if (conv.bias().is_type<data>() && bias_node.is_type<data>()) {
if (recalculate_biases(conv.bias().as<data>(), bias_node.as<data>())) {
p.replace_all_usages(eltw_node, conv);
p.add_optimized_primitive_info(eltw_node.id(), {conv.id()});
p.remove_all_connections(eltw_node);
p.remove_if_dangling(eltw_node);
}
}
continue;
}
auto conv_with_bias_prim = std::make_shared<convolution>(desc->id + "_tmp", auto conv_with_bias_prim = std::make_shared<convolution>(desc->id + "_tmp",
desc->input[0], desc->input[0],
desc->weights, desc->weights,
@ -325,6 +371,19 @@ void prepare_primitive_fusing::fuse_bias(program &p) {
auto desc = deconv.get_primitive(); auto desc = deconv.get_primitive();
std::vector<primitive_id> biases = {bias_name}; std::vector<primitive_id> biases = {bias_name};
// If the primitive has biases, then we try to combine the values, or do nothing and keep as fused sum.
if (deconv.bias_term()) {
if (deconv.bias().is_type<data>() && bias_node.is_type<data>()) {
if (recalculate_biases(deconv.bias().as<data>(), bias_node.as<data>())) {
p.replace_all_usages(eltw_node, deconv);
p.add_optimized_primitive_info(eltw_node.id(), {deconv.id()});
p.remove_all_connections(eltw_node);
p.remove_if_dangling(eltw_node);
}
}
continue;
}
auto deconv_with_bias_prim = std::make_shared<deconvolution>(desc->id + "_tmp", auto deconv_with_bias_prim = std::make_shared<deconvolution>(desc->id + "_tmp",
desc->input[0], desc->input[0],
desc->weights, desc->weights,
@ -340,6 +399,20 @@ void prepare_primitive_fusing::fuse_bias(program &p) {
} else if (replace_candidate.is_type<fully_connected>()) { } else if (replace_candidate.is_type<fully_connected>()) {
auto& fc = replace_candidate.as<fully_connected>(); auto& fc = replace_candidate.as<fully_connected>();
auto desc = fc.get_primitive(); auto desc = fc.get_primitive();
// If the primitive has biases, then we try to combine the values, or do nothing and keep as fused sum.
if (fc.bias_term()) {
if (fc.bias().is_type<data>() && bias_node.is_type<data>()) {
if (recalculate_biases(fc.bias().as<data>(), bias_node.as<data>())) {
p.replace_all_usages(eltw_node, fc);
p.add_optimized_primitive_info(eltw_node.id(), {fc.id()});
p.remove_all_connections(eltw_node);
p.remove_if_dangling(eltw_node);
}
}
continue;
}
auto fc_with_bias_prim = std::make_shared<fully_connected>(desc->id + "_tmp", auto fc_with_bias_prim = std::make_shared<fully_connected>(desc->id + "_tmp",
desc->input[0], desc->input[0],
desc->weights, desc->weights,

View File

@ -85,7 +85,7 @@ protected:
} else { } else {
throw std::runtime_error("Unsupported data type for activations zero points for oneDNN convolution"); throw std::runtime_error("Unsupported data type for activations zero points for oneDNN convolution");
} }
a_zp.as<data>().attach_memory(s32_mem); a_zp.as<data>().attach_memory(s32_mem, false);
int mask = a_zp.get_output_layout().count() > 1 ? 2 : 0; int mask = a_zp.get_output_layout().count() > 1 ? 2 : 0;
@ -211,6 +211,51 @@ attach_convolution_onednn::attach_convolution_onednn() {
std::make_tuple(data_types::f16, format::bfyx), std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx), std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx), std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv16),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f32, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::f16, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::u8, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::i8, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv32),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv2),
}); });
} }

View File

@ -0,0 +1,212 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "deconvolution_inst.h"
#include "eltwise_inst.h"
#include "quantize_inst.h"
#include "primitive_onednn_base.h"
#include "impls/implementation_map.hpp"
#include "kernel_selector_common.h"
#include <oneapi/dnnl/dnnl.hpp>
#include <algorithm>
#include <memory>
namespace cldnn {
namespace onednn {
struct deconvolution_onednn : typed_primitive_onednn_impl<deconvolution, dnnl::deconvolution_forward::desc> {
using parent = typed_primitive_onednn_impl<deconvolution, dnnl::deconvolution_forward::desc>;
using parent::parent;
protected:
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<deconvolution_onednn>(*this);
}
bool validate_impl(const typed_primitive_inst<deconvolution>& instance) const override {
bool res = true;
auto outer_id = _outer.id();
auto data_type = instance.node.input().get_output_layout().data_type;
// Integer signed/unsigned is ok for convoluiton
CLDNN_ERROR_DATA_TYPES_MISMATCH_IGNORE_SIGN(outer_id,
"Input memory",
data_type,
"filter memory",
instance.weights_memory(0)->get_layout().data_type,
"");
return res;
}
std::unordered_map<int, dnnl::memory> get_arguments(deconvolution_inst& instance) const override {
std::unordered_map<int, dnnl::memory> args = parent::get_arguments(instance);
auto& engine = instance.get_network().get_engine();
auto onednn_engine = engine.get_onednn_engine();
{
auto weights = instance.weights_memory(0);
args.insert({DNNL_ARG_WEIGHTS, weights->get_onednn_memory(_pd.weights_desc(0))});
}
if (instance.bias_term()) {
auto bias = instance.bias_memory(0);
args.insert({DNNL_ARG_BIAS, bias->get_onednn_memory(_pd.weights_desc(1))});
}
return args;
}
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<deconvolution>& arg) {
auto attrs = parent::get_primitive_attributes(arg);
return attrs;
}
static kernel_selector::WeightsReorderParams get_weights_reorder(const deconvolution_node& arg, const dnnl::primitive_desc& pd) {
kernel_selector::WeightsReorderParams weights_reorder_params;
auto& reorderKS = kernel_selector::ReorderWeightsKernelSelctor::Instance();
kernel_selector::reorder_weights_params r_params;
auto cldnn_prim = arg.get_primitive();
auto weights_layout = arg.get_dependency(1).get_output_layout();
auto grouped_weights = format::is_grouped(weights_layout.format) || arg.get_primitive()->grouped_weights_shape;
cldnn::format out_fmt = onednn::convert_format(onednn::get_format_by_desc(pd.weights_desc(0)), grouped_weights);
kernel_selector::WeightsLayout reqLayout = to_weights_layout(out_fmt, cldnn_prim->grouped_weights_shape);
set_params(arg, r_params);
r_params.layerID = arg.id() + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, cldnn_prim->grouped_weights_shape);
r_params.output = r_params.input.TransformIgnorePadding(reqLayout, r_params.input.GetDType(), arg.get_groups(), false);
r_params.rotate_180 = false;
kernel_selector::reorder_optional_params op;
kernel_selector::KernelsData kernels_data = reorderKS.GetBestKernels(r_params, op);
if (kernels_data.empty()) {
throw std::runtime_error("No suitable kernel found for weights reorder from " +
kernel_selector::toString(r_params.input.GetLayout()) + " to " +
kernel_selector::toString(r_params.output.GetLayout()));
}
weights_reorder_params.engine = kernel_selector::WeightsReorderParams::Engine::GPU;
weights_reorder_params.clKernel = std::make_shared<kernel_selector::clKernelData>(kernels_data[0].kernels[0]);
weights_reorder_params.dest = r_params.output;
return weights_reorder_params;
}
static std::shared_ptr<dnnl::deconvolution_forward::desc> get_deconvolution_descriptor(const deconvolution_node& arg) {
auto prim = arg.get_primitive();
auto& input = arg.get_dependency(0);
auto& weights = arg.get_dependency(1);
auto spatials_rank = cldnn::format::spatial_num(input.get_output_layout().format);
auto stride = onednn::convert_spatials(prim->stride, spatials_rank);
auto dilation = onednn::convert_spatials(cldnn::tensor{1}, spatials_rank);
auto pad_l = onednn::convert_spatials(prim->input_offset, spatials_rank);
auto pad_r = onednn::convert_spatials(prim->input_offset, spatials_rank);
auto input_md = onednn::layout_to_memory_desc(input.get_output_layout());
auto weights_md = onednn::layout_to_memory_desc(weights.get_output_layout(), dnnl::memory::format_tag::any);
auto output_md = onednn::layout_to_memory_desc(arg.get_output_layout());
auto grouped_weights = format::is_grouped(weights.get_output_layout().format) || prim->grouped_weights_shape;
for (size_t i = 0; i < dilation.size(); i++) {
dilation[i]--;
pad_l[i] = -pad_l[i];
int weights_offset = (grouped_weights ? 3 : 2) + static_cast<int>(i);
auto os = output_md.dims()[2 + i];
auto is = input_md.dims()[2 + i];
auto ks = weights_md.dims()[weights_offset];
auto kernel_range = 1 + (ks - 1) * (dilation[i] + 1);
pad_r[i] = (is - 1) * stride[i] - os + kernel_range - pad_l[i];
}
if (arg.bias_term()) {
auto bias_md = onednn::layout_to_memory_desc(arg.get_dependency(2).get_output_layout(), dnnl::memory::format_tag::any, true);
return std::make_shared<dnnl::deconvolution_forward::desc>(
dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
input_md,
weights_md,
bias_md,
output_md,
stride,
dilation,
pad_l,
pad_r);
} else {
return std::make_shared<dnnl::deconvolution_forward::desc>(
dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
input_md,
weights_md,
output_md,
stride,
dilation,
pad_l,
pad_r);
}
}
public:
static primitive_impl* create(const deconvolution_node& arg) {
auto& engine = arg.get_program().get_engine();
auto desc = get_deconvolution_descriptor(arg);
auto attr = get_primitive_attributes(arg);
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
return new deconvolution_onednn(arg, desc, attr, prim_desc, get_weights_reorder(arg, prim_desc));
}
};
namespace detail {
attach_deconvolution_onednn::attach_deconvolution_onednn() {
implementation_map<deconvolution>::add(impl_types::onednn, deconvolution_onednn::create, {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv16),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv2),
});
}
} // namespace detail
} // namespace onednn
} // namespace cldnn

View File

@ -12,6 +12,7 @@
#include "utils.hpp" #include "utils.hpp"
#include "quantize_inst.h" #include "quantize_inst.h"
#include "reorder_inst.h"
#include "reorder/reorder_weights_kernel_selector.h" #include "reorder/reorder_weights_kernel_selector.h"
#include "reorder/reorder_kernel_base.h" #include "reorder/reorder_kernel_base.h"
@ -778,8 +779,10 @@ protected:
} }
} }
} }
} else if (node->is_type<reorder>()) {
continue;
} else { } else {
throw std::runtime_error("Unsupported fused op for onednn prim"); throw std::runtime_error("Unsupported fused op of " + node->get_primitive()->type_string() + " type for oneDNN primitive");
} }
} }
@ -795,7 +798,12 @@ protected:
} }
// Update total onednn post-ops info // Update total onednn post-ops info
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops)); auto it = onednn_fusing_map.find(arg.id());
if (it != onednn_fusing_map.end()) {
it->second = std::move(fused_ops);
} else {
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops));
}
// Trying to optimize more than 1 post-ops // Trying to optimize more than 1 post-ops
auto post_ops_size = onednn_fusing_map[arg.id()].size(); auto post_ops_size = onednn_fusing_map[arg.id()].size();

View File

@ -12,6 +12,7 @@ namespace onednn {
void register_implementations() { void register_implementations() {
REGISTER_ONEDNN_IMPL(convolution); REGISTER_ONEDNN_IMPL(convolution);
REGISTER_ONEDNN_IMPL(deconvolution);
REGISTER_ONEDNN_IMPL(concatenation); REGISTER_ONEDNN_IMPL(concatenation);
REGISTER_ONEDNN_IMPL(eltwise); REGISTER_ONEDNN_IMPL(eltwise);
REGISTER_ONEDNN_IMPL(gemm); REGISTER_ONEDNN_IMPL(gemm);

View File

@ -18,6 +18,7 @@ namespace detail {
} }
REGISTER_ONEDNN_IMPL(convolution); REGISTER_ONEDNN_IMPL(convolution);
REGISTER_ONEDNN_IMPL(deconvolution);
REGISTER_ONEDNN_IMPL(concatenation); REGISTER_ONEDNN_IMPL(concatenation);
REGISTER_ONEDNN_IMPL(eltwise); REGISTER_ONEDNN_IMPL(eltwise);
REGISTER_ONEDNN_IMPL(gemm); REGISTER_ONEDNN_IMPL(gemm);

View File

@ -153,6 +153,10 @@ private:
const layout& output_layout, const layout& output_layout,
const layout& weights_layout, const layout& weights_layout,
std::shared_ptr<const convolution> conv); std::shared_ptr<const convolution> conv);
bool convolution_bs_fs_yx_bsv32_fsv32_opt(const layout &input_layout,
const layout& output_layout,
const layout& weights_layout,
std::shared_ptr<const convolution> conv);
bool convolution_fs_b_yx_fsv32_opt(const layout& input_layout, bool convolution_fs_b_yx_fsv32_opt(const layout& input_layout,
const layout& output_layout, const layout& output_layout,
const layout& weights_layout, const layout& weights_layout,
@ -173,6 +177,7 @@ public:
format get_preferred_format(program_node& node); format get_preferred_format(program_node& node);
impl_types get_preferred_impl_type(program_node& node, format preferred_format); impl_types get_preferred_impl_type(program_node& node, format preferred_format);
bool are_data_types_suitable_for_onednn(program_node& node);
bool is_format_supported(program_node& node, format::type fmt); bool is_format_supported(program_node& node, format::type fmt);
// Returns whether reorder between "prev" with format fmt_prev and "next" with format fmt_next // Returns whether reorder between "prev" with format fmt_prev and "next" with format fmt_next

View File

@ -739,83 +739,158 @@ layout layout_optimizer::get_expected_layout(layout const& current_layout,
auto expected_data_type = current_layout.data_type; auto expected_data_type = current_layout.data_type;
auto expected_format = current_layout.format; auto expected_format = current_layout.format;
auto input_layout = node.get_dependency(0).get_output_layout(); auto input_layout = node.get_dependency(0).get_output_layout();
auto output_layout = node.calc_output_layout();
const float cond_denom = _total_conv > 0 ? 1.0f / static_cast<float>(_total_conv) : 1.0f; const float cond_denom = _total_conv > 0 ? 1.0f / static_cast<float>(_total_conv) : 1.0f;
auto output_layout = node.calc_output_layout(); bool is_dw = input_layout.size.feature[0] == static_cast<int>(prim->groups);
int ofm_per_group = output_layout.size.feature[0] / prim->groups;
int ifm_per_group = input_layout.size.feature[0] / prim->groups;
int compute_block = 32;
bool valid_grouped = !is_dw && prim->groups > 1 && (ofm_per_group % compute_block == 0 && ifm_per_group % compute_block == 0);
bool valid_int8_dw = is_dw && output_layout.size.batch[0] % 16 == 0;
bool non_grouped = prim->groups == 1;
bool is_2d = input_layout.format.spatial_num() == 2;
bool onednn_valid_post_ops = get_post_ops_count(node) <= 32;
bool use_onednn_impls = _optimization_attributes.use_onednn_impls;
bool i8_u8_input = input_layout.data_type == data_types::u8 || input_layout.data_type == data_types::i8;
if ((input_layout.data_type == data_types::u8 || input_layout.data_type == data_types::i8)) { if (use_onednn_impls && onednn_valid_post_ops) {
if ((_optimization_attributes.bs_fs_yx_bsv16_fsv16_network && expected_tensor.batch[0] % 16 == 0 && for (auto& fo : node.get_fused_primitives()) {
convolution_bs_fs_yx_bsv16_fsv16_opt(input_layout, output_layout, weights_layout, prim))) { if (fo.node->is_type<eltwise>()) {
expected_format = cldnn::format::bs_fs_yx_bsv16_fsv16; auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
} else if ((_optimization_attributes.b_fs_yx_fsv16_network && auto out_layout = node.get_output_layout();
convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim))) { auto in_dt = in_layout.data_type;
expected_format = cldnn::format::b_fs_yx_fsv16; auto out_dt = out_layout.data_type;
} else if ((_optimization_attributes.b_fs_zyx_fsv16_network && if ((out_layout.count() == in_layout.count()) &&
convolution_b_fs_zyx_fsv16_opt(input_layout, output_layout, weights_layout, prim))) { (data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt)) && in_dt != out_dt) {
expected_format = cldnn::format::b_fs_zyx_fsv16; onednn_valid_post_ops = false;
} else { break;
expected_format = imad_case(node); }
}
} }
expected_tensor = current_layout.size; }
} else if (_optimization_attributes.b_fs_zyx_fsv16_network &&
convolution_b_fs_zyx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) {
expected_tensor = current_layout.size;
if ((current_layout.data_type == data_types::f32 && expected_tensor.batch[0] % 16 == 0) ||
(current_layout.data_type == data_types::f16 && expected_tensor.batch[0] % 32 == 0))
expected_format = cldnn::format::bs_fs_zyx_bsv16_fsv16;
else
expected_format = cldnn::format::b_fs_zyx_fsv16;
} else if (current_layout.format == format::bfzyx) { if (use_onednn_impls) {
expected_tensor = current_layout.size; /* ***************************** OneDNN impls format selection part ****************************** */
expected_format = cldnn::format::bfzyx; if (i8_u8_input) {
} else if (_optimization_attributes.bs_fs_yx_bsv16_fsv16_network && if ((non_grouped || valid_grouped || valid_int8_dw) && onednn_valid_post_ops && is_2d) {
convolution_bs_fs_yx_bsv16_fsv16_opt(node.input().get_output_layout(), output_layout, weights_layout, prim)) { if (input_layout.size.batch[0] % 16 == 0)
expected_tensor = current_layout.size; expected_format = cldnn::format::bs_fs_yx_bsv32_fsv32;
expected_format = cldnn::format::bs_fs_yx_bsv16_fsv16; else
} else if (_optimization_attributes.fs_b_yx_fsv32_network && !node.get_transposed() && expected_format = cldnn::format::b_fs_yx_fsv32;
((convolution_fs_b_yx_fsv32_opt(input_layout, } else if ((_optimization_attributes.b_fs_yx_fsv16_network &&
output_layout, convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) && is_2d) {
weights_layout, prim) || if (is_dw)
(((node.get_dependency(0).is_type<convolution>() && is_format_optimized(node.get_dependency(0).as<convolution>(), format::fs_b_yx_fsv32)) expected_format = cldnn::format::b_fs_yx_fsv32;
|| (_optimized_conv_count.at({format::fs_b_yx_fsv32, false}) * cond_denom > 0.8f)) && else
convolution_fs_b_yx_fsv32_opt(input_layout, expected_format = cldnn::format::b_fs_yx_fsv16;
output_layout, } else {
weights_layout, prim, true))))) { expected_format = imad_case(node);
// Chose fs_b_yx_fsv32 layout in two cases: 1-st: the current conv primitive totally supports fs_b_yx_fsv32 layout }
// 2-nd: the previous conv primitive supports fs_b_yx_fsv32 layout and expected_tensor = current_layout.size;
// current conv primitives supports this one with weak restrictions - } else if (input_layout.data_type == data_types::f16 && is_2d) {
// that should be cheaper than reordering data to another layout
expected_tensor = current_layout.size;
expected_format = format::fs_b_yx_fsv32;
} else if (should_select_b_fs_yx_fsv16_layout(node, weights_layout)) {
expected_tensor = current_layout.size;
expected_format = cldnn::format::b_fs_yx_fsv16;
} else if (current_layout.data_type == data_types::f16 &&
layout_optimizer::convolution_byxf_opt(input_layout, current_layout, weights_layout, node) &&
(users_for_convolution_byxf_opt(node, 2) || deps_for_convolution_byxf_opt(node, 2)) &&
// todo: remove this condition when yxfb optimizations will be disabled
current_layout.format != cldnn::format::yxfb && current_layout.size.batch[0] == 1) {
expected_tensor = current_layout.size;
expected_format = cldnn::format::byxf;
} else if (current_layout.format == format::b_fs_yx_fsv4 ||
current_layout.format == format::os_is_yx_osv16_isv4) {
// imad case
// nothing to do, just go out from here.
} else if (layout_optimizer::convolution_bfyx_opt(current_layout, weights_layout, prim) ||
(_output_size_handling_enabled && prim->with_output_size) || node.get_transposed()) {
{
expected_tensor = current_layout.size; expected_tensor = current_layout.size;
if (current_layout.format == format::b_fs_zyx_fsv16 || current_layout.format == format::bs_fs_zyx_bsv16_fsv16)
expected_format = cldnn::format::bfzyx;
else
expected_format = cldnn::format::bfyx;
}
if (input_layout.size.batch[0] >= 16 && onednn_valid_post_ops) {
if (output_layout.data_type == input_layout.data_type) {
if (non_grouped || valid_grouped || is_dw) {
expected_format = cldnn::format::bs_fs_yx_bsv32_fsv16;
} else {
expected_format = cldnn::format::b_fs_yx_fsv16;
}
} else {
expected_format = cldnn::format::bs_fs_yx_bsv16_fsv16;
}
} else {
expected_format = cldnn::format::b_fs_yx_fsv16;
}
} else if (input_layout.data_type == data_types::f16 &&
convolution_bs_fs_yx_bsv16_fsv16_opt(input_layout, output_layout, weights_layout, prim) &&
(output_layout.data_type == input_layout.data_type ||
!data_type_traits::is_floating_point(input_layout.data_type))) {
expected_tensor = current_layout.size;
if (prim->groups == 1 || (output_layout.size.feature[0] % 16 == 0 && input_layout.size.feature[0] % 16 == 0)) {
expected_format = cldnn::format::bs_fs_yx_bsv32_fsv16;
} else {
expected_format = cldnn::format::bs_fs_yx_bsv16_fsv16;
}
}
} else { } else {
expected_tensor = current_layout.size; /* *************************** Native impls format selection part ************************** */
expected_format = cldnn::format::yxfb; if (i8_u8_input) {
if ((_optimization_attributes.bs_fs_yx_bsv16_fsv16_network && expected_tensor.batch[0] % 16 == 0 &&
convolution_bs_fs_yx_bsv16_fsv16_opt(input_layout, output_layout, weights_layout, prim))) {
expected_format = cldnn::format::bs_fs_yx_bsv16_fsv16;
} else if ((_optimization_attributes.b_fs_yx_fsv16_network &&
convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim))) {
expected_format = cldnn::format::b_fs_yx_fsv16;
} else if ((_optimization_attributes.b_fs_zyx_fsv16_network &&
convolution_b_fs_zyx_fsv16_opt(input_layout, output_layout, weights_layout, prim))) {
expected_format = cldnn::format::b_fs_zyx_fsv16;
} else {
expected_format = imad_case(node);
}
expected_tensor = current_layout.size;
} else if (_optimization_attributes.b_fs_zyx_fsv16_network &&
convolution_b_fs_zyx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) {
expected_tensor = current_layout.size;
if ((current_layout.data_type == data_types::f32 && expected_tensor.batch[0] % 16 == 0) ||
(current_layout.data_type == data_types::f16 && expected_tensor.batch[0] % 32 == 0))
expected_format = cldnn::format::bs_fs_zyx_bsv16_fsv16;
else
expected_format = cldnn::format::b_fs_zyx_fsv16;
} else if (current_layout.format == format::bfzyx) {
expected_tensor = current_layout.size;
expected_format = cldnn::format::bfzyx;
} else if (_optimization_attributes.bs_fs_yx_bsv16_fsv16_network &&
convolution_bs_fs_yx_bsv16_fsv16_opt(node.input().get_output_layout(), output_layout, weights_layout, prim)) {
expected_tensor = current_layout.size;
expected_format = cldnn::format::bs_fs_yx_bsv16_fsv16;
} else if (_optimization_attributes.fs_b_yx_fsv32_network && !node.get_transposed() &&
((convolution_fs_b_yx_fsv32_opt(input_layout,
output_layout,
weights_layout, prim) ||
(((node.get_dependency(0).is_type<convolution>() && is_format_optimized(node.get_dependency(0).as<convolution>(), format::fs_b_yx_fsv32))
|| (_optimized_conv_count.at({format::fs_b_yx_fsv32, false}) * cond_denom > 0.8f)) &&
convolution_fs_b_yx_fsv32_opt(input_layout,
output_layout,
weights_layout, prim, true))))) {
// Chose fs_b_yx_fsv32 layout in two cases: 1-st: the current conv primitive totally supports fs_b_yx_fsv32 layout
// 2-nd: the previous conv primitive supports fs_b_yx_fsv32 layout and
// current conv primitives supports this one with weak restrictions -
// that should be cheaper than reordering data to another layout
expected_tensor = current_layout.size;
expected_format = format::fs_b_yx_fsv32;
} else if (should_select_b_fs_yx_fsv16_layout(node, weights_layout)) {
expected_tensor = current_layout.size;
expected_format = cldnn::format::b_fs_yx_fsv16;
} else if (current_layout.data_type == data_types::f16 &&
layout_optimizer::convolution_byxf_opt(input_layout, current_layout, weights_layout, node) &&
(users_for_convolution_byxf_opt(node, 2) || deps_for_convolution_byxf_opt(node, 2)) &&
// todo: remove this condition when yxfb optimizations will be disabled
current_layout.format != cldnn::format::yxfb && current_layout.size.batch[0] == 1) {
expected_tensor = current_layout.size;
expected_format = cldnn::format::byxf;
} else if (current_layout.format == format::b_fs_yx_fsv4 ||
current_layout.format == format::os_is_yx_osv16_isv4) {
// imad case
// nothing to do, just go out from here.
} else if (layout_optimizer::convolution_bfyx_opt(current_layout, weights_layout, prim) ||
(_output_size_handling_enabled && prim->with_output_size) || node.get_transposed()) {
{
expected_tensor = current_layout.size;
if (current_layout.format == format::b_fs_zyx_fsv16 || current_layout.format == format::bs_fs_zyx_bsv16_fsv16)
expected_format = cldnn::format::bfzyx;
else
expected_format = cldnn::format::bfyx;
}
} else {
expected_tensor = current_layout.size;
expected_format = cldnn::format::yxfb;
}
} }
return layout(expected_data_type, expected_format, expected_tensor); return layout(expected_data_type, expected_format, expected_tensor);
@ -828,8 +903,36 @@ layout layout_optimizer::get_expected_layout(layout const& current_layout,
auto expected_tensor = current_layout.size; auto expected_tensor = current_layout.size;
auto expected_data_type = current_layout.data_type; auto expected_data_type = current_layout.data_type;
auto expected_format = current_layout.format; auto expected_format = current_layout.format;
auto input_layout = node.get_dependency(0).get_output_layout();
auto spatial_dims_num = input_layout.format.spatial_num();
auto is_2d = spatial_dims_num == 2;
bool use_onednn_impls = _optimization_attributes.use_onednn_impls;
if (_optimization_attributes.b_fs_zyx_fsv16_network && bool onednn_valid_dt = input_layout.data_type == data_types::i8 ||
input_layout.data_type == data_types::u8 ||
input_layout.data_type == data_types::f16;
bool onednn_valid_params = onednn_valid_dt &&
input_layout.size.feature[0] >= 16 &&
prim->groups == 1 &&
get_post_ops_count(node) <= 32 &&
input_layout.size.batch[0] < 16; // oneDNNs optimized kernel doesn't support big batches yet
if (use_onednn_impls && onednn_valid_params && spatial_dims_num <= 3) {
if (input_layout.data_type == data_types::f16) {
if (input_layout.size.batch[0] < 16) {
expected_format = is_2d ? cldnn::format::b_fs_yx_fsv16 : cldnn::format::b_fs_zyx_fsv16;
} else {
expected_format = is_2d ? cldnn::format::bs_fs_yx_bsv32_fsv16 : cldnn::format::bs_fs_zyx_bsv32_fsv16;
}
} else {
if (input_layout.size.batch[0] < 16) {
expected_format = is_2d ? cldnn::format::b_fs_yx_fsv32 : cldnn::format::b_fs_zyx_fsv32;
} else {
expected_format = is_2d ? cldnn::format::bs_fs_yx_bsv32_fsv32 : cldnn::format::bs_fs_zyx_bsv32_fsv32;
}
}
} else if (_optimization_attributes.b_fs_zyx_fsv16_network &&
deconvolution_b_fs_zyx_fsv16_opt(current_layout, output_or_weights_layout, prim)) { deconvolution_b_fs_zyx_fsv16_opt(current_layout, output_or_weights_layout, prim)) {
expected_tensor = current_layout.size; expected_tensor = current_layout.size;
if ((current_layout.data_type == data_types::f32 && expected_tensor.batch[0] % 16 == 0) || if ((current_layout.data_type == data_types::f32 && expected_tensor.batch[0] % 16 == 0) ||
@ -875,6 +978,41 @@ layout layout_optimizer::get_expected_layout(layout const& current_layout,
return layout(expected_data_type, expected_format, expected_tensor); return layout(expected_data_type, expected_format, expected_tensor);
} }
bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
auto in_dt = node.get_dependency(0).get_output_layout().data_type;
auto out_dt = node.get_output_layout().data_type;
if (in_dt == data_types::f32)
return false;
if (node.is_type<pooling>()) {
if (!data_type_traits::is_floating_point(in_dt) && in_dt != out_dt)
return false;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && out_dt != data_types::f32)
return true;
if (in_dt == data_types::f16 || out_dt == data_types::f16)
return true;
if (out_dt == data_types::f32)
return true;
if (in_dt == data_types::i32 || out_dt == data_types::i32)
return true;
if ((in_dt == data_types::i8 || out_dt == data_types::i8) || (in_dt == data_types::u8 || out_dt == data_types::u8))
return true;
} else if (node.is_type<convolution>() || node.is_type<deconvolution>()) {
bool is_conv = node.is_type<convolution>();
auto wei_dt = is_conv ? node.as<convolution>().weights().get_output_layout().data_type :
node.as<deconvolution>().weights().get_output_layout().data_type;
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) && (out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
return true;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
}
return false;
}
impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format preferred_format) { impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format preferred_format) {
impl_types preferred_impl = impl_types::any; impl_types preferred_impl = impl_types::any;
if (!_forcing_map.empty() && _forcing_map.count(node.id()) != 0) { if (!_forcing_map.empty() && _forcing_map.count(node.id()) != 0) {
@ -938,7 +1076,7 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
if (input_layout.format.dimension() != output_layout.format.dimension()) { if (input_layout.format.dimension() != output_layout.format.dimension()) {
preferred_impl = impl_types::ocl; preferred_impl = impl_types::ocl;
} }
} else if (node.is_type<pooling>()) { } else if (node.is_type<pooling>() || node.is_type<convolution>() || node.is_type<deconvolution>()) {
if (!_optimization_attributes.use_onednn_impls) if (!_optimization_attributes.use_onednn_impls)
return impl_types::ocl; return impl_types::ocl;
@ -964,79 +1102,79 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
} }
// if (node.is_type<convolution>()) { if (node.is_type<convolution>()) {
// // onednn doesn't have good support for groups with fsv16 fmt // oneDNN doesn't have good support for groups with fsv16 fmt
// auto& conv = node.as<convolution>(); auto& conv = node.as<convolution>();
// auto input_layout = conv.input().get_output_layout(); auto input_layout = conv.input().get_output_layout();
// bool fp16_input = input_layout.data_type == data_types::f16; bool fp16_input = input_layout.data_type == data_types::f16;
// bool has_groups = conv.get_primitive()->groups > 1; bool has_groups = conv.get_primitive()->groups > 1;
// bool is_depthwise = conv.get_primitive()->groups == input_layout.size.feature[0]; bool is_depthwise = conv.get_primitive()->groups == input_layout.size.feature[0];
// bool first_conv = input_layout.size.feature[0] <= 4; bool first_conv = input_layout.size.feature[0] <= 4;
// bool enable_onednn_dw_fp16_conv = fp16_input && is_depthwise; bool enable_onednn_dw_fp16_conv = fp16_input && is_depthwise;
// if (((has_groups && !enable_onednn_dw_fp16_conv) || first_conv) && if (((has_groups && !enable_onednn_dw_fp16_conv) || first_conv) &&
// (conv.get_output_layout().format == format::b_fs_yx_fsv16)) { (conv.get_output_layout().format == format::b_fs_yx_fsv16)) {
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// } }
// } }
// if (node.is_type<deconvolution>()) { if (node.is_type<deconvolution>()) {
// auto& deconv = node.as<deconvolution>(); auto& deconv = node.as<deconvolution>();
// auto input_layout = deconv.input().get_output_layout(); auto input_layout = deconv.input().get_output_layout();
// bool valid_ic = input_layout.size.feature[0] >= 16; bool valid_ic = input_layout.size.feature[0] >= 16;
// bool valid_groups = deconv.get_primitive()->groups == 1; bool valid_groups = deconv.get_primitive()->groups == 1;
// bool valid_post_ops = get_post_ops_count(node) <= 10; bool onednn_valid_post_ops = get_post_ops_count(node) <= 32;
// bool valid_batch = input_layout.size.batch[0] < 16; // oneDNNs optimized kernel doesn't support big batches yet bool valid_batch = input_layout.size.batch[0] < 16; // oneDNN's optimized kernel doesn't support big batches yet
// bool valid_params = valid_ic && valid_groups && valid_post_ops && valid_batch; bool valid_params = valid_ic && valid_groups && onednn_valid_post_ops && valid_batch;
// if (!valid_params) if (!valid_params)
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// } }
// // [WA] onednn doesn't support > 32 post-ops. Remove once onednn improve post-ops for GPU. // [WA] oneDNN doesn't support > 32 post-ops. Remove once oneDNN improve post-ops for GPU.
// if (get_post_ops_count(node) > 32) { if (get_post_ops_count(node) > 32) {
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// } }
// if (!data_types_are_suitable_for_onednn(node)) { if (!are_data_types_suitable_for_onednn(node)) {
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// } }
// for (auto& fo : node.get_fused_primitives()) { for (auto& fo : node.get_fused_primitives()) {
// if (fo.node->is_type<eltwise>()) { if (fo.node->is_type<eltwise>()) {
// auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout(); auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
// auto out_layout = node.get_output_layout(); auto out_layout = node.get_output_layout();
// auto in_dt = in_layout.data_type; auto in_dt = in_layout.data_type;
// auto out_dt = out_layout.data_type; auto out_dt = out_layout.data_type;
// if ((out_layout.count() == in_layout.count()) && if ((out_layout.count() == in_layout.count()) &&
// (data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt)) && in_dt != out_dt) { (data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt)) && in_dt != out_dt) {
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// break; break;
// } }
// } else if (fo.node->is_type<activation>()) { } else if (fo.node->is_type<activation>()) {
// // Some activations aren't implemented in oneDNN // Some activations aren't implemented in oneDNN
// auto activation_prim = fo.node->as<activation>().get_primitive(); auto activation_prim = fo.node->as<activation>().get_primitive();
// if (activation_prim->activation_function == activation_func::negative || if (activation_prim->activation_function == activation_func::negative ||
// activation_prim->activation_function == activation_func::negation || activation_prim->activation_function == activation_func::negation ||
// activation_prim->activation_function == activation_func::sign) activation_prim->activation_function == activation_func::sign)
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// } }
// } }
// // oneDNN doesn't support asymmetric weights quantization // oneDNN doesn't support asymmetric weights quantization
// if (node.is_type<convolution>() && node.as<convolution>().weights_zero_points_term()) if (node.is_type<convolution>() && node.as<convolution>().weights_zero_points_term())
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// // OneDNN doesn't support sum post ops for deconvolutions // oneDNN doesn't support sum post ops for deconvolutions
// if (node.is_type<deconvolution>() && impl_candidate == impl_types::onednn) { if (node.is_type<deconvolution>() && impl_candidate == impl_types::onednn) {
// for (auto& fused_op : node.get_fused_primitives()) { for (auto& fused_op : node.get_fused_primitives()) {
// if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) { if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
// auto eltw_in_layout = node.get_dependency(fused_op.dep_start_idx).get_output_layout(); auto eltw_in_layout = node.get_dependency(fused_op.dep_start_idx).get_output_layout();
// if (fused_op.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout)) { if (fused_op.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout)) {
// impl_candidate = impl_types::ocl; impl_candidate = impl_types::ocl;
// break; break;
// } }
// } }
// } }
// } }
preferred_impl = impl_candidate; preferred_impl = impl_candidate;
} else if (node.is_type<concatenation>()) { } else if (node.is_type<concatenation>()) {

View File

@ -2838,3 +2838,56 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_extended, deconvolution_random_test, testing::
.add_all_3d(data_types::i8, data_types::i8, data_types::f32, format::bs_fs_zyx_bsv16_fsv16, format::bs_fs_zyx_bsv16_fsv16) .add_all_3d(data_types::i8, data_types::i8, data_types::f32, format::bs_fs_zyx_bsv16_fsv16, format::bs_fs_zyx_bsv16_fsv16)
.add_all_3d(data_types::u8, data_types::i8, data_types::f32, format::bs_fs_zyx_bsv16_fsv16, format::bs_fs_zyx_bsv16_fsv16) .add_all_3d(data_types::u8, data_types::i8, data_types::f32, format::bs_fs_zyx_bsv16_fsv16, format::bs_fs_zyx_bsv16_fsv16)
), deconvolution_random_test_params::print_params); ), deconvolution_random_test_params::print_params);
#ifdef ENABLE_ONEDNN_FOR_GPU
TEST(deconvolution_f32_fw_gpu_onednn, basic_wsiz2x2_in2x2x1x1_stride2_nopad) {
// Filter : 1x1
// Input : 2x2
// Output : 4x4
// Stride : 2x2
auto& engine = get_onednn_test_engine();
auto input = engine.allocate_memory({ data_types::f32, format::yxfb, { 1, 1, 2, 2 } });
auto weights = engine.allocate_memory({ data_types::f32, format::oiyx, { 1, 1, 2, 2 } });
auto biases = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 1, 1 } });
set_values(input, { 8.f, 0.5f, 6.f, 9.f });
set_values(weights, { -2.0f, 0.5f, 3.5f, 1.5f });
set_values(biases, { 1.0f });
topology topology(
input_layout("input", input->get_layout()),
data("weights", weights),
data("biases", biases),
deconvolution("deconv", "input", { "weights" }, { "biases" }, { 1,1,2,2 })
);
build_options bo;
implementation_desc conv_impl = { format::yxfb, "", impl_types::onednn };
bo.set_option(build_option::force_implementations({ {"deconv", conv_impl} }));
network network(engine, topology, bo);
network.set_input_data("input", input);
auto outputs = network.execute();
EXPECT_EQ(outputs.size(), size_t(1));
EXPECT_EQ(outputs.begin()->first, "deconv");
auto output_prim = outputs.begin()->second.get_memory();
cldnn::mem_lock<float> output_ptr (output_prim, get_test_stream());
std::vector<float> expected_output_vec = {
-15.f, 5.f, 0.f, 1.25f,
29.f, 13.f, 2.75f, 1.75,
-11.f, 4.f, -17.f, 5.5f,
22.f, 10.f, 32.5f, 14.5f
};
for (unsigned int i = 0; i < expected_output_vec.size(); i++)
{
EXPECT_FLOAT_EQ(expected_output_vec[i], output_ptr[i]);
}
}
#endif

View File

@ -27,6 +27,7 @@
#include <cldnn/primitives/batch_to_space.hpp> #include <cldnn/primitives/batch_to_space.hpp>
#include <cldnn/primitives/space_to_batch.hpp> #include <cldnn/primitives/space_to_batch.hpp>
#include <cldnn/primitives/reduce.hpp> #include <cldnn/primitives/reduce.hpp>
#include <cldnn/primitives/crop.hpp>
#include <cmath> #include <cmath>
@ -130,7 +131,11 @@ struct normalize_test_params {
template<typename T> template<typename T>
class BaseFusingTest : public ::testing::TestWithParam<T> { class BaseFusingTest : public ::testing::TestWithParam<T> {
public: public:
#ifdef ENABLE_ONEDNN_FOR_GPU
cldnn::engine& engine = get_onednn_test_engine();
#else
cldnn::engine& engine = get_test_engine(); cldnn::engine& engine = get_test_engine();
#endif
cldnn::topology topology_fused; cldnn::topology topology_fused;
cldnn::topology topology_non_fused; cldnn::topology topology_non_fused;
cldnn::build_options bo_fused; cldnn::build_options bo_fused;
@ -234,6 +239,12 @@ public:
} else if (l.data_type == data_types::f32) { } else if (l.data_type == data_types::f32) {
VF<float> rnd_vec(s.count(), fill_value); VF<float> rnd_vec(s.count(), fill_value);
set_values(prim, rnd_vec); set_values(prim, rnd_vec);
} else if (l.data_type == data_types::u8) {
VF<uint8_t> rnd_vec(s.count(), static_cast<uint8_t>(fill_value));
set_values(prim, rnd_vec);
} else if (l.data_type == data_types::i8) {
VF<int8_t> rnd_vec(s.count(), static_cast<int8_t>(fill_value));
set_values(prim, rnd_vec);
} else { } else {
throw std::runtime_error("get_mem: Unsupported precision"); throw std::runtime_error("get_mem: Unsupported precision");
} }
@ -274,8 +285,10 @@ public:
} else if (l.data_type == data_types::i8) { } else if (l.data_type == data_types::i8) {
VF<int8_t> rnd_vec = generate_random_1d<int8_t>(s.count(), min, max); VF<int8_t> rnd_vec = generate_random_1d<int8_t>(s.count(), min, max);
set_values(prim, rnd_vec); set_values(prim, rnd_vec);
} } else if (l.data_type == data_types::u8) {
else if (l.data_type == data_types::bin) { VF<uint8_t> rnd_vec = generate_random_1d<uint8_t>(s.count(), min, max);
set_values(prim, rnd_vec);
} else if (l.data_type == data_types::bin) {
VF<int32_t> rnd_vec = generate_random_1d<int32_t>(s.count() / 32, min, max); VF<int32_t> rnd_vec = generate_random_1d<int32_t>(s.count() / 32, min, max);
set_values(prim, rnd_vec); set_values(prim, rnd_vec);
} }
@ -528,6 +541,10 @@ public:
#define CASE_CONV_U8S8_8 {1, 3, 4, 5}, {1, 32, 4, 5}, {1, 1, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx #define CASE_CONV_U8S8_8 {1, 3, 4, 5}, {1, 32, 4, 5}, {1, 1, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_U8S8_9 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx #define CASE_CONV_U8S8_9 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx
#define CASE_CONV_U8S8_10 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx #define CASE_CONV_U8S8_10 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx
#define CASE_CONV_U8S8_11 {32, 15, 4, 5}, {32, 30, 2, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_U8S8_12 {32, 15, 5, 5}, {32, 30, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_U8S8_13 {32, 16, 4, 5}, {32, 32, 4, 5}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_U8S8_14 {32, 17, 4, 5}, {32, 17, 4, 5}, {1, 1, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 17, data_types::u8, format::bfyx, data_types::i8, format::goiyx, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_1 {1, 15, 4, 5}, {1, 30, 2, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx #define CASE_CONV_S8S8_1 {1, 15, 4, 5}, {1, 30, 2, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_2 {1, 15, 5, 5}, {1, 30, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx #define CASE_CONV_S8S8_2 {1, 15, 5, 5}, {1, 30, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
@ -540,6 +557,10 @@ public:
#define CASE_CONV_S8S8_9 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx #define CASE_CONV_S8S8_9 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_10 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx #define CASE_CONV_S8S8_10 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_11 {1, 4, 1280, 720}, {1, 4, 1280, 720}, {1, 1, 5, 5}, tensor{1}, tensor{0, 0, -2, -2}, tensor{1}, 1, data_types::i8, format::b_fs_yx_fsv4, data_types::i8, format::os_is_yx_osv16_isv4, data_types::f32, format::bfyx #define CASE_CONV_S8S8_11 {1, 4, 1280, 720}, {1, 4, 1280, 720}, {1, 1, 5, 5}, tensor{1}, tensor{0, 0, -2, -2}, tensor{1}, 1, data_types::i8, format::b_fs_yx_fsv4, data_types::i8, format::os_is_yx_osv16_isv4, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_12 {32, 15, 4, 5}, {32, 30, 2, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_13 {32, 15, 5, 5}, {32, 30, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_14 {32, 16, 4, 5}, {32, 32, 4, 5}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_CONV_S8S8_15 {32, 17, 4, 5}, {32, 17, 4, 5}, {1, 1, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 17, data_types::i8, format::bfyx, data_types::i8, format::goiyx, data_types::f32, format::bfyx
#define CASE_CONV3D_U8S8_1 {1, 15, 5, 4, 5}, {1, 30, 3, 2, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfzyx, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx #define CASE_CONV3D_U8S8_1 {1, 15, 5, 4, 5}, {1, 30, 3, 2, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfzyx, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_CONV3D_U8S8_2 {1, 15, 5, 5, 5}, {1, 30, 3, 3, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfzyx, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx #define CASE_CONV3D_U8S8_2 {1, 15, 5, 5, 5}, {1, 30, 3, 3, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfzyx, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx
@ -771,6 +792,29 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_bias,
bc_test_params{CASE_CONV_FP16_10, 2, 3}, bc_test_params{CASE_CONV_FP16_10, 2, 3},
})); }));
class conv_fp32_double_bias : public ConvFusingTest {};
TEST_P(conv_fp32_double_bias, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p))),
data("bias1", get_mem(get_bias_layout(p))),
data("bias2", get_mem(get_bias_layout(p))),
convolution("conv_prim", "input", {"weights"}, std::vector<primitive_id>{}, p.groups, p.stride, p.pad, p.dilation),
eltwise("add_bias1", {"conv_prim", "bias1"}, eltwise_mode::sum),
eltwise("add_bias2", {"add_bias1", "bias2"}, eltwise_mode::sum),
reorder("reorder_bfyx", "add_bias2", p.default_format, data_types::f32)
);
tolerance = 1e-5f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_double_bias,
::testing::ValuesIn(std::vector<bc_test_params>{
bc_test_params{CASE_CONV_U8S8_1, 2, 4},
bc_test_params{CASE_CONV_S8S8_1, 2, 4},
}));
class conv_fp32_prelu_eltwise : public ConvFusingTest {}; class conv_fp32_prelu_eltwise : public ConvFusingTest {};
TEST_P(conv_fp32_prelu_eltwise, basic_sum) { TEST_P(conv_fp32_prelu_eltwise, basic_sum) {
auto p = GetParam(); auto p = GetParam();
@ -4864,7 +4908,7 @@ TEST_P(deconv_scale_actv_quant_u8_eltw_scale_actv_quant_i8, basic) {
reorder("out", "quant2", p.default_format, data_types::f32) reorder("out", "quant2", p.default_format, data_types::f32)
); );
tolerance = 1.0f; tolerance = 2.1f;
execute(p); execute(p);
} }
@ -8795,3 +8839,272 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_activation_scale_eltwise,
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 5 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 5 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 5 },
})); }));
#ifdef ENABLE_ONEDNN_FOR_GPU
class ConvFusingTestOneDNN : public WeightsPrimitiveFusingTest<bc_test_params> {
public:
void execute(bc_test_params& p) {
// Onednn post operation has issue in a machine that does not support imad.
if (!engine.get_device_info().supports_imad)
return;
auto input_prim = p.data_type == data_types::u8 ? get_mem(get_input_layout(p), 0, 10) : get_mem(get_input_layout(p));
auto impl_forcing_bo = bo_fused.get<build_option_type::force_implementations>();
const auto& impl_forcing = impl_forcing_bo->forcing;
auto forcing_format = p.input_format;
for (auto& forcing : impl_forcing) {
if (forcing.first == "conv_prim") {
forcing_format = forcing.second.output_format;
}
}
implementation_desc conv_impl = { forcing_format, "", impl_types::onednn };
bo_fused.set_option(build_option::force_implementations({ {"conv_prim", conv_impl} }));
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
network network_fused(this->engine, this->topology_fused, bo_fused);
network_fused.set_input_data("input", input_prim);
network_not_fused.set_input_data("input", input_prim);
compare(network_not_fused, network_fused, p);
auto find_conv = [](primitive_info& p) -> bool {
if (p.original_id == "conv_prim")
return true;
return false;
};
auto pi_fused = network_fused.get_primitives_info();
auto info_fused = std::find_if(pi_fused.begin(), pi_fused.end(), find_conv);
if (info_fused != pi_fused.end())
std::cout << "kernel: " << info_fused->kernel_id << std::endl;
}
};
class conv_int8_eltwise_onednn : public ConvFusingTestOneDNN {};
TEST_P(conv_int8_eltwise_onednn, u8_eltwise_sum_out) {
auto p = GetParam();
auto shift_layout = get_output_layout(p);
shift_layout.data_type = data_types::f32;
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p), 0, 2)),
data("bias", get_mem(get_bias_layout(p))),
data("shift_data", get_mem(shift_layout)),
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
eltwise("shift", {"conv_prim", "shift_data"}, eltwise_mode::sum, data_types::f32),
// Add 'not fusable' primitive to be able to test full size tensor sum
crop("crop", "shift", get_output_layout(p).size, {0, 0, 0, 0}),
reorder("reorder_bfyx", "crop", p.default_format, data_types::f32)
);
tolerance = 1.f;
execute(p);
}
TEST_P(conv_int8_eltwise_onednn, u8_eltwise_prod_out) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p), -2, 2)),
data("bias", get_mem(get_bias_layout(p))),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count()) ),
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
eltwise("scale", {"conv_prim", "scale_data"}, eltwise_mode::prod, data_types::u8),
crop("crop", "scale", get_output_layout(p).size, {0, 0, 0, 0}),
reorder("reorder_bfyx", "crop", p.default_format, data_types::f32)
);
tolerance = 1.f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_int8_eltwise_onednn,
::testing::ValuesIn(std::vector<bc_test_params>{
bc_test_params{CASE_CONV_U8S8_1, 2, 4},
bc_test_params{CASE_CONV_U8S8_2, 2, 4},
bc_test_params{CASE_CONV_U8S8_3, 2, 4},
bc_test_params{CASE_CONV_S8S8_1, 2, 4},
bc_test_params{CASE_CONV_S8S8_2, 2, 4},
bc_test_params{CASE_CONV_S8S8_3, 2, 4},
bc_test_params{CASE_CONV_U8S8_11, 2, 4},
bc_test_params{CASE_CONV_U8S8_12, 2, 4},
bc_test_params{CASE_CONV_U8S8_13, 2, 4},
bc_test_params{CASE_CONV_S8S8_12, 2, 4},
bc_test_params{CASE_CONV_S8S8_13, 2, 4},
bc_test_params{CASE_CONV_S8S8_14, 2, 4},
bc_test_params{CASE_CONV3D_U8S8_1, 3, 4},
bc_test_params{CASE_CONV3D_U8S8_2, 3, 4},
bc_test_params{CASE_CONV3D_U8S8_3, 3, 4},
bc_test_params{CASE_CONV3D_U8S8_5, 3, 4},
bc_test_params{CASE_CONV3D_S8S8_1, 3, 4},
bc_test_params{CASE_CONV3D_S8S8_2, 3, 4},
bc_test_params{CASE_CONV3D_S8S8_3, 3, 4},
bc_test_params{CASE_CONV3D_S8S8_5, 3, 4},
}));
class conv_fp32_activation_onednn : public ConvFusingTestOneDNN {};
TEST_P(conv_fp32_activation_onednn, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p))),
data("bias", get_mem(get_bias_layout(p))),
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
activation("activation", "conv_prim", activation_func::abs),
reorder("reorder_bfyx", "activation", p.default_format, data_types::f32)
);
tolerance = 1e-2f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_activation_onednn,
::testing::ValuesIn(std::vector<bc_test_params>{
bc_test_params{CASE_CONV_FP16_1, 2, 3},
bc_test_params{CASE_CONV_FP16_2, 2, 3},
bc_test_params{CASE_CONV_FP16_3, 2, 3},
bc_test_params{CASE_CONV_FP16_4, 2, 3},
}));
class conv_int8_quantize_u8_onednn : public ConvFusingTestOneDNN {};
TEST_P(conv_int8_quantize_u8_onednn, per_channel) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p), -2, 2)),
data("bias", get_mem(get_bias_layout(p))),
data("in_lo", get_mem(get_per_channel_layout(p), -10, 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 0, 10)),
data("out_lo", get_mem(get_single_element_layout(p), 0)),
data("out_hi", get_mem(get_single_element_layout(p), 255)),
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
quantize("quantize", "conv_prim", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
tolerance = 1.0f;
execute(p);
}
TEST_P(conv_int8_quantize_u8_onednn, per_tensor) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p), -2, 2)),
data("bias", get_mem(get_bias_layout(p), 0)),
data("in_lo", get_mem(get_single_element_layout(p), -10)),
data("in_hi", get_mem(get_single_element_layout(p), 10)),
data("out_lo", get_mem(get_single_element_layout(p), 0)),
data("out_hi", get_mem(get_single_element_layout(p), 255)),
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
quantize("quantize", "conv_prim", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
tolerance = 1.0f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_int8_quantize_u8_onednn,
::testing::ValuesIn(std::vector<bc_test_params>{
bc_test_params{CASE_CONV_U8S8_1, 2, 3},
bc_test_params{CASE_CONV_U8S8_2, 2, 3},
bc_test_params{CASE_CONV_U8S8_3, 2, 3},
bc_test_params{CASE_CONV_S8S8_1, 2, 3},
bc_test_params{CASE_CONV_S8S8_2, 2, 3},
bc_test_params{CASE_CONV_S8S8_3, 2, 3},
}));
class conv_int8_activation_eltwise_quantize_onednn : public ConvFusingTestOneDNN {};
TEST_P(conv_int8_activation_eltwise_quantize_onednn, bsv32_fsv32) {
auto p = GetParam();
layout eltwise_layout = get_output_layout(p);
eltwise_layout.format = format::bs_fs_yx_bsv32_fsv32;
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p), -1, 1)),
data("bias", get_mem(get_bias_layout(p))),
data("eltwise_data", get_mem(eltwise_layout, -0.5, 0.5)),
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)),
convolution("conv_prim", "input", { "weights" }, { "bias" }, p.groups, p.stride, p.pad, p.dilation),
activation("activation", "conv_prim", activation_func::abs),
eltwise("eltwise", "activation", "eltwise_data", eltwise_mode::sum),
quantize("quantize", "eltwise", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
implementation_desc conv_impl = { format::bs_fs_yx_bsv32_fsv32, "", impl_types::onednn };
bo_fused.set_option(build_option::force_implementations({ {"conv_prim", conv_impl} }));
tolerance = 1.f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_int8_activation_eltwise_quantize_onednn,
::testing::ValuesIn(std::vector<bc_test_params>{
bc_test_params{CASE_CONV_U8S8_1, 2, 5},
bc_test_params{CASE_CONV_U8S8_2, 2, 5},
bc_test_params{CASE_CONV_U8S8_3, 2, 5},
bc_test_params{CASE_CONV_U8S8_4, 2, 5},
bc_test_params{CASE_CONV_U8S8_7, 2, 5},
bc_test_params{CASE_CONV_U8S8_8, 2, 5},
bc_test_params{CASE_CONV_U8S8_11, 2, 5},
bc_test_params{CASE_CONV_U8S8_12, 2, 5},
bc_test_params{CASE_CONV_U8S8_13, 2, 5},
bc_test_params{CASE_CONV_U8S8_14, 2, 5},
bc_test_params{CASE_CONV_S8S8_1, 2, 5},
bc_test_params{CASE_CONV_S8S8_2, 2, 5},
bc_test_params{CASE_CONV_S8S8_3, 2, 5},
bc_test_params{CASE_CONV_S8S8_4, 2, 5},
bc_test_params{CASE_CONV_S8S8_7, 2, 5},
bc_test_params{CASE_CONV_S8S8_8, 2, 5},
bc_test_params{CASE_CONV_S8S8_12, 2, 5},
bc_test_params{CASE_CONV_S8S8_13, 2, 5},
bc_test_params{CASE_CONV_S8S8_14, 2, 5},
bc_test_params{CASE_CONV_S8S8_15, 2, 5},
}));
class conv_int8_scale_shift_swish_onednn : public ConvFusingTestOneDNN {};
TEST_P(conv_int8_scale_shift_swish_onednn, bsv32_fsv32) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p), -1, 1)),
data("bias", get_mem(get_bias_layout(p))),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
data("shift_data", get_mem(get_per_channel_layout(p), 1)),
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
eltwise("scale0", {"conv_prim", "scale_data"}, eltwise_mode::sum),
eltwise("shift0", {"scale0", "shift_data"}, eltwise_mode::sum),
activation("sigmoid", "shift0", activation_func::swish),
eltwise("scale1", {"sigmoid", "scale_data"}, eltwise_mode::sum),
eltwise("shift1", {"scale1", "shift_data"}, eltwise_mode::sum),
reorder("reorder_bfyx", "shift1", p.default_format, data_types::f32)
);
implementation_desc conv_impl = { format::bs_fs_yx_bsv32_fsv32, "", impl_types::onednn };
bo_fused.set_option(build_option::force_implementations({ {"conv_prim", conv_impl} }));
tolerance = 1.f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_int8_scale_shift_swish_onednn,
::testing::ValuesIn(std::vector<bc_test_params>{
bc_test_params{CASE_CONV_U8S8_1, 2, 7},
bc_test_params{CASE_CONV_U8S8_2, 2, 7},
bc_test_params{CASE_CONV_S8S8_1, 2, 7},
bc_test_params{CASE_CONV_S8S8_2, 2, 7},
bc_test_params{CASE_CONV_U8S8_11, 2, 7},
bc_test_params{CASE_CONV_U8S8_12, 2, 7},
bc_test_params{CASE_CONV_U8S8_14, 2, 7},
bc_test_params{CASE_CONV_S8S8_12, 2, 7},
bc_test_params{CASE_CONV_S8S8_13, 2, 7},
bc_test_params{CASE_CONV_S8S8_15, 2, 7},
}));
#endif