[GPU] Added post-ops support for OneDNN primitives (#7737)

[GPU] Memory lock fix
This commit is contained in:
Ilya Znamenskiy 2021-10-01 06:18:00 +03:00 committed by GitHub
parent 302eb08dc5
commit f675df625c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1130 additions and 47 deletions

View File

@ -159,6 +159,15 @@ struct eltwise : public primitive_base<eltwise> {
}
}
bool needs_onednn_sum_post_op(layout input_layout) const {
if (mode == eltwise_mode::sum &&
(input_layout.size.spatial[0] > 1 || input_layout.size.spatial[1] > 1 || input_layout.size.batch[0] > 1)) {
return true;
}
return false;
}
/// @param mode Eltwise mode.
eltwise_mode mode;
/// @param coefficients Blob-wise coefficient for SUM operation.

View File

@ -20,15 +20,20 @@ struct quantize_params : public base_params {
, has_post_shift(true)
, has_pre_shift(true)
, has_clamp(true)
, has_min_clamp(true)
, has_max_clamp(true)
, per_tensor_input_range(false)
, per_tensor_input_scale(false)
, per_tensor_input_shift(false)
, per_tensor_output_range(false)
, per_tensor_output_scale(false)
, per_tensor_output_shift(false)
, in_lo(0.0f)
, in_hi(0.0f)
, in_scale(0.0f)
, in_shift(0.0f)
, out_lo(0.0f)
, out_hi(0.0f)
, out_scale(0.0f)
, out_shift(0.0f) { }
@ -39,10 +44,13 @@ struct quantize_params : public base_params {
bool has_post_shift;
bool has_pre_shift;
bool has_clamp;
bool has_min_clamp;
bool has_max_clamp;
bool per_tensor_input_range;
bool per_tensor_input_scale;
bool per_tensor_input_shift;
bool per_tensor_output_range;
bool per_tensor_output_scale;
bool per_tensor_output_shift;
@ -50,6 +58,8 @@ struct quantize_params : public base_params {
float in_hi;
float in_scale;
float in_shift;
float out_lo;
float out_hi;
float out_scale;
float out_shift;
@ -79,15 +89,20 @@ struct quantize_fuse_params : fuse_params {
bool has_post_shift,
bool has_pre_shift,
bool has_clamp,
bool has_min_clamp,
bool has_max_clamp,
bool per_tensor_input_range,
bool per_tensor_input_scale,
bool per_tensor_input_shift,
bool per_tensor_output_range,
bool per_tensor_output_scale,
bool per_tensor_output_shift,
float in_lo,
float in_hi,
float in_scale,
float in_shift,
float out_lo,
float out_hi,
float out_scale,
float out_shift)
: fuse_params(KernelType::QUANTIZE)
@ -96,19 +111,25 @@ struct quantize_fuse_params : fuse_params {
, has_post_shift(has_post_shift)
, has_pre_shift(has_pre_shift)
, has_clamp(has_clamp)
, has_min_clamp(has_min_clamp)
, has_max_clamp(has_max_clamp)
, per_tensor_input_range(per_tensor_input_range)
, per_tensor_input_scale(per_tensor_input_scale)
, per_tensor_input_shift(per_tensor_input_shift)
, per_tensor_output_range(per_tensor_output_range)
, per_tensor_output_scale(per_tensor_output_scale)
, per_tensor_output_shift(per_tensor_output_shift)
, in_lo(in_lo)
, in_hi(in_hi)
, in_scale(in_scale)
, in_shift(in_shift)
, out_lo(out_lo)
, out_hi(out_hi)
, out_scale(out_scale)
, out_shift(out_shift) {
size_t index = 0;
if (has_clamp) {
bool out_range_usage = per_tensor_output_range && out_lo < out_hi;
if (!out_range_usage && has_clamp) {
in_range_lo_idx = index++;
in_range_hi_idx = index++;
}
@ -131,10 +152,13 @@ struct quantize_fuse_params : fuse_params {
bool has_post_shift;
bool has_pre_shift;
bool has_clamp;
bool has_min_clamp;
bool has_max_clamp;
bool per_tensor_input_range;
bool per_tensor_input_scale;
bool per_tensor_input_shift;
bool per_tensor_output_range;
bool per_tensor_output_scale;
bool per_tensor_output_shift;
@ -142,6 +166,8 @@ struct quantize_fuse_params : fuse_params {
float in_hi;
float in_scale;
float in_shift;
float out_lo;
float out_hi;
float out_scale;
float out_shift;

View File

@ -1631,33 +1631,88 @@ JitConstants FusedOpsCodeGenerator::MakeOpJitConstants(const FusedOpsConfigurati
: ConvertToType(GetInputVarName(p->in_scale_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
auto pre_shift = p->per_tensor_input_shift ? Broadcast(toCodeString(p->in_shift), tmp_type, vec_size)
: ConvertToType(GetInputVarName(p->in_shift_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
auto in_lo = p->per_tensor_input_range ? Broadcast(toCodeString(p->in_lo), tmp_type, vec_size)
: ConvertToType(GetInputVarName(p->in_range_lo_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
auto in_hi = p->per_tensor_input_range ? Broadcast(toCodeString(p->in_hi), tmp_type, vec_size)
: ConvertToType(GetInputVarName(p->in_range_hi_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
if (p->has_clamp) {
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = min(max(" + in_lo + ", " + in_converted + "), " + in_hi + ");";
} else {
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = " + in_converted + ";";
}
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + "*" + pre_scale + ";";
if (p->per_tensor_output_range && p->out_lo < p->out_hi) {
// Input scale
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = " + in_converted + " * " + pre_scale + ";";
// Input shift
if (p->has_pre_shift)
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + " + " + pre_shift + ";";
// Round operation isn't needed if output type is int8/uint8 and scale coefficient in all output channels is equal to 1.0
bool output_type_is_int8 = desc.output_tensor.GetDType() == Datatype::UINT8 || desc.output_tensor.GetDType() == Datatype::INT8;
if (((p->has_post_scale || p->has_post_shift) && output_type_is_int8) || !output_type_is_int8)
op_decls += "\\\n\t" + tmp_var + " = round(" + tmp_var + ");";
bool need_round = (p->has_post_scale || p->has_post_shift) &&
(desc.output_tensor.GetDType() == Datatype::UINT8 || desc.output_tensor.GetDType() == Datatype::INT8);
// Output scale
if (p->has_post_scale)
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " * " + post_scale + ");";
// Output shift
if (p->has_post_shift)
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " + " + post_shift + ");";
if (need_round)
op_decls += "\\\n\t" + tmp_var + " = round(" + tmp_var + ");";
// Output range
auto out_lo = Broadcast(std::to_string(p->out_lo), tmp_type, vec_size);
auto out_hi = Broadcast(std::to_string(p->out_hi), tmp_type, vec_size);
// Output clamp
if (p->has_clamp) {
if (p->has_min_clamp && p->has_max_clamp)
op_decls += "\\\n\t" + tmp_var + " = clamp(" + tmp_var + ", " + out_lo + ", " + out_hi + ");";
else if (p->has_min_clamp)
op_decls += "\\\n\t" + tmp_var + " = max(" + tmp_var + ", " + out_lo + ");";
else
op_decls += "\\\n\t" + tmp_var + " = min(" + tmp_var + ", " + out_hi + ");";
}
// Output conversion with rounding and saturation
op_decls += "\\\n\t" + GetOutputType(vec_size) + " " + out_var + " = " + ConvertToOutputTypeSat(tmp_var, vec_size) + ";";
break;
} else {
// Input range
auto in_lo = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_lo), tmp_type, vec_size)
: ConvertToType(GetInputVarName(p->in_range_lo_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
auto in_hi = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_hi), tmp_type, vec_size)
: ConvertToType(GetInputVarName(p->in_range_hi_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
// Input clamp
if (p->has_clamp) {
if (p->has_min_clamp && p->has_max_clamp)
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = clamp(" + in_converted + ", " + in_lo + ", " + in_hi + ");";
else if (p->has_min_clamp)
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = max(" + in_converted + ", " + in_lo + ");";
else
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = min(" + in_converted + ", " + in_hi + ");";
} else {
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = " + in_converted + ";";
}
// Input scale
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + " * " + pre_scale + ";";
// Input shift
if (p->has_pre_shift)
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + " + " + pre_shift + ";";
// Round operation isn't needed if output type is int8/uint8 and scale coefficient in all output channels is equal to 1.0
bool output_type_is_int8 = desc.output_tensor.GetDType() == Datatype::UINT8 || desc.output_tensor.GetDType() == Datatype::INT8;
if (((p->has_post_scale || p->has_post_shift) && output_type_is_int8) || !output_type_is_int8)
op_decls += "\\\n\t" + tmp_var + " = round(" + tmp_var + ");";
// Output scale
if (p->has_post_scale)
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " * " + post_scale + ");";
// Output shift
if (p->has_post_shift)
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " + " + post_shift + ");";
// Output conversion with rounding and saturation
op_decls += "\\\n\t" + GetOutputType(vec_size) + " " + out_var + " = " + ConvertToOutputTypeSat(tmp_var, vec_size) + ";";
break;
}
}
case KernelType::ACTIVATION: {
auto p = desc.GetOpParams<activation_fuse_params>();
@ -1871,7 +1926,7 @@ std::string FusedOpsCodeGenerator::ConvertToOutputTypeSat(std::string var, size_
if (desc.output_tensor.GetDType() == Datatype::F32 || desc.output_tensor.GetDType() == Datatype::F16)
return "convert_" + GetOutputType(vec_size) + "(" + var + ")";
else
return "convert_" + GetOutputType(vec_size) + "_sat(" + var + ")";
return "convert_" + GetOutputType(vec_size) + "_sat_rte(" + var + ")";
}
std::vector<size_t> FusedOpsCodeGenerator::GetRequiredInputs() const {
@ -1880,7 +1935,8 @@ std::vector<size_t> FusedOpsCodeGenerator::GetRequiredInputs() const {
auto p = std::dynamic_pointer_cast<quantize_fuse_params>(desc.op_params);
if (p) {
std::vector<size_t> res = {};
if (!p->per_tensor_input_range && p->has_clamp) {
bool out_range_usage = p->per_tensor_output_range && p->out_lo < p->out_hi;
if (!out_range_usage && p->has_clamp) {
res.push_back(p->in_range_lo_idx);
res.push_back(p->in_range_hi_idx);
}

View File

@ -36,6 +36,40 @@ void basic_memory_dependencies::run(program& p) {
add_memory_dependency(it, node);
}
if (node->is_type<convolution>()) {
auto& conv = node->as<convolution>();
bool can_reuse_eltwise_mem = false;
size_t eltw_dep = 0;
for (auto& fused_op : conv.get_fused_primitives()) {
if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
auto eltw_in_layout = conv.get_dependency(fused_op.dep_start_idx).get_output_layout();
auto conv_out_layout = node->get_output_layout();
if (eltw_dep > 0) {
can_reuse_eltwise_mem = false;
break;
}
if (eltw_in_layout.size == conv_out_layout.size &&
eltw_in_layout.format == conv_out_layout.format &&
eltw_in_layout.data_padding == conv_out_layout.data_padding &&
data_type_traits::size_of(eltw_in_layout.data_type) == data_type_traits::size_of(conv_out_layout.data_type)) {
eltw_dep = fused_op.dep_start_idx;
can_reuse_eltwise_mem = true;
}
}
}
if (can_reuse_eltwise_mem) {
auto& eltw_node = conv.get_dependency(eltw_dep);
eltw_node.can_share_buffer(false);
conv.can_share_buffer(false);
for (auto& user : conv.get_users()) {
add_memory_dependency(user, &eltw_node);
add_memory_dependency(user, &conv);
}
}
}
// Note we iterate over processing order, it means if primitve has processing num greater than any of outputs,
// this output has to land on the primitve restriction list. Otherwise memory reuse can corrupt final results.
node->add_memory_dependency(past_outputs);

View File

@ -6,6 +6,8 @@
#include "pooling_inst.h"
#include "quantize_inst.h"
#include "reshape_inst.h"
#include "reorder_inst.h"
#include "binary_convolution_inst.h"
#include "scale_inst.h"
#include "eltwise_inst.h"
@ -41,6 +43,57 @@ bool check_binarization(memory::ptr mem_input_low, memory::ptr mem_input_high, p
void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& quantize_node) {
const auto& stream = p.get_stream();
size_t out_features = static_cast<size_t>(quantize_node.get_output_layout().size.feature[0]);
float* bias_values = nullptr;
cldnn::memory* bias_mem_ptr = nullptr;
bool can_merge_bias = false;
size_t bias_depth = 0;
// Will try to merge bias into FQ
auto &merge_node = quantize_node.get_dependency(0);
if (merge_node.is_type<eltwise>() && merge_node.get_dependencies().size() == 2) {
auto& eltw_node = merge_node.as<eltwise>();
auto& eltw_node_dep1 = eltw_node.get_dependency(1);
// Check that this is not input layout
if (!eltw_node_dep1.is_type<input_layout>()) {
// We should check a case with reshape / reorder nodes before bias constant data
if (eltw_node_dep1.is_type<data>()) {
bias_depth = 1;
} else if (eltw_node_dep1.get_dependencies().size()) {
auto has_extra_nodes1 = eltw_node_dep1.is_type<reshape>() || eltw_node_dep1.is_type<reorder>();
if (has_extra_nodes1 && eltw_node_dep1.get_dependency(0).is_type<data>()) {
bias_depth = 2;
} else if (has_extra_nodes1 && eltw_node_dep1.get_dependency(0).get_dependencies().size()) {
auto has_extra_nodes2 = eltw_node_dep1.get_dependency(0).is_type<reshape>() || eltw_node_dep1.get_dependency(0).is_type<reorder>();
if (has_extra_nodes2 && eltw_node_dep1.get_dependency(0).get_dependency(0).is_type<data>())
bias_depth = 3;
}
}
auto& dep = bias_depth == 1 ? eltw_node_dep1 :
bias_depth == 2 ? eltw_node_dep1.get_dependency(0) :
bias_depth == 3 ? eltw_node_dep1.get_dependency(0).get_dependency(0) :
eltw_node_dep1;
if (bias_depth) {
can_merge_bias = dep.is_constant() && dep.get_output_layout().count() == out_features && dep.get_users().size() == 1 &&
eltw_node.get_primitive()->mode == eltwise_mode::sum && eltw_node.get_dependencies().size() == 2 &&
eltw_node.get_dependency(0).is_type<convolution>();
}
if (can_merge_bias) {
auto &bias = dep.as<data>();
auto &mem_bias = bias.get_attached_memory();
bias_mem_ptr = &mem_bias;
auto data_bias_ptr = static_cast<float*>(mem_bias.lock(stream));
bias_values = data_bias_ptr;
}
}
}
program_node &input_low_node = quantize_node.get_dependency(1);
program_node &input_high_node = quantize_node.get_dependency(2);
program_node &output_low_node = quantize_node.get_dependency(3);
@ -83,9 +136,11 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
auto lock_memory = [&stream] (memory::ptr memory, std::function<void(std::size_t, float)>& set_data,
std::function<float(size_t)>& get_data) {
using float_mem_lock = mem_lock<float, mem_lock_type::write>;
using uint16_t_mem_lock = mem_lock<uint16_t, mem_lock_type::write>;
switch (memory->get_layout().data_type) {
case data_types::f32: {
std::shared_ptr<mem_lock<float, mem_lock_type::write>> data_lock_ptr = std::make_shared<mem_lock<float, mem_lock_type::write>>(memory, stream);
std::shared_ptr<float_mem_lock> data_lock_ptr = std::make_shared<float_mem_lock>(memory, stream);
float* data = data_lock_ptr->data();
set_data = [data] (size_t idx, float value) {
data[idx] = value;
@ -93,11 +148,10 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
get_data = [data] (size_t idx) {
return data[idx];
};
return std::pair<std::shared_ptr<mem_lock<float, mem_lock_type::write>>,
std::shared_ptr<mem_lock<uint16_t, mem_lock_type::write>>>(data_lock_ptr, nullptr);
return std::pair<std::shared_ptr<float_mem_lock>, std::shared_ptr<uint16_t_mem_lock>>(data_lock_ptr, nullptr);
}
case data_types::f16: {
auto data_lock_ptr = std::make_shared<mem_lock<uint16_t, mem_lock_type::write>>(memory, stream);
std::shared_ptr<uint16_t_mem_lock> data_lock_ptr = std::make_shared<uint16_t_mem_lock>(memory, stream);
uint16_t* data = data_lock_ptr->data();
set_data = [data] (size_t idx, float value) {
data[idx] = float_to_half(value);
@ -105,8 +159,7 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
get_data = [data] (size_t idx) {
return half_to_float(data[idx]);
};
return std::pair<std::shared_ptr<mem_lock<float, mem_lock_type::write>>,
std::shared_ptr<mem_lock<uint16_t, mem_lock_type::write>>>(nullptr, data_lock_ptr);
return std::pair<std::shared_ptr<float_mem_lock>, std::shared_ptr<uint16_t_mem_lock>>(nullptr, data_lock_ptr);
}
default:
throw std::runtime_error("prepare_quantization: Unsupported precision of quantize output values");
@ -161,8 +214,9 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
float out_lo = get_data_output_low(get_offset_safe(mem_output_low->get_layout(), idx));
float out_hi = get_data_output_high(get_offset_safe(mem_output_high->get_layout(), idx));
set_data_input_scale(s_offset, (static_cast<float>(levels) - 1.f) / (in_hi - in_lo));
set_data_input_shift(s_offset, - in_lo * (static_cast<float>(levels) - 1.f) / (in_hi - in_lo));
float in_shift_basic = (static_cast<float>(levels) - 1.f) / (in_hi - in_lo);
set_data_input_scale(s_offset, in_shift_basic);
set_data_input_shift(s_offset, can_merge_bias ? (bias_values[f] - in_lo) * in_shift_basic : -in_lo * in_shift_basic);
set_data_output_scale(s_offset, (out_hi - out_lo) / (static_cast<float>(levels) - 1.f));
set_data_output_shift(s_offset, out_lo);
@ -186,12 +240,15 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
bool per_tensor_in_range = true;
bool per_tensor_out_scale = true;
bool per_tensor_out_shift = true;
bool per_tensor_out_range = true;
float in_scale_val = get_data_input_scale(0);
float in_shift_val = get_data_input_shift(0);
float out_scale_val = get_data_output_scale(0);
float out_shift_val = get_data_output_shift(0);
float in_lo_val = get_data_input_low(0);
float in_hi_val = get_data_input_high(0);
float out_lo_val = get_data_output_low(0);
float out_hi_val = get_data_output_high(0);
for (size_t i = 0; i < scales_layout.count(); i++) {
if (in_scale_val != get_data_input_scale(i))
per_tensor_in_scale = false;
@ -207,9 +264,56 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
if (in_lo_val != get_data_input_low(i % mem_input_low->get_layout().count()) ||
in_hi_val != get_data_input_high(i % mem_input_high->get_layout().count()))
per_tensor_in_range = false;
if (out_lo_val != get_data_output_low(i % mem_output_low->get_layout().count()) ||
out_hi_val != get_data_output_high(i % mem_output_high->get_layout().count()))
per_tensor_out_range = false;
}
auto out_is_int8 = quantize_node.get_output_layout().data_type == data_types::i8;
auto out_is_uint8 = quantize_node.get_output_layout().data_type == data_types::u8;
auto out_is_fp = !(out_is_int8 || out_is_uint8);
bool need_clamp = levels != 256 || out_is_fp;
bool need_min_clamp = need_clamp;
bool need_max_clamp = need_clamp;
// Check that we can optimize clamp operation for int8 data using saturation clamp only
if (per_tensor_out_range && !out_is_fp && levels != 256) {
if ((out_is_int8 && out_lo_val == -128.f) || (out_is_uint8 && out_lo_val == 0.f))
need_min_clamp = false;
if ((out_is_int8 && out_hi_val == 127.f) || (out_is_uint8 && out_hi_val == 255.f))
need_max_clamp = false;
}
// Check that we can merge bias into FQ input shift and if yes then
// we remove bias from network graph
if (can_merge_bias) {
auto &eltw_node = merge_node.as<eltwise>();
// Remove bias constants and extra reshapes / reorders from the graph (dep3, dep2, dep1)
if (bias_depth == 3) {
auto &dep3 = eltw_node.get_dependency(1).get_dependency(0).get_dependency(0);
p.remove_all_connections(dep3);
p.remove_if_dangling(dep3);
}
if (bias_depth >= 2) {
auto &dep2 = eltw_node.get_dependency(1).get_dependency(0);
p.remove_all_connections(dep2);
p.remove_if_dangling(dep2);
}
auto &dep1 = eltw_node.get_dependency(1);
p.remove_all_connections(dep1);
p.remove_if_dangling(dep1);
// Remove bias from the graph (eltwise in a "sum" mode)
p.extract_and_remove(eltw_node);
}
if (has_negative_scales) {
if (can_merge_bias)
bias_mem_ptr->unlock(stream);
return;
}
@ -266,17 +370,30 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
quantize_node.set_input_shift_val(in_shift_val);
}
auto out_dt = quantize_node.get_output_layout().data_type;
bool need_clamp = levels != 256 || (out_dt != data_types::u8 && out_dt != data_types::i8);
if (need_clamp) {
quantize_node.set_need_clamp();
}
if (need_min_clamp) {
quantize_node.set_need_min_clamp();
}
if (need_max_clamp) {
quantize_node.set_need_max_clamp();
}
if (per_tensor_in_range) {
quantize_node.set_per_tensor_input_range();
quantize_node.set_input_lo_val(in_lo_val);
quantize_node.set_input_hi_val(in_hi_val);
}
if (per_tensor_out_range) {
quantize_node.set_per_tensor_output_range();
quantize_node.set_output_lo_val(out_lo_val);
quantize_node.set_output_hi_val(out_hi_val);
}
if (per_tensor_out_scale) {
quantize_node.set_per_tensor_output_scale();
quantize_node.set_output_scale_val(out_scale_val);
@ -286,6 +403,10 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
quantize_node.set_per_tensor_output_shift();
quantize_node.set_output_shift_val(out_shift_val);
}
if (can_merge_bias) {
bias_mem_ptr->unlock(stream);
}
}
void prepare_quantization::handle_quantize_node(program& p, quantize_node& quantize_node) {

View File

@ -55,10 +55,13 @@ public:
quantize_params.has_post_shift = arg.get_need_post_shift();
quantize_params.has_pre_shift = arg.get_need_pre_shift();
quantize_params.has_clamp = arg.get_need_clamp();
quantize_params.has_min_clamp = arg.get_need_min_clamp();
quantize_params.has_max_clamp = arg.get_need_max_clamp();
quantize_params.per_tensor_input_range = arg.get_per_tensor_input_range();
quantize_params.per_tensor_input_scale = arg.get_per_tensor_input_scale();
quantize_params.per_tensor_input_shift = arg.get_per_tensor_input_shift();
quantize_params.per_tensor_output_range = arg.get_per_tensor_output_range();
quantize_params.per_tensor_output_scale = arg.get_per_tensor_output_scale();
quantize_params.per_tensor_output_shift = arg.get_per_tensor_output_shift();
@ -66,6 +69,8 @@ public:
quantize_params.in_hi = arg.get_input_hi_val();
quantize_params.in_scale = arg.get_input_scale_val();
quantize_params.in_shift = arg.get_input_shift_val();
quantize_params.out_lo = arg.get_output_lo_val();
quantize_params.out_hi = arg.get_output_hi_val();
quantize_params.out_scale = arg.get_output_scale_val();
quantize_params.out_shift = arg.get_output_shift_val();

View File

@ -25,6 +25,36 @@
namespace cldnn {
namespace onednn {
enum class onednn_post_op_type : uint32_t {
eltwise_act,
eltwise_clip,
eltwise_linear,
eltwise_round,
binary_mul,
binary_add,
binary_max,
binary_min,
scale,
sum,
optimized,
optimized_eltwise,
optimized_sum
};
struct onednn_post_op_desc {
onednn_post_op_type op_type;
size_t mem_offset;
size_t mem_dep;
};
// This map contains information about onednn post-ops types, memory buffer offsets and dependencies
// key is cldnn::primitive_id,
// value is an instance of struct onednn_post_op_desc containing info about post-ops related to the node defined by key:
// op_type - onednn_post_op_type (enum),
// mem_offset - index of memory buffer for current post-operation,
// mem_dep - memory dependency for working with fused node
static std::unordered_map<cldnn::primitive_id, std::vector<onednn_post_op_desc>> onednn_fusing_map;
template <class PType, class DescType, class PrimDescType = dnnl::primitive_desc, class PrimType = dnnl::primitive>
struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
const typed_program_node<PType>& _outer;
@ -32,7 +62,7 @@ struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
std::shared_ptr<dnnl::primitive_attr> _attrs;
PrimDescType _pd;
PrimType _prim;
std::unordered_map<int, dnnl::memory> _args;
std::unordered_map<uint32_t, std::unordered_map<int, dnnl::memory>> _args;
typed_primitive_onednn_impl(const typed_program_node<PType>& arg,
std::shared_ptr<DescType> desc,
@ -67,12 +97,431 @@ protected:
return !zp.empty() && (reinterpret_cast<const int32_t&>(zp[0]) == drsv);
}
static dnnl::post_ops try_optimize_post_ops(const typed_program_node<PType>& arg, dnnl::post_ops& p_ops,
const std::shared_ptr<dnnl::primitive_attr>& attr,
bool& optimization_is_completed) {
// Get current node id for creating of optimization map
auto node_id = arg.id();
// Create new dnnl::post_ops object which will be filled inside the optimization process
dnnl::post_ops optimized_p_ops;
// Add new post-op into optimized_p_ops structure
auto add_post_op = [&](onednn_post_op_type type, const dnnl::post_ops& cur_p_ops, dnnl::post_ops& new_p_ops, int idx) {
switch (type) {
case onednn_post_op_type::eltwise_act:
case onednn_post_op_type::eltwise_clip:
case onednn_post_op_type::eltwise_linear:
case onednn_post_op_type::eltwise_round:
{
dnnl::algorithm alg;
float scale, alpha, beta;
cur_p_ops.get_params_eltwise(idx, scale, alg, alpha, beta);
new_p_ops.append_eltwise(scale, alg, alpha, beta);
break;
}
case onednn_post_op_type::binary_add:
case onednn_post_op_type::binary_mul:
case onednn_post_op_type::binary_max:
case onednn_post_op_type::binary_min:
{
dnnl::algorithm alg;
dnnl::memory::desc desc;
cur_p_ops.get_params_binary(idx, alg, desc);
new_p_ops.append_binary(alg, desc);
break;
}
case onednn_post_op_type::scale:
{
break;
}
case onednn_post_op_type::sum:
{
float scale;
dnnl::memory::data_type data_type;
cur_p_ops.get_params_sum(idx, scale, data_type);
new_p_ops.append_sum(scale, data_type);
break;
}
case onednn_post_op_type::optimized:
case onednn_post_op_type::optimized_sum:
case onednn_post_op_type::optimized_eltwise:
{
// Current operation already has been optimized => don't need extra actions
break;
}
default:
throw std::runtime_error("Unsupported onednn post-operation type");
}
};
auto& cur_post_ops = onednn_fusing_map[node_id];
size_t cur_post_op_idx = 1;
size_t prev_post_op_idx = 0;
bool optimization_done = false;
// Check and update post-op map if we already optimized something
for (size_t post_op_idx = 0; post_op_idx < cur_post_ops.size(); post_op_idx++) {
if (cur_post_ops[post_op_idx].op_type == onednn_post_op_type::optimized_sum)
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
else if (cur_post_ops[post_op_idx].op_type == onednn_post_op_type::optimized_eltwise)
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
else if (cur_post_ops[post_op_idx].op_type == onednn_post_op_type::optimized)
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
}
// Get post-ops size for current node
auto post_ops_size = cur_post_ops.size();
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
while (!optimization_done) {
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
auto prev_type = cur_post_ops[prev_post_op_idx].op_type;
// Ignore optimized operations for "previous" operation in our operation pair
while ((prev_type == onednn_post_op_type::optimized || prev_type == onednn_post_op_type::optimized_sum ||
prev_type == onednn_post_op_type::optimized_eltwise) && cur_post_op_idx < post_ops_size - 1) {
prev_post_op_idx++;
cur_post_op_idx++;
prev_type = cur_post_ops[prev_post_op_idx].op_type;
cur_type = cur_post_ops[cur_post_op_idx].op_type;
}
// Ignore optimized operations for "current" operation in our operation pair
while ((cur_type == onednn_post_op_type::optimized || cur_type == onednn_post_op_type::optimized_sum ||
cur_type == onednn_post_op_type::optimized_eltwise) && cur_post_op_idx < post_ops_size - 1) {
cur_post_op_idx++;
cur_type = cur_post_ops[cur_post_op_idx].op_type;
}
auto cur_idx = static_cast<int>(has_out_scales(attr) ? (cur_post_op_idx >= 1 ? cur_post_op_idx - 1 : 0) : cur_post_op_idx);
auto prev_idx = static_cast<int>(has_out_scales(attr) ? (prev_post_op_idx >= 1 ? prev_post_op_idx - 1 : 0) : prev_post_op_idx);
auto cur_type_is_optimized = cur_type == onednn_post_op_type::optimized ||
cur_type == onednn_post_op_type::optimized_sum ||
cur_type == onednn_post_op_type::optimized_eltwise;
auto prev_type_is_optimized = prev_type == onednn_post_op_type::optimized ||
prev_type == onednn_post_op_type::optimized_sum ||
prev_type == onednn_post_op_type::optimized_eltwise;
// If this is the last pair and it's optimized - add the last post-op and go out from the cycle
if (cur_post_op_idx == post_ops_size - 1 && (cur_type_is_optimized || prev_type_is_optimized)) {
if (!prev_type_is_optimized) {
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
}
if (!cur_type_is_optimized) {
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
}
break;
}
auto equal_ops = cur_type == prev_type;
auto cur_type_is_binary_add_or_mul = cur_type == onednn_post_op_type::binary_add || cur_type == onednn_post_op_type::binary_mul;
auto prev_type_is_binary_add_or_mul = prev_type == onednn_post_op_type::binary_add || prev_type == onednn_post_op_type::binary_mul;
// Post-ops combinations which can be simplified
auto eltw_and_eltw = equal_ops && cur_type == onednn_post_op_type::eltwise_linear;
auto bin_and_eltw = cur_type_is_binary_add_or_mul && prev_type == onednn_post_op_type::eltwise_linear;
auto eltw_and_bin = cur_type == onednn_post_op_type::eltwise_linear && prev_type_is_binary_add_or_mul;
auto eltw_and_sum = cur_type == onednn_post_op_type::eltwise_linear && prev_type == onednn_post_op_type::sum;
auto eltw_and_scale = cur_type == onednn_post_op_type::eltwise_linear && prev_type == onednn_post_op_type::scale;
auto can_try_optimize = eltw_and_eltw ||
bin_and_eltw ||
eltw_and_bin ||
eltw_and_sum ||
eltw_and_scale;
bool cur_ops_pair_is_optimized = false;
if (can_try_optimize) {
if (eltw_and_eltw) {
dnnl::algorithm alg;
float cur_scale, prev_scale, cur_alpha, prev_alpha, cur_beta, prev_beta;
p_ops.get_params_eltwise(prev_idx, prev_scale, alg, prev_alpha, prev_beta);
p_ops.get_params_eltwise(cur_idx, cur_scale, alg, cur_alpha, cur_beta);
// Eltwise + eltwise pair can be optimized only if cur_alpha is equal to 1.0f
if (cur_alpha == 1.0f && prev_scale == cur_scale) {
dnnl::post_ops eltw_p_op;
eltw_p_op.append_eltwise(cur_scale, alg, prev_alpha, cur_beta + prev_beta);
// Combine 2 eltwises into one
add_post_op(cur_type, eltw_p_op, optimized_p_ops, 0);
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
// Set the flag if extra optimizations checking is needed
if (cur_post_op_idx < post_ops_size - 1) {
if (cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::eltwise_linear ||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::binary_add ||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::binary_mul ||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::optimized_eltwise) {
optimization_is_completed = true;
}
}
cur_ops_pair_is_optimized = true;
}
} else if (bin_and_eltw) {
dnnl::algorithm alg;
dnnl::memory::desc desc;
float scale, alpha, beta;
cldnn::program_node& cur_node = arg.get_dependency(cur_post_ops[cur_post_op_idx].mem_dep);
p_ops.get_params_binary(cur_idx, alg, desc);
p_ops.get_params_eltwise(prev_idx, scale, alg, alpha, beta);
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
auto bin_ops_can_be_optimized = cur_node.is_type<data>() && cur_node.is_constant() &&
cur_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
auto bin_add_and_eltw = alpha == 1.0f && scale == 1.0f && cur_type == onednn_post_op_type::binary_add && bin_ops_can_be_optimized;
auto bin_mul_and_eltw = beta == 0.f && scale == 1.0f && cur_type == onednn_post_op_type::binary_mul && bin_ops_can_be_optimized;
if (bin_add_and_eltw || bin_mul_and_eltw) {
memory::ptr cur_bin_mem_ptr = cur_node.as<data>().get_attached_memory_ptr();
auto& stream = cur_bin_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::write> bin_and_eltw_lock(cur_bin_mem_ptr, stream);
size_t cur_bin_mem_size = cur_node.get_output_layout().count();
// Update all binary coefficients
if (bin_add_and_eltw) {
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
bin_and_eltw_lock[data_idx] += beta;
}
} else {
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
bin_and_eltw_lock[data_idx] *= alpha;
}
}
// Marked previous eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
} else if (eltw_and_bin) {
dnnl::algorithm alg;
dnnl::memory::desc desc;
float scale, alpha, beta;
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
p_ops.get_params_eltwise(cur_idx, scale, alg, alpha, beta);
p_ops.get_params_binary(prev_idx, alg, desc);
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
auto bin_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
prev_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
auto eltw_and_bin_add = alpha == 1.0f && scale == 1.0f && prev_type == onednn_post_op_type::binary_add && bin_ops_can_be_optimized;
auto eltw_and_bin_mul = beta == 0.f && scale == 1.0f && prev_type == onednn_post_op_type::binary_mul && bin_ops_can_be_optimized;
if (eltw_and_bin_add || eltw_and_bin_mul) {
memory::ptr prev_bin_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
auto& stream = prev_bin_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::write> eltw_and_bin_lock(prev_bin_mem_ptr, stream);
size_t prev_bin_mem_size = prev_node.get_output_layout().count();
// Update all binary coefficients
if (eltw_and_bin_add) {
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
eltw_and_bin_lock[data_idx] += beta;
}
} else {
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
eltw_and_bin_lock[data_idx] *= alpha;
}
}
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
} else if (eltw_and_sum) {
dnnl::algorithm alg;
float cur_scale, prev_scale, alpha, beta;
dnnl::memory::data_type data_type;
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
p_ops.get_params_eltwise(cur_idx, cur_scale, alg, alpha, beta);
p_ops.get_params_sum(prev_idx, prev_scale, data_type);
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
auto eltw_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
prev_node.get_users().size() == 1;
// Eltwise can be inserted into the scale field of previous sum if cur_beta is equal to 0.f
if (beta == 0.f && cur_scale == 1.0f && eltw_ops_can_be_optimized) {
dnnl::post_ops sum_p_op;
sum_p_op.append_sum(alpha * prev_scale, data_type);
// Insert cur eltwise into sum
add_post_op(prev_type, sum_p_op, optimized_p_ops, 0);
memory::ptr prev_eltw_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
auto& stream = prev_eltw_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::write> eltw_and_sum_lock(prev_eltw_mem_ptr, stream);
size_t prev_eltw_mem_size = prev_node.get_output_layout().count();
// Also multiply sum on alpha for getting valid results
for (size_t data_idx = 0; data_idx < prev_eltw_mem_size; data_idx++) {
eltw_and_sum_lock[data_idx] *= alpha;
}
// Marked current and previous operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
// Set the flag if extra optimizations checking is needed
if (cur_post_op_idx < post_ops_size - 1) {
if (cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::eltwise_linear ||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::optimized_eltwise) {
optimization_is_completed = true;
}
}
cur_ops_pair_is_optimized = true;
}
} else if (eltw_and_scale) {
dnnl::algorithm alg;
float cur_scale, alpha, beta;
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
p_ops.get_params_eltwise(cur_idx, cur_scale, alg, alpha, beta);
// Eltwise can be inserted into output_scale if cur_beta is equal to 0.f and cur_scale is equal to 1.0f
if (beta == 0.f && cur_scale == 1.0f && prev_node.get_output_layout().data_type == data_types::f32) {
memory::ptr prev_scale_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
auto& stream = prev_scale_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::write> eltw_and_scale_lock(prev_scale_mem_ptr, stream);
size_t prev_scale_mem_size = prev_node.get_output_layout().count();
// Update all scale coefficients
for (size_t data_idx = 0; data_idx < prev_scale_mem_size; data_idx++) {
eltw_and_scale_lock[data_idx] *= alpha;
}
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
}
}
// If no optimizations have been applied then copy post-op info into the new optimized_p_ops structure
if (!(has_out_scales(attr) && prev_post_op_idx == 0) && !cur_ops_pair_is_optimized) {
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
}
if (cur_post_op_idx == post_ops_size - 1 && !cur_ops_pair_is_optimized) {
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
optimization_done = true;
} else if (cur_post_ops[cur_post_op_idx].op_type != onednn_post_op_type::optimized) {
cur_post_op_idx++;
prev_post_op_idx++;
}
}
optimization_is_completed = !optimization_is_completed;
return optimized_p_ops;
}
void configure_post_ops_arguments(typed_primitive_inst<PType>& instance, std::unordered_map<int, dnnl::memory>& args) const {
// Get current node id for creating of optimization map
auto node_id = instance.id();
auto& engine = instance.get_network().get_engine();
auto dnnl_engine = engine.get_onednn_engine();
// Get current post-ops info
dnnl::post_ops post_ops = _attrs->get_post_ops();
// Create onednn memory buffers for post-ops
auto& cur_post_ops = onednn_fusing_map[node_id];
auto post_ops_size = cur_post_ops.size();
for (size_t post_op_idx = 0, num_of_optimized_post_ops = 0; post_op_idx < post_ops_size; post_op_idx++) {
auto post_op_type = cur_post_ops[post_op_idx].op_type;
auto memory_offset = cur_post_ops[post_op_idx].mem_offset;
auto onednn_post_op_idx = has_out_scales(_attrs) && post_op_idx > 0 ? post_op_idx - 1 : post_op_idx;
onednn_post_op_idx -= num_of_optimized_post_ops;
switch (post_op_type) {
case onednn_post_op_type::eltwise_act:
case onednn_post_op_type::eltwise_clip:
case onednn_post_op_type::eltwise_linear:
case onednn_post_op_type::eltwise_round:
{
// onednn elwise doesn't need any data from memory buffers
break;
}
case onednn_post_op_type::binary_add:
case onednn_post_op_type::binary_mul:
case onednn_post_op_type::binary_max:
case onednn_post_op_type::binary_min:
{
auto binary_op_mem = instance.fused_memory(memory_offset);
dnnl::algorithm alg;
dnnl::memory::desc desc;
post_ops.get_params_binary(static_cast<int>(onednn_post_op_idx), alg, desc);
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(static_cast<int>(onednn_post_op_idx)) | DNNL_ARG_SRC_1,
binary_op_mem->get_onednn_memory(desc)});
break;
}
case onednn_post_op_type::scale:
{
auto scale_op_mem = instance.fused_memory(memory_offset);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(scale_op_mem->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_OUTPUT_SCALES, scale_op_mem->get_onednn_memory(desc)});
break;
}
case onednn_post_op_type::sum:
case onednn_post_op_type::optimized_sum:
case onednn_post_op_type::optimized_eltwise:
{
break;
}
case onednn_post_op_type::optimized:
{
// Optimized post-op, count it to respect onednn_post_op_idx in the next operations
num_of_optimized_post_ops++;
break;
}
default:
throw std::runtime_error("Unsupported onednn post-operation type");
}
}
}
virtual std::unordered_map<int, dnnl::memory> get_arguments(typed_primitive_inst<PType>& instance) const {
std::unordered_map<int, dnnl::memory> args;
auto& engine = instance.get_network().get_engine();
auto dnnl_engine = engine.get_onednn_engine();
for (size_t i = 0; i < instance.inputs_memory_count(); i++) {
auto& input = instance.input_memory(i);
args.insert({DNNL_ARG_SRC, input.get_onednn_memory(_pd.dnnl::primitive_desc_base::src_desc(static_cast<int>(i)))});
{
auto& input = instance.input_memory(0);
args.insert({DNNL_ARG_SRC, input.get_onednn_memory(_pd.dnnl::primitive_desc_base::src_desc(0))});
}
{
@ -80,15 +529,281 @@ protected:
args.insert({DNNL_ARG_DST, output.get_onednn_memory(_pd.dnnl::primitive_desc_base::dst_desc(0))});
}
configure_post_ops_arguments(instance, args);
return args;
}
void init_kernels() override { }
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<PType>& /* arg */) {
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<PType>& arg) {
const std::vector<fused_primitive_desc>& cldnn_post_ops = arg.get_fused_primitives();
auto attrs = std::make_shared<dnnl::primitive_attr>();
dnnl::post_ops post_ops;
size_t memory_offset = 0;
// Create onednn post-ops list related to the current node
std::vector<onednn_post_op_desc> fused_ops;
// Added this for debug purposes only
size_t empty_mem = 0xff;
// Add information about post-operation into the list, update indices
auto update_onednn_post_op_list = [&](onednn_post_op_type type, size_t m_dep) {
onednn_post_op_desc cur_op_desc = { type, memory_offset, m_dep };
fused_ops.push_back(cur_op_desc);
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
type == onednn_post_op_type::binary_mul ||
type == onednn_post_op_type::binary_max ||
type == onednn_post_op_type::binary_min ||
type == onednn_post_op_type::scale ||
type == onednn_post_op_type::sum;
if (has_memory_buffers)
memory_offset++;
};
for (size_t idx = 0; idx < cldnn_post_ops.size(); idx++) {
auto node = cldnn_post_ops[idx].node;
if (node->is_type<activation>()) {
auto fused_desc = node->as<activation>().get_primitive();
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
} else if (node->is_type<eltwise>()) {
auto& e_node = node->as<eltwise>();
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
auto in = arg.get_dependency(dep_idx).get_output_layout();
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
if (e_node.get_primitive()->needs_onednn_sum_post_op(in)) {
post_ops.append_sum(1.0f, onednn::convert_data_type(in.data_type));
update_onednn_post_op_list(onednn_post_op_type::sum, dep_idx);
} else {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
}
} else {
if (in.size.spatial[0] > 1 || in.size.spatial[1] > 1 || in.size.batch[0] > 1)
throw std::runtime_error("Unsupported eltwise mode for fused onednn op");
if (idx == 0 && !has_out_scales(attrs)) {
int mask = in.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
} else {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
}
}
} else if (node->is_type<quantize>()) {
auto& q_node = node->as<quantize>();
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
if (q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val()) {
// 1. pre-scale & pre-shift
{
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_per_tensor_input_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
if (idx == 0 && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
arg.type() == convolution::type_id() &&
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
int mask = in_scale.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
} else {
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_pre_shift()) {
if (q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
// 2. round
auto out_dt = cldnn_post_ops[idx].output_layout.data_type;
bool output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
if (!output_type_is_int8) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
}
// 3. post-scale & post-shift
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_need_post_scale()) {
if (q_node.get_per_tensor_output_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_post_shift()) {
if (q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
// 4. clamp
if (q_node.get_need_clamp()) {
float out_lo = q_node.get_need_min_clamp() ? q_node.get_output_lo_val() : data_type_traits::min<float>(out_dt);
float out_hi = q_node.get_need_max_clamp() ? q_node.get_output_hi_val() : data_type_traits::max<float>(out_dt);
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip, out_lo, out_hi);
update_onednn_post_op_list(onednn_post_op_type::eltwise_clip, empty_mem);
}
} else {
// 1. clamp
if (q_node.get_need_clamp()) {
auto in_lo = arg.get_dependency(dep_idx++).get_output_layout();
auto in_hi = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::algorithm clamp_max = dnnl::algorithm::binary_max;
dnnl::algorithm clamp_min = dnnl::algorithm::binary_min;
dnnl::memory::desc in_lo_desc = onednn::layout_to_memory_desc(in_lo, dnnl::memory::format_tag::ab, true);
dnnl::memory::desc in_hi_desc = onednn::layout_to_memory_desc(in_hi, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(clamp_max, in_lo_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_max, dep_idx - 2);
post_ops.append_binary(clamp_min, in_hi_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_min, dep_idx - 1);
}
// 2. pre-scale & pre-shift
{
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_per_tensor_input_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
if (idx == 0 && !q_node.get_need_clamp() && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
arg.type() == convolution::type_id() &&
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
int mask = in_scale.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
} else {
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_pre_shift()) {
if (q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
// 3. round
{
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
}
// 4. post-scale & post-shift
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_need_post_scale()) {
if (q_node.get_per_tensor_output_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_post_shift()) {
if (q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
} else {
throw std::runtime_error("Unsupported fused op for onednn prim");
}
}
// Update total onednn post-ops info
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops));
// Trying to optimize more than 1 post-ops
auto post_ops_size = onednn_fusing_map[arg.id()].size();
if (post_ops_size > 1) {
dnnl::post_ops optimized_post_ops = post_ops;
bool optimization_is_finished = false;
// Trying to combine multiplications and additions which are placed one after another.
// We do it in the cycle because "eltw + eltw" cases can be simplified again in some cases.
do {
optimized_post_ops = try_optimize_post_ops(arg, optimized_post_ops, attrs, optimization_is_finished);
} while (!optimization_is_finished);
attrs->set_post_ops(optimized_post_ops);
} else {
// Set post-ops without any optimizations
attrs->set_post_ops(post_ops);
}
return attrs;
}
@ -104,7 +819,8 @@ protected:
}
void set_arguments_impl(typed_primitive_inst<PType>& instance) override {
_args = get_arguments(instance);
uint32_t net_id = instance.get_network().get_id();
_args[net_id] = get_arguments(instance);
}
event::ptr execute_impl(const std::vector<event::ptr>& /* events */,
@ -113,6 +829,7 @@ protected:
auto& engine = network.get_engine();
auto& stream = network.get_stream();
auto profiling = engine.configuration().enable_profiling;
auto net_id = network.get_id();
event::ptr event;
if (profiling) {
@ -120,7 +837,7 @@ protected:
event = stream.create_user_event(false);
}
_prim.execute(stream.get_onednn_stream(), _args);
_prim.execute(stream.get_onednn_stream(), _args[net_id]);
if (profiling) {
stream.finish();

View File

@ -29,28 +29,36 @@ public:
bool get_need_post_scale() const { return need_post_scale; }
bool get_need_post_shift() const { return need_post_shift; }
bool get_need_clamp() const { return need_clamp; }
bool get_need_min_clamp() const { return need_min_clamp; }
bool get_need_max_clamp() const { return need_max_clamp; }
bool get_per_tensor_input_scale() const { return per_tensor_input_scale; }
bool get_per_tensor_input_shift() const { return per_tensor_input_shift; }
bool get_per_tensor_input_range() const { return per_tensor_input_range; }
bool get_per_tensor_output_scale() const { return per_tensor_output_scale; }
bool get_per_tensor_output_shift() const { return per_tensor_output_shift; }
bool get_per_tensor_output_range() const { return per_tensor_output_range; }
float get_input_scale_val() const { return in_scale; }
float get_input_shift_val() const { return in_shift; }
float get_input_lo_val() const { return in_lo; }
float get_input_hi_val() const { return in_hi; }
float get_output_scale_val() const { return out_scale; }
float get_output_shift_val() const { return out_shift; }
float get_output_lo_val() const { return out_lo; }
float get_output_hi_val() const { return out_hi; }
void set_scale_shift_opt() { scale_shift_opt = true; }
void set_need_post_scale() { need_post_scale = true; }
void set_need_post_shift() { need_post_shift = true; }
void set_need_pre_shift() { need_pre_shift = true; }
void set_need_clamp() { need_clamp = true; }
void set_need_min_clamp() { need_min_clamp = true; }
void set_need_max_clamp() { need_max_clamp = true; }
void set_per_tensor_input_scale() { per_tensor_input_scale = true; }
void set_per_tensor_input_shift() { per_tensor_input_shift = true; }
void set_per_tensor_input_range() { per_tensor_input_range = true; }
void set_per_tensor_output_scale() { per_tensor_output_scale = true; }
void set_per_tensor_output_shift() { per_tensor_output_shift = true; }
void set_per_tensor_output_range() { per_tensor_output_range = true; }
// Clamp is needed to avoid inf and -inf which are converted to undefined "inf" constant in opencl
void set_input_scale_val(float val) { in_scale = clamp(val); }
void set_input_shift_val(float val) { in_shift = clamp(val); }
@ -58,6 +66,8 @@ public:
void set_input_hi_val(float val) { in_hi = clamp(val); }
void set_output_scale_val(float val) { out_scale = clamp(val); }
void set_output_shift_val(float val) { out_shift = clamp(val); }
void set_output_lo_val(float val) { out_lo = clamp(val); }
void set_output_hi_val(float val) { out_hi = clamp(val); }
std::shared_ptr<kernel_selector::fuse_params> get_fuse_params() const override {
return std::make_shared<kernel_selector::quantize_fuse_params>(scale_shift_opt,
@ -65,15 +75,20 @@ public:
need_post_shift,
need_pre_shift,
need_clamp,
need_min_clamp,
need_max_clamp,
per_tensor_input_range,
per_tensor_input_scale,
per_tensor_input_shift,
per_tensor_output_range,
per_tensor_output_scale,
per_tensor_output_shift,
in_lo,
in_hi,
in_scale,
in_shift,
out_lo,
out_hi,
out_scale,
out_shift);
}
@ -88,10 +103,13 @@ private:
bool need_post_shift = false;
bool need_pre_shift = false;
bool need_clamp = false;
bool need_min_clamp = false;
bool need_max_clamp = false;
bool per_tensor_input_range = false;
bool per_tensor_input_scale = false;
bool per_tensor_input_shift = false;
bool per_tensor_output_range = false;
bool per_tensor_output_scale = false;
bool per_tensor_output_shift = false;
@ -99,6 +117,8 @@ private:
float in_hi = 0.0f;
float in_scale = 0.0f;
float in_shift = 0.0f;
float out_lo = 0.0f;
float out_hi = 0.0f;
float out_scale = 0.0f;
float out_shift = 0.0f;
};

View File

@ -24,6 +24,59 @@
using namespace cldnn;
static size_t get_post_ops_count(const program_node& node) {
size_t onednn_post_ops_count = 0;
for (auto& fo : node.get_fused_primitives()) {
if (fo.node->is_type<activation>() || fo.node->is_type<eltwise>()) {
onednn_post_ops_count++;
} else if (fo.node->is_type<quantize>()) {
auto& q = fo.node->as<quantize>();
// pre-scale, pre-shift
if (q.get_per_tensor_input_scale() && q.get_per_tensor_input_shift()) {
onednn_post_ops_count++;
} else {
onednn_post_ops_count += 2;
}
// post-scale, post-shift
if (q.get_need_post_scale() && q.get_need_post_shift() &&
q.get_per_tensor_output_scale() && q.get_per_tensor_output_shift()) {
onednn_post_ops_count++;
} else {
onednn_post_ops_count += 2;
}
auto out_dt = fo.output_layout.data_type;
auto output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
auto out_range_usage = q.get_per_tensor_output_range() && q.get_output_lo_val() < q.get_output_hi_val();
if (out_range_usage) {
// round
if (!output_type_is_int8) {
onednn_post_ops_count++;
}
// clamp
if (q.get_need_clamp()) {
onednn_post_ops_count++;
}
} else {
// clamp
if (q.get_need_clamp()) {
onednn_post_ops_count += 2;
}
// round
{
onednn_post_ops_count++;
}
}
}
}
return onednn_post_ops_count;
}
std::pair<std::shared_ptr<reorder>, bool> reorder_factory::get_reorder(primitive_id src_id,
const layout& in_layout,
const layout& out_layout

View File

@ -473,6 +473,47 @@ void network::allocate_primitives() {
for (auto const& node : nodes_to_allocate) {
allocate_primitive_instance(*node);
}
for (auto const& node : _program->get_processing_order()) {
if (node->get_preferred_impl_type() == impl_types::onednn) {
bool can_reuse_eltwise_mem = false;
size_t eltw_dep = 0;
for (auto& fused_op : node->get_fused_primitives()) {
if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
auto eltw_in_layout = node->get_dependency(fused_op.dep_start_idx).get_output_layout();
auto out_layout = node->get_output_layout();
if (eltw_in_layout.size == out_layout.size &&
eltw_in_layout.format == out_layout.format &&
eltw_in_layout.data_padding == out_layout.data_padding &&
data_type_traits::size_of(eltw_in_layout.data_type) == data_type_traits::size_of(out_layout.data_type)) {
if (eltw_dep > 0) {
throw std::runtime_error("Unsupported multiple full size tensors.");
}
eltw_dep = fused_op.dep_start_idx;
can_reuse_eltwise_mem = true;
}
if (fused_op.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout) && !can_reuse_eltwise_mem) {
throw std::runtime_error("Buffer reuse is required for onednn sum post operation.");
}
}
}
if (can_reuse_eltwise_mem) {
auto& eltw_in = node->get_dependency(eltw_dep);
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node->id()) != _primitives.end()) {
auto& eltw_inst = _primitives.at(eltw_in.id());
auto& prim_inst = _primitives.at(node->id());
auto& eltw_mem = eltw_inst->output_memory();
auto new_mem = eltw_mem.get_engine()->reinterpret_buffer(eltw_mem, node->get_output_layout());
prim_inst->set_output_memory(new_mem);
}
}
}
}
}
void network::build_insts_deps() {

View File

@ -1025,9 +1025,10 @@ void program::fuse_nodes(program_node &fused_node, program_node &peer_node, std:
quantize_node& q_node = peer_node.as<quantize>();
if (q_node.get_scale_shift_opt()) {
bool can_drop_input = false;
bool out_range_usage = q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val();
// Drop input range if clamp is not needed
can_drop_input |= (i == 1 || i == 2) && !q_node.get_need_clamp();
// Drop input range if we use output per-tensor range or if clamp is used for input range
can_drop_input |= (i == 1 || i == 2) && (out_range_usage || (!out_range_usage && !q_node.get_need_clamp()));
// Drop output range - it's not used in scale-shift-opt quantize kernel
can_drop_input |= i == 3 || i == 4;
// Drop tensor with input scale when we have per-tensor parameter