[GPU] Added post-ops support for OneDNN primitives (#7737)
[GPU] Memory lock fix
This commit is contained in:
parent
302eb08dc5
commit
f675df625c
@ -159,6 +159,15 @@ struct eltwise : public primitive_base<eltwise> {
|
||||
}
|
||||
}
|
||||
|
||||
bool needs_onednn_sum_post_op(layout input_layout) const {
|
||||
if (mode == eltwise_mode::sum &&
|
||||
(input_layout.size.spatial[0] > 1 || input_layout.size.spatial[1] > 1 || input_layout.size.batch[0] > 1)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// @param mode Eltwise mode.
|
||||
eltwise_mode mode;
|
||||
/// @param coefficients Blob-wise coefficient for SUM operation.
|
||||
|
@ -20,15 +20,20 @@ struct quantize_params : public base_params {
|
||||
, has_post_shift(true)
|
||||
, has_pre_shift(true)
|
||||
, has_clamp(true)
|
||||
, has_min_clamp(true)
|
||||
, has_max_clamp(true)
|
||||
, per_tensor_input_range(false)
|
||||
, per_tensor_input_scale(false)
|
||||
, per_tensor_input_shift(false)
|
||||
, per_tensor_output_range(false)
|
||||
, per_tensor_output_scale(false)
|
||||
, per_tensor_output_shift(false)
|
||||
, in_lo(0.0f)
|
||||
, in_hi(0.0f)
|
||||
, in_scale(0.0f)
|
||||
, in_shift(0.0f)
|
||||
, out_lo(0.0f)
|
||||
, out_hi(0.0f)
|
||||
, out_scale(0.0f)
|
||||
, out_shift(0.0f) { }
|
||||
|
||||
@ -39,10 +44,13 @@ struct quantize_params : public base_params {
|
||||
bool has_post_shift;
|
||||
bool has_pre_shift;
|
||||
bool has_clamp;
|
||||
bool has_min_clamp;
|
||||
bool has_max_clamp;
|
||||
|
||||
bool per_tensor_input_range;
|
||||
bool per_tensor_input_scale;
|
||||
bool per_tensor_input_shift;
|
||||
bool per_tensor_output_range;
|
||||
bool per_tensor_output_scale;
|
||||
bool per_tensor_output_shift;
|
||||
|
||||
@ -50,6 +58,8 @@ struct quantize_params : public base_params {
|
||||
float in_hi;
|
||||
float in_scale;
|
||||
float in_shift;
|
||||
float out_lo;
|
||||
float out_hi;
|
||||
float out_scale;
|
||||
float out_shift;
|
||||
|
||||
@ -79,15 +89,20 @@ struct quantize_fuse_params : fuse_params {
|
||||
bool has_post_shift,
|
||||
bool has_pre_shift,
|
||||
bool has_clamp,
|
||||
bool has_min_clamp,
|
||||
bool has_max_clamp,
|
||||
bool per_tensor_input_range,
|
||||
bool per_tensor_input_scale,
|
||||
bool per_tensor_input_shift,
|
||||
bool per_tensor_output_range,
|
||||
bool per_tensor_output_scale,
|
||||
bool per_tensor_output_shift,
|
||||
float in_lo,
|
||||
float in_hi,
|
||||
float in_scale,
|
||||
float in_shift,
|
||||
float out_lo,
|
||||
float out_hi,
|
||||
float out_scale,
|
||||
float out_shift)
|
||||
: fuse_params(KernelType::QUANTIZE)
|
||||
@ -96,19 +111,25 @@ struct quantize_fuse_params : fuse_params {
|
||||
, has_post_shift(has_post_shift)
|
||||
, has_pre_shift(has_pre_shift)
|
||||
, has_clamp(has_clamp)
|
||||
, has_min_clamp(has_min_clamp)
|
||||
, has_max_clamp(has_max_clamp)
|
||||
, per_tensor_input_range(per_tensor_input_range)
|
||||
, per_tensor_input_scale(per_tensor_input_scale)
|
||||
, per_tensor_input_shift(per_tensor_input_shift)
|
||||
, per_tensor_output_range(per_tensor_output_range)
|
||||
, per_tensor_output_scale(per_tensor_output_scale)
|
||||
, per_tensor_output_shift(per_tensor_output_shift)
|
||||
, in_lo(in_lo)
|
||||
, in_hi(in_hi)
|
||||
, in_scale(in_scale)
|
||||
, in_shift(in_shift)
|
||||
, out_lo(out_lo)
|
||||
, out_hi(out_hi)
|
||||
, out_scale(out_scale)
|
||||
, out_shift(out_shift) {
|
||||
size_t index = 0;
|
||||
if (has_clamp) {
|
||||
bool out_range_usage = per_tensor_output_range && out_lo < out_hi;
|
||||
if (!out_range_usage && has_clamp) {
|
||||
in_range_lo_idx = index++;
|
||||
in_range_hi_idx = index++;
|
||||
}
|
||||
@ -131,10 +152,13 @@ struct quantize_fuse_params : fuse_params {
|
||||
bool has_post_shift;
|
||||
bool has_pre_shift;
|
||||
bool has_clamp;
|
||||
bool has_min_clamp;
|
||||
bool has_max_clamp;
|
||||
|
||||
bool per_tensor_input_range;
|
||||
bool per_tensor_input_scale;
|
||||
bool per_tensor_input_shift;
|
||||
bool per_tensor_output_range;
|
||||
bool per_tensor_output_scale;
|
||||
bool per_tensor_output_shift;
|
||||
|
||||
@ -142,6 +166,8 @@ struct quantize_fuse_params : fuse_params {
|
||||
float in_hi;
|
||||
float in_scale;
|
||||
float in_shift;
|
||||
float out_lo;
|
||||
float out_hi;
|
||||
float out_scale;
|
||||
float out_shift;
|
||||
|
||||
|
@ -1631,33 +1631,88 @@ JitConstants FusedOpsCodeGenerator::MakeOpJitConstants(const FusedOpsConfigurati
|
||||
: ConvertToType(GetInputVarName(p->in_scale_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
|
||||
auto pre_shift = p->per_tensor_input_shift ? Broadcast(toCodeString(p->in_shift), tmp_type, vec_size)
|
||||
: ConvertToType(GetInputVarName(p->in_shift_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
|
||||
auto in_lo = p->per_tensor_input_range ? Broadcast(toCodeString(p->in_lo), tmp_type, vec_size)
|
||||
: ConvertToType(GetInputVarName(p->in_range_lo_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
|
||||
auto in_hi = p->per_tensor_input_range ? Broadcast(toCodeString(p->in_hi), tmp_type, vec_size)
|
||||
: ConvertToType(GetInputVarName(p->in_range_hi_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
|
||||
|
||||
if (p->has_clamp) {
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = min(max(" + in_lo + ", " + in_converted + "), " + in_hi + ");";
|
||||
} else {
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = " + in_converted + ";";
|
||||
}
|
||||
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + "*" + pre_scale + ";";
|
||||
if (p->per_tensor_output_range && p->out_lo < p->out_hi) {
|
||||
// Input scale
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = " + in_converted + " * " + pre_scale + ";";
|
||||
|
||||
// Input shift
|
||||
if (p->has_pre_shift)
|
||||
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + " + " + pre_shift + ";";
|
||||
|
||||
// Round operation isn't needed if output type is int8/uint8 and scale coefficient in all output channels is equal to 1.0
|
||||
bool output_type_is_int8 = desc.output_tensor.GetDType() == Datatype::UINT8 || desc.output_tensor.GetDType() == Datatype::INT8;
|
||||
if (((p->has_post_scale || p->has_post_shift) && output_type_is_int8) || !output_type_is_int8)
|
||||
op_decls += "\\\n\t" + tmp_var + " = round(" + tmp_var + ");";
|
||||
|
||||
bool need_round = (p->has_post_scale || p->has_post_shift) &&
|
||||
(desc.output_tensor.GetDType() == Datatype::UINT8 || desc.output_tensor.GetDType() == Datatype::INT8);
|
||||
// Output scale
|
||||
if (p->has_post_scale)
|
||||
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " * " + post_scale + ");";
|
||||
|
||||
// Output shift
|
||||
if (p->has_post_shift)
|
||||
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " + " + post_shift + ");";
|
||||
if (need_round)
|
||||
op_decls += "\\\n\t" + tmp_var + " = round(" + tmp_var + ");";
|
||||
|
||||
// Output range
|
||||
auto out_lo = Broadcast(std::to_string(p->out_lo), tmp_type, vec_size);
|
||||
auto out_hi = Broadcast(std::to_string(p->out_hi), tmp_type, vec_size);
|
||||
|
||||
// Output clamp
|
||||
if (p->has_clamp) {
|
||||
if (p->has_min_clamp && p->has_max_clamp)
|
||||
op_decls += "\\\n\t" + tmp_var + " = clamp(" + tmp_var + ", " + out_lo + ", " + out_hi + ");";
|
||||
else if (p->has_min_clamp)
|
||||
op_decls += "\\\n\t" + tmp_var + " = max(" + tmp_var + ", " + out_lo + ");";
|
||||
else
|
||||
op_decls += "\\\n\t" + tmp_var + " = min(" + tmp_var + ", " + out_hi + ");";
|
||||
}
|
||||
|
||||
// Output conversion with rounding and saturation
|
||||
op_decls += "\\\n\t" + GetOutputType(vec_size) + " " + out_var + " = " + ConvertToOutputTypeSat(tmp_var, vec_size) + ";";
|
||||
break;
|
||||
} else {
|
||||
// Input range
|
||||
auto in_lo = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_lo), tmp_type, vec_size)
|
||||
: ConvertToType(GetInputVarName(p->in_range_lo_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
|
||||
auto in_hi = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_hi), tmp_type, vec_size)
|
||||
: ConvertToType(GetInputVarName(p->in_range_hi_idx, is_shuffled, shuffle_var), tmp_type, vec_size);
|
||||
|
||||
// Input clamp
|
||||
if (p->has_clamp) {
|
||||
if (p->has_min_clamp && p->has_max_clamp)
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = clamp(" + in_converted + ", " + in_lo + ", " + in_hi + ");";
|
||||
else if (p->has_min_clamp)
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = max(" + in_converted + ", " + in_lo + ");";
|
||||
else
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = min(" + in_converted + ", " + in_hi + ");";
|
||||
} else {
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = " + in_converted + ";";
|
||||
}
|
||||
|
||||
// Input scale
|
||||
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + " * " + pre_scale + ";";
|
||||
|
||||
// Input shift
|
||||
if (p->has_pre_shift)
|
||||
op_decls += "\\\n\t" + tmp_var + " = " + tmp_var + " + " + pre_shift + ";";
|
||||
|
||||
// Round operation isn't needed if output type is int8/uint8 and scale coefficient in all output channels is equal to 1.0
|
||||
bool output_type_is_int8 = desc.output_tensor.GetDType() == Datatype::UINT8 || desc.output_tensor.GetDType() == Datatype::INT8;
|
||||
if (((p->has_post_scale || p->has_post_shift) && output_type_is_int8) || !output_type_is_int8)
|
||||
op_decls += "\\\n\t" + tmp_var + " = round(" + tmp_var + ");";
|
||||
|
||||
// Output scale
|
||||
if (p->has_post_scale)
|
||||
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " * " + post_scale + ");";
|
||||
|
||||
// Output shift
|
||||
if (p->has_post_shift)
|
||||
op_decls += "\\\n\t" + tmp_var + " = (" + tmp_var + " + " + post_shift + ");";
|
||||
|
||||
// Output conversion with rounding and saturation
|
||||
op_decls += "\\\n\t" + GetOutputType(vec_size) + " " + out_var + " = " + ConvertToOutputTypeSat(tmp_var, vec_size) + ";";
|
||||
break;
|
||||
}
|
||||
}
|
||||
case KernelType::ACTIVATION: {
|
||||
auto p = desc.GetOpParams<activation_fuse_params>();
|
||||
@ -1871,7 +1926,7 @@ std::string FusedOpsCodeGenerator::ConvertToOutputTypeSat(std::string var, size_
|
||||
if (desc.output_tensor.GetDType() == Datatype::F32 || desc.output_tensor.GetDType() == Datatype::F16)
|
||||
return "convert_" + GetOutputType(vec_size) + "(" + var + ")";
|
||||
else
|
||||
return "convert_" + GetOutputType(vec_size) + "_sat(" + var + ")";
|
||||
return "convert_" + GetOutputType(vec_size) + "_sat_rte(" + var + ")";
|
||||
}
|
||||
|
||||
std::vector<size_t> FusedOpsCodeGenerator::GetRequiredInputs() const {
|
||||
@ -1880,7 +1935,8 @@ std::vector<size_t> FusedOpsCodeGenerator::GetRequiredInputs() const {
|
||||
auto p = std::dynamic_pointer_cast<quantize_fuse_params>(desc.op_params);
|
||||
if (p) {
|
||||
std::vector<size_t> res = {};
|
||||
if (!p->per_tensor_input_range && p->has_clamp) {
|
||||
bool out_range_usage = p->per_tensor_output_range && p->out_lo < p->out_hi;
|
||||
if (!out_range_usage && p->has_clamp) {
|
||||
res.push_back(p->in_range_lo_idx);
|
||||
res.push_back(p->in_range_hi_idx);
|
||||
}
|
||||
|
@ -36,6 +36,40 @@ void basic_memory_dependencies::run(program& p) {
|
||||
add_memory_dependency(it, node);
|
||||
}
|
||||
|
||||
if (node->is_type<convolution>()) {
|
||||
auto& conv = node->as<convolution>();
|
||||
bool can_reuse_eltwise_mem = false;
|
||||
size_t eltw_dep = 0;
|
||||
|
||||
for (auto& fused_op : conv.get_fused_primitives()) {
|
||||
if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
|
||||
auto eltw_in_layout = conv.get_dependency(fused_op.dep_start_idx).get_output_layout();
|
||||
auto conv_out_layout = node->get_output_layout();
|
||||
if (eltw_dep > 0) {
|
||||
can_reuse_eltwise_mem = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (eltw_in_layout.size == conv_out_layout.size &&
|
||||
eltw_in_layout.format == conv_out_layout.format &&
|
||||
eltw_in_layout.data_padding == conv_out_layout.data_padding &&
|
||||
data_type_traits::size_of(eltw_in_layout.data_type) == data_type_traits::size_of(conv_out_layout.data_type)) {
|
||||
eltw_dep = fused_op.dep_start_idx;
|
||||
can_reuse_eltwise_mem = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (can_reuse_eltwise_mem) {
|
||||
auto& eltw_node = conv.get_dependency(eltw_dep);
|
||||
eltw_node.can_share_buffer(false);
|
||||
conv.can_share_buffer(false);
|
||||
for (auto& user : conv.get_users()) {
|
||||
add_memory_dependency(user, &eltw_node);
|
||||
add_memory_dependency(user, &conv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note we iterate over processing order, it means if primitve has processing num greater than any of outputs,
|
||||
// this output has to land on the primitve restriction list. Otherwise memory reuse can corrupt final results.
|
||||
node->add_memory_dependency(past_outputs);
|
||||
|
@ -6,6 +6,8 @@
|
||||
|
||||
#include "pooling_inst.h"
|
||||
#include "quantize_inst.h"
|
||||
#include "reshape_inst.h"
|
||||
#include "reorder_inst.h"
|
||||
#include "binary_convolution_inst.h"
|
||||
#include "scale_inst.h"
|
||||
#include "eltwise_inst.h"
|
||||
@ -41,6 +43,57 @@ bool check_binarization(memory::ptr mem_input_low, memory::ptr mem_input_high, p
|
||||
|
||||
void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& quantize_node) {
|
||||
const auto& stream = p.get_stream();
|
||||
|
||||
size_t out_features = static_cast<size_t>(quantize_node.get_output_layout().size.feature[0]);
|
||||
float* bias_values = nullptr;
|
||||
cldnn::memory* bias_mem_ptr = nullptr;
|
||||
bool can_merge_bias = false;
|
||||
size_t bias_depth = 0;
|
||||
|
||||
// Will try to merge bias into FQ
|
||||
auto &merge_node = quantize_node.get_dependency(0);
|
||||
|
||||
if (merge_node.is_type<eltwise>() && merge_node.get_dependencies().size() == 2) {
|
||||
auto& eltw_node = merge_node.as<eltwise>();
|
||||
auto& eltw_node_dep1 = eltw_node.get_dependency(1);
|
||||
|
||||
// Check that this is not input layout
|
||||
if (!eltw_node_dep1.is_type<input_layout>()) {
|
||||
// We should check a case with reshape / reorder nodes before bias constant data
|
||||
if (eltw_node_dep1.is_type<data>()) {
|
||||
bias_depth = 1;
|
||||
} else if (eltw_node_dep1.get_dependencies().size()) {
|
||||
auto has_extra_nodes1 = eltw_node_dep1.is_type<reshape>() || eltw_node_dep1.is_type<reorder>();
|
||||
if (has_extra_nodes1 && eltw_node_dep1.get_dependency(0).is_type<data>()) {
|
||||
bias_depth = 2;
|
||||
} else if (has_extra_nodes1 && eltw_node_dep1.get_dependency(0).get_dependencies().size()) {
|
||||
auto has_extra_nodes2 = eltw_node_dep1.get_dependency(0).is_type<reshape>() || eltw_node_dep1.get_dependency(0).is_type<reorder>();
|
||||
if (has_extra_nodes2 && eltw_node_dep1.get_dependency(0).get_dependency(0).is_type<data>())
|
||||
bias_depth = 3;
|
||||
}
|
||||
}
|
||||
|
||||
auto& dep = bias_depth == 1 ? eltw_node_dep1 :
|
||||
bias_depth == 2 ? eltw_node_dep1.get_dependency(0) :
|
||||
bias_depth == 3 ? eltw_node_dep1.get_dependency(0).get_dependency(0) :
|
||||
eltw_node_dep1;
|
||||
|
||||
if (bias_depth) {
|
||||
can_merge_bias = dep.is_constant() && dep.get_output_layout().count() == out_features && dep.get_users().size() == 1 &&
|
||||
eltw_node.get_primitive()->mode == eltwise_mode::sum && eltw_node.get_dependencies().size() == 2 &&
|
||||
eltw_node.get_dependency(0).is_type<convolution>();
|
||||
}
|
||||
|
||||
if (can_merge_bias) {
|
||||
auto &bias = dep.as<data>();
|
||||
auto &mem_bias = bias.get_attached_memory();
|
||||
bias_mem_ptr = &mem_bias;
|
||||
auto data_bias_ptr = static_cast<float*>(mem_bias.lock(stream));
|
||||
bias_values = data_bias_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
program_node &input_low_node = quantize_node.get_dependency(1);
|
||||
program_node &input_high_node = quantize_node.get_dependency(2);
|
||||
program_node &output_low_node = quantize_node.get_dependency(3);
|
||||
@ -83,9 +136,11 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
|
||||
auto lock_memory = [&stream] (memory::ptr memory, std::function<void(std::size_t, float)>& set_data,
|
||||
std::function<float(size_t)>& get_data) {
|
||||
using float_mem_lock = mem_lock<float, mem_lock_type::write>;
|
||||
using uint16_t_mem_lock = mem_lock<uint16_t, mem_lock_type::write>;
|
||||
switch (memory->get_layout().data_type) {
|
||||
case data_types::f32: {
|
||||
std::shared_ptr<mem_lock<float, mem_lock_type::write>> data_lock_ptr = std::make_shared<mem_lock<float, mem_lock_type::write>>(memory, stream);
|
||||
std::shared_ptr<float_mem_lock> data_lock_ptr = std::make_shared<float_mem_lock>(memory, stream);
|
||||
float* data = data_lock_ptr->data();
|
||||
set_data = [data] (size_t idx, float value) {
|
||||
data[idx] = value;
|
||||
@ -93,11 +148,10 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
get_data = [data] (size_t idx) {
|
||||
return data[idx];
|
||||
};
|
||||
return std::pair<std::shared_ptr<mem_lock<float, mem_lock_type::write>>,
|
||||
std::shared_ptr<mem_lock<uint16_t, mem_lock_type::write>>>(data_lock_ptr, nullptr);
|
||||
return std::pair<std::shared_ptr<float_mem_lock>, std::shared_ptr<uint16_t_mem_lock>>(data_lock_ptr, nullptr);
|
||||
}
|
||||
case data_types::f16: {
|
||||
auto data_lock_ptr = std::make_shared<mem_lock<uint16_t, mem_lock_type::write>>(memory, stream);
|
||||
std::shared_ptr<uint16_t_mem_lock> data_lock_ptr = std::make_shared<uint16_t_mem_lock>(memory, stream);
|
||||
uint16_t* data = data_lock_ptr->data();
|
||||
set_data = [data] (size_t idx, float value) {
|
||||
data[idx] = float_to_half(value);
|
||||
@ -105,8 +159,7 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
get_data = [data] (size_t idx) {
|
||||
return half_to_float(data[idx]);
|
||||
};
|
||||
return std::pair<std::shared_ptr<mem_lock<float, mem_lock_type::write>>,
|
||||
std::shared_ptr<mem_lock<uint16_t, mem_lock_type::write>>>(nullptr, data_lock_ptr);
|
||||
return std::pair<std::shared_ptr<float_mem_lock>, std::shared_ptr<uint16_t_mem_lock>>(nullptr, data_lock_ptr);
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("prepare_quantization: Unsupported precision of quantize output values");
|
||||
@ -161,8 +214,9 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
|
||||
float out_lo = get_data_output_low(get_offset_safe(mem_output_low->get_layout(), idx));
|
||||
float out_hi = get_data_output_high(get_offset_safe(mem_output_high->get_layout(), idx));
|
||||
set_data_input_scale(s_offset, (static_cast<float>(levels) - 1.f) / (in_hi - in_lo));
|
||||
set_data_input_shift(s_offset, - in_lo * (static_cast<float>(levels) - 1.f) / (in_hi - in_lo));
|
||||
float in_shift_basic = (static_cast<float>(levels) - 1.f) / (in_hi - in_lo);
|
||||
set_data_input_scale(s_offset, in_shift_basic);
|
||||
set_data_input_shift(s_offset, can_merge_bias ? (bias_values[f] - in_lo) * in_shift_basic : -in_lo * in_shift_basic);
|
||||
set_data_output_scale(s_offset, (out_hi - out_lo) / (static_cast<float>(levels) - 1.f));
|
||||
set_data_output_shift(s_offset, out_lo);
|
||||
|
||||
@ -186,12 +240,15 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
bool per_tensor_in_range = true;
|
||||
bool per_tensor_out_scale = true;
|
||||
bool per_tensor_out_shift = true;
|
||||
bool per_tensor_out_range = true;
|
||||
float in_scale_val = get_data_input_scale(0);
|
||||
float in_shift_val = get_data_input_shift(0);
|
||||
float out_scale_val = get_data_output_scale(0);
|
||||
float out_shift_val = get_data_output_shift(0);
|
||||
float in_lo_val = get_data_input_low(0);
|
||||
float in_hi_val = get_data_input_high(0);
|
||||
float out_lo_val = get_data_output_low(0);
|
||||
float out_hi_val = get_data_output_high(0);
|
||||
for (size_t i = 0; i < scales_layout.count(); i++) {
|
||||
if (in_scale_val != get_data_input_scale(i))
|
||||
per_tensor_in_scale = false;
|
||||
@ -207,9 +264,56 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
if (in_lo_val != get_data_input_low(i % mem_input_low->get_layout().count()) ||
|
||||
in_hi_val != get_data_input_high(i % mem_input_high->get_layout().count()))
|
||||
per_tensor_in_range = false;
|
||||
if (out_lo_val != get_data_output_low(i % mem_output_low->get_layout().count()) ||
|
||||
out_hi_val != get_data_output_high(i % mem_output_high->get_layout().count()))
|
||||
per_tensor_out_range = false;
|
||||
}
|
||||
|
||||
auto out_is_int8 = quantize_node.get_output_layout().data_type == data_types::i8;
|
||||
auto out_is_uint8 = quantize_node.get_output_layout().data_type == data_types::u8;
|
||||
auto out_is_fp = !(out_is_int8 || out_is_uint8);
|
||||
bool need_clamp = levels != 256 || out_is_fp;
|
||||
bool need_min_clamp = need_clamp;
|
||||
bool need_max_clamp = need_clamp;
|
||||
|
||||
// Check that we can optimize clamp operation for int8 data using saturation clamp only
|
||||
if (per_tensor_out_range && !out_is_fp && levels != 256) {
|
||||
if ((out_is_int8 && out_lo_val == -128.f) || (out_is_uint8 && out_lo_val == 0.f))
|
||||
need_min_clamp = false;
|
||||
if ((out_is_int8 && out_hi_val == 127.f) || (out_is_uint8 && out_hi_val == 255.f))
|
||||
need_max_clamp = false;
|
||||
}
|
||||
|
||||
// Check that we can merge bias into FQ input shift and if yes then
|
||||
// we remove bias from network graph
|
||||
if (can_merge_bias) {
|
||||
auto &eltw_node = merge_node.as<eltwise>();
|
||||
|
||||
// Remove bias constants and extra reshapes / reorders from the graph (dep3, dep2, dep1)
|
||||
if (bias_depth == 3) {
|
||||
auto &dep3 = eltw_node.get_dependency(1).get_dependency(0).get_dependency(0);
|
||||
p.remove_all_connections(dep3);
|
||||
p.remove_if_dangling(dep3);
|
||||
}
|
||||
|
||||
if (bias_depth >= 2) {
|
||||
auto &dep2 = eltw_node.get_dependency(1).get_dependency(0);
|
||||
p.remove_all_connections(dep2);
|
||||
p.remove_if_dangling(dep2);
|
||||
}
|
||||
|
||||
auto &dep1 = eltw_node.get_dependency(1);
|
||||
p.remove_all_connections(dep1);
|
||||
p.remove_if_dangling(dep1);
|
||||
|
||||
// Remove bias from the graph (eltwise in a "sum" mode)
|
||||
p.extract_and_remove(eltw_node);
|
||||
}
|
||||
|
||||
if (has_negative_scales) {
|
||||
if (can_merge_bias)
|
||||
bias_mem_ptr->unlock(stream);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@ -266,17 +370,30 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
quantize_node.set_input_shift_val(in_shift_val);
|
||||
}
|
||||
|
||||
auto out_dt = quantize_node.get_output_layout().data_type;
|
||||
bool need_clamp = levels != 256 || (out_dt != data_types::u8 && out_dt != data_types::i8);
|
||||
if (need_clamp) {
|
||||
quantize_node.set_need_clamp();
|
||||
}
|
||||
|
||||
if (need_min_clamp) {
|
||||
quantize_node.set_need_min_clamp();
|
||||
}
|
||||
|
||||
if (need_max_clamp) {
|
||||
quantize_node.set_need_max_clamp();
|
||||
}
|
||||
|
||||
if (per_tensor_in_range) {
|
||||
quantize_node.set_per_tensor_input_range();
|
||||
quantize_node.set_input_lo_val(in_lo_val);
|
||||
quantize_node.set_input_hi_val(in_hi_val);
|
||||
}
|
||||
|
||||
if (per_tensor_out_range) {
|
||||
quantize_node.set_per_tensor_output_range();
|
||||
quantize_node.set_output_lo_val(out_lo_val);
|
||||
quantize_node.set_output_hi_val(out_hi_val);
|
||||
}
|
||||
|
||||
if (per_tensor_out_scale) {
|
||||
quantize_node.set_per_tensor_output_scale();
|
||||
quantize_node.set_output_scale_val(out_scale_val);
|
||||
@ -286,6 +403,10 @@ void prepare_quantization::prepare_scale_shift_opt(program &p, quantize_node& q
|
||||
quantize_node.set_per_tensor_output_shift();
|
||||
quantize_node.set_output_shift_val(out_shift_val);
|
||||
}
|
||||
|
||||
if (can_merge_bias) {
|
||||
bias_mem_ptr->unlock(stream);
|
||||
}
|
||||
}
|
||||
|
||||
void prepare_quantization::handle_quantize_node(program& p, quantize_node& quantize_node) {
|
||||
|
@ -55,10 +55,13 @@ public:
|
||||
quantize_params.has_post_shift = arg.get_need_post_shift();
|
||||
quantize_params.has_pre_shift = arg.get_need_pre_shift();
|
||||
quantize_params.has_clamp = arg.get_need_clamp();
|
||||
quantize_params.has_min_clamp = arg.get_need_min_clamp();
|
||||
quantize_params.has_max_clamp = arg.get_need_max_clamp();
|
||||
|
||||
quantize_params.per_tensor_input_range = arg.get_per_tensor_input_range();
|
||||
quantize_params.per_tensor_input_scale = arg.get_per_tensor_input_scale();
|
||||
quantize_params.per_tensor_input_shift = arg.get_per_tensor_input_shift();
|
||||
quantize_params.per_tensor_output_range = arg.get_per_tensor_output_range();
|
||||
quantize_params.per_tensor_output_scale = arg.get_per_tensor_output_scale();
|
||||
quantize_params.per_tensor_output_shift = arg.get_per_tensor_output_shift();
|
||||
|
||||
@ -66,6 +69,8 @@ public:
|
||||
quantize_params.in_hi = arg.get_input_hi_val();
|
||||
quantize_params.in_scale = arg.get_input_scale_val();
|
||||
quantize_params.in_shift = arg.get_input_shift_val();
|
||||
quantize_params.out_lo = arg.get_output_lo_val();
|
||||
quantize_params.out_hi = arg.get_output_hi_val();
|
||||
quantize_params.out_scale = arg.get_output_scale_val();
|
||||
quantize_params.out_shift = arg.get_output_shift_val();
|
||||
|
||||
|
@ -25,6 +25,36 @@
|
||||
namespace cldnn {
|
||||
namespace onednn {
|
||||
|
||||
enum class onednn_post_op_type : uint32_t {
|
||||
eltwise_act,
|
||||
eltwise_clip,
|
||||
eltwise_linear,
|
||||
eltwise_round,
|
||||
binary_mul,
|
||||
binary_add,
|
||||
binary_max,
|
||||
binary_min,
|
||||
scale,
|
||||
sum,
|
||||
optimized,
|
||||
optimized_eltwise,
|
||||
optimized_sum
|
||||
};
|
||||
|
||||
struct onednn_post_op_desc {
|
||||
onednn_post_op_type op_type;
|
||||
size_t mem_offset;
|
||||
size_t mem_dep;
|
||||
};
|
||||
|
||||
// This map contains information about onednn post-ops types, memory buffer offsets and dependencies
|
||||
// key is cldnn::primitive_id,
|
||||
// value is an instance of struct onednn_post_op_desc containing info about post-ops related to the node defined by key:
|
||||
// op_type - onednn_post_op_type (enum),
|
||||
// mem_offset - index of memory buffer for current post-operation,
|
||||
// mem_dep - memory dependency for working with fused node
|
||||
static std::unordered_map<cldnn::primitive_id, std::vector<onednn_post_op_desc>> onednn_fusing_map;
|
||||
|
||||
template <class PType, class DescType, class PrimDescType = dnnl::primitive_desc, class PrimType = dnnl::primitive>
|
||||
struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
|
||||
const typed_program_node<PType>& _outer;
|
||||
@ -32,7 +62,7 @@ struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
|
||||
std::shared_ptr<dnnl::primitive_attr> _attrs;
|
||||
PrimDescType _pd;
|
||||
PrimType _prim;
|
||||
std::unordered_map<int, dnnl::memory> _args;
|
||||
std::unordered_map<uint32_t, std::unordered_map<int, dnnl::memory>> _args;
|
||||
|
||||
typed_primitive_onednn_impl(const typed_program_node<PType>& arg,
|
||||
std::shared_ptr<DescType> desc,
|
||||
@ -67,12 +97,431 @@ protected:
|
||||
return !zp.empty() && (reinterpret_cast<const int32_t&>(zp[0]) == drsv);
|
||||
}
|
||||
|
||||
static dnnl::post_ops try_optimize_post_ops(const typed_program_node<PType>& arg, dnnl::post_ops& p_ops,
|
||||
const std::shared_ptr<dnnl::primitive_attr>& attr,
|
||||
bool& optimization_is_completed) {
|
||||
// Get current node id for creating of optimization map
|
||||
auto node_id = arg.id();
|
||||
|
||||
// Create new dnnl::post_ops object which will be filled inside the optimization process
|
||||
dnnl::post_ops optimized_p_ops;
|
||||
|
||||
// Add new post-op into optimized_p_ops structure
|
||||
auto add_post_op = [&](onednn_post_op_type type, const dnnl::post_ops& cur_p_ops, dnnl::post_ops& new_p_ops, int idx) {
|
||||
switch (type) {
|
||||
case onednn_post_op_type::eltwise_act:
|
||||
case onednn_post_op_type::eltwise_clip:
|
||||
case onednn_post_op_type::eltwise_linear:
|
||||
case onednn_post_op_type::eltwise_round:
|
||||
{
|
||||
dnnl::algorithm alg;
|
||||
float scale, alpha, beta;
|
||||
cur_p_ops.get_params_eltwise(idx, scale, alg, alpha, beta);
|
||||
new_p_ops.append_eltwise(scale, alg, alpha, beta);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_add:
|
||||
case onednn_post_op_type::binary_mul:
|
||||
case onednn_post_op_type::binary_max:
|
||||
case onednn_post_op_type::binary_min:
|
||||
{
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
cur_p_ops.get_params_binary(idx, alg, desc);
|
||||
new_p_ops.append_binary(alg, desc);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::scale:
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::sum:
|
||||
{
|
||||
float scale;
|
||||
dnnl::memory::data_type data_type;
|
||||
cur_p_ops.get_params_sum(idx, scale, data_type);
|
||||
new_p_ops.append_sum(scale, data_type);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::optimized:
|
||||
case onednn_post_op_type::optimized_sum:
|
||||
case onednn_post_op_type::optimized_eltwise:
|
||||
{
|
||||
// Current operation already has been optimized => don't need extra actions
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported onednn post-operation type");
|
||||
}
|
||||
};
|
||||
|
||||
auto& cur_post_ops = onednn_fusing_map[node_id];
|
||||
|
||||
size_t cur_post_op_idx = 1;
|
||||
size_t prev_post_op_idx = 0;
|
||||
bool optimization_done = false;
|
||||
|
||||
// Check and update post-op map if we already optimized something
|
||||
for (size_t post_op_idx = 0; post_op_idx < cur_post_ops.size(); post_op_idx++) {
|
||||
if (cur_post_ops[post_op_idx].op_type == onednn_post_op_type::optimized_sum)
|
||||
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
|
||||
else if (cur_post_ops[post_op_idx].op_type == onednn_post_op_type::optimized_eltwise)
|
||||
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
|
||||
else if (cur_post_ops[post_op_idx].op_type == onednn_post_op_type::optimized)
|
||||
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
|
||||
}
|
||||
|
||||
// Get post-ops size for current node
|
||||
auto post_ops_size = cur_post_ops.size();
|
||||
|
||||
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
|
||||
while (!optimization_done) {
|
||||
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
auto prev_type = cur_post_ops[prev_post_op_idx].op_type;
|
||||
|
||||
// Ignore optimized operations for "previous" operation in our operation pair
|
||||
while ((prev_type == onednn_post_op_type::optimized || prev_type == onednn_post_op_type::optimized_sum ||
|
||||
prev_type == onednn_post_op_type::optimized_eltwise) && cur_post_op_idx < post_ops_size - 1) {
|
||||
prev_post_op_idx++;
|
||||
cur_post_op_idx++;
|
||||
prev_type = cur_post_ops[prev_post_op_idx].op_type;
|
||||
cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
}
|
||||
|
||||
// Ignore optimized operations for "current" operation in our operation pair
|
||||
while ((cur_type == onednn_post_op_type::optimized || cur_type == onednn_post_op_type::optimized_sum ||
|
||||
cur_type == onednn_post_op_type::optimized_eltwise) && cur_post_op_idx < post_ops_size - 1) {
|
||||
cur_post_op_idx++;
|
||||
cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
}
|
||||
|
||||
auto cur_idx = static_cast<int>(has_out_scales(attr) ? (cur_post_op_idx >= 1 ? cur_post_op_idx - 1 : 0) : cur_post_op_idx);
|
||||
auto prev_idx = static_cast<int>(has_out_scales(attr) ? (prev_post_op_idx >= 1 ? prev_post_op_idx - 1 : 0) : prev_post_op_idx);
|
||||
auto cur_type_is_optimized = cur_type == onednn_post_op_type::optimized ||
|
||||
cur_type == onednn_post_op_type::optimized_sum ||
|
||||
cur_type == onednn_post_op_type::optimized_eltwise;
|
||||
auto prev_type_is_optimized = prev_type == onednn_post_op_type::optimized ||
|
||||
prev_type == onednn_post_op_type::optimized_sum ||
|
||||
prev_type == onednn_post_op_type::optimized_eltwise;
|
||||
|
||||
// If this is the last pair and it's optimized - add the last post-op and go out from the cycle
|
||||
if (cur_post_op_idx == post_ops_size - 1 && (cur_type_is_optimized || prev_type_is_optimized)) {
|
||||
if (!prev_type_is_optimized) {
|
||||
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
|
||||
}
|
||||
if (!cur_type_is_optimized) {
|
||||
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
auto equal_ops = cur_type == prev_type;
|
||||
auto cur_type_is_binary_add_or_mul = cur_type == onednn_post_op_type::binary_add || cur_type == onednn_post_op_type::binary_mul;
|
||||
auto prev_type_is_binary_add_or_mul = prev_type == onednn_post_op_type::binary_add || prev_type == onednn_post_op_type::binary_mul;
|
||||
|
||||
// Post-ops combinations which can be simplified
|
||||
auto eltw_and_eltw = equal_ops && cur_type == onednn_post_op_type::eltwise_linear;
|
||||
auto bin_and_eltw = cur_type_is_binary_add_or_mul && prev_type == onednn_post_op_type::eltwise_linear;
|
||||
auto eltw_and_bin = cur_type == onednn_post_op_type::eltwise_linear && prev_type_is_binary_add_or_mul;
|
||||
auto eltw_and_sum = cur_type == onednn_post_op_type::eltwise_linear && prev_type == onednn_post_op_type::sum;
|
||||
auto eltw_and_scale = cur_type == onednn_post_op_type::eltwise_linear && prev_type == onednn_post_op_type::scale;
|
||||
|
||||
auto can_try_optimize = eltw_and_eltw ||
|
||||
bin_and_eltw ||
|
||||
eltw_and_bin ||
|
||||
eltw_and_sum ||
|
||||
eltw_and_scale;
|
||||
|
||||
bool cur_ops_pair_is_optimized = false;
|
||||
|
||||
if (can_try_optimize) {
|
||||
if (eltw_and_eltw) {
|
||||
dnnl::algorithm alg;
|
||||
float cur_scale, prev_scale, cur_alpha, prev_alpha, cur_beta, prev_beta;
|
||||
|
||||
p_ops.get_params_eltwise(prev_idx, prev_scale, alg, prev_alpha, prev_beta);
|
||||
p_ops.get_params_eltwise(cur_idx, cur_scale, alg, cur_alpha, cur_beta);
|
||||
|
||||
// Eltwise + eltwise pair can be optimized only if cur_alpha is equal to 1.0f
|
||||
if (cur_alpha == 1.0f && prev_scale == cur_scale) {
|
||||
dnnl::post_ops eltw_p_op;
|
||||
eltw_p_op.append_eltwise(cur_scale, alg, prev_alpha, cur_beta + prev_beta);
|
||||
|
||||
// Combine 2 eltwises into one
|
||||
add_post_op(cur_type, eltw_p_op, optimized_p_ops, 0);
|
||||
|
||||
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
|
||||
|
||||
// Set the flag if extra optimizations checking is needed
|
||||
if (cur_post_op_idx < post_ops_size - 1) {
|
||||
if (cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::eltwise_linear ||
|
||||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::binary_add ||
|
||||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::binary_mul ||
|
||||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::optimized_eltwise) {
|
||||
optimization_is_completed = true;
|
||||
}
|
||||
}
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (bin_and_eltw) {
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
float scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& cur_node = arg.get_dependency(cur_post_ops[cur_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_binary(cur_idx, alg, desc);
|
||||
p_ops.get_params_eltwise(prev_idx, scale, alg, alpha, beta);
|
||||
|
||||
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
|
||||
auto bin_ops_can_be_optimized = cur_node.is_type<data>() && cur_node.is_constant() &&
|
||||
cur_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
|
||||
|
||||
auto bin_add_and_eltw = alpha == 1.0f && scale == 1.0f && cur_type == onednn_post_op_type::binary_add && bin_ops_can_be_optimized;
|
||||
auto bin_mul_and_eltw = beta == 0.f && scale == 1.0f && cur_type == onednn_post_op_type::binary_mul && bin_ops_can_be_optimized;
|
||||
|
||||
if (bin_add_and_eltw || bin_mul_and_eltw) {
|
||||
memory::ptr cur_bin_mem_ptr = cur_node.as<data>().get_attached_memory_ptr();
|
||||
auto& stream = cur_bin_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::write> bin_and_eltw_lock(cur_bin_mem_ptr, stream);
|
||||
|
||||
size_t cur_bin_mem_size = cur_node.get_output_layout().count();
|
||||
|
||||
// Update all binary coefficients
|
||||
if (bin_add_and_eltw) {
|
||||
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
|
||||
bin_and_eltw_lock[data_idx] += beta;
|
||||
}
|
||||
} else {
|
||||
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
|
||||
bin_and_eltw_lock[data_idx] *= alpha;
|
||||
}
|
||||
}
|
||||
|
||||
// Marked previous eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (eltw_and_bin) {
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
float scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_eltwise(cur_idx, scale, alg, alpha, beta);
|
||||
p_ops.get_params_binary(prev_idx, alg, desc);
|
||||
|
||||
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
|
||||
auto bin_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
|
||||
prev_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
|
||||
|
||||
auto eltw_and_bin_add = alpha == 1.0f && scale == 1.0f && prev_type == onednn_post_op_type::binary_add && bin_ops_can_be_optimized;
|
||||
auto eltw_and_bin_mul = beta == 0.f && scale == 1.0f && prev_type == onednn_post_op_type::binary_mul && bin_ops_can_be_optimized;
|
||||
|
||||
if (eltw_and_bin_add || eltw_and_bin_mul) {
|
||||
memory::ptr prev_bin_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
|
||||
auto& stream = prev_bin_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::write> eltw_and_bin_lock(prev_bin_mem_ptr, stream);
|
||||
|
||||
size_t prev_bin_mem_size = prev_node.get_output_layout().count();
|
||||
|
||||
// Update all binary coefficients
|
||||
if (eltw_and_bin_add) {
|
||||
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
|
||||
eltw_and_bin_lock[data_idx] += beta;
|
||||
}
|
||||
} else {
|
||||
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
|
||||
eltw_and_bin_lock[data_idx] *= alpha;
|
||||
}
|
||||
}
|
||||
|
||||
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (eltw_and_sum) {
|
||||
dnnl::algorithm alg;
|
||||
float cur_scale, prev_scale, alpha, beta;
|
||||
dnnl::memory::data_type data_type;
|
||||
|
||||
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_eltwise(cur_idx, cur_scale, alg, alpha, beta);
|
||||
p_ops.get_params_sum(prev_idx, prev_scale, data_type);
|
||||
|
||||
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
|
||||
auto eltw_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
|
||||
prev_node.get_users().size() == 1;
|
||||
|
||||
// Eltwise can be inserted into the scale field of previous sum if cur_beta is equal to 0.f
|
||||
if (beta == 0.f && cur_scale == 1.0f && eltw_ops_can_be_optimized) {
|
||||
dnnl::post_ops sum_p_op;
|
||||
sum_p_op.append_sum(alpha * prev_scale, data_type);
|
||||
|
||||
// Insert cur eltwise into sum
|
||||
add_post_op(prev_type, sum_p_op, optimized_p_ops, 0);
|
||||
|
||||
memory::ptr prev_eltw_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
|
||||
auto& stream = prev_eltw_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::write> eltw_and_sum_lock(prev_eltw_mem_ptr, stream);
|
||||
|
||||
size_t prev_eltw_mem_size = prev_node.get_output_layout().count();
|
||||
|
||||
// Also multiply sum on alpha for getting valid results
|
||||
for (size_t data_idx = 0; data_idx < prev_eltw_mem_size; data_idx++) {
|
||||
eltw_and_sum_lock[data_idx] *= alpha;
|
||||
}
|
||||
|
||||
// Marked current and previous operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
|
||||
|
||||
// Set the flag if extra optimizations checking is needed
|
||||
if (cur_post_op_idx < post_ops_size - 1) {
|
||||
if (cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::eltwise_linear ||
|
||||
cur_post_ops[cur_post_op_idx + 1].op_type == onednn_post_op_type::optimized_eltwise) {
|
||||
optimization_is_completed = true;
|
||||
}
|
||||
}
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (eltw_and_scale) {
|
||||
dnnl::algorithm alg;
|
||||
float cur_scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_eltwise(cur_idx, cur_scale, alg, alpha, beta);
|
||||
|
||||
// Eltwise can be inserted into output_scale if cur_beta is equal to 0.f and cur_scale is equal to 1.0f
|
||||
if (beta == 0.f && cur_scale == 1.0f && prev_node.get_output_layout().data_type == data_types::f32) {
|
||||
memory::ptr prev_scale_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
|
||||
auto& stream = prev_scale_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::write> eltw_and_scale_lock(prev_scale_mem_ptr, stream);
|
||||
|
||||
size_t prev_scale_mem_size = prev_node.get_output_layout().count();
|
||||
|
||||
// Update all scale coefficients
|
||||
for (size_t data_idx = 0; data_idx < prev_scale_mem_size; data_idx++) {
|
||||
eltw_and_scale_lock[data_idx] *= alpha;
|
||||
}
|
||||
|
||||
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no optimizations have been applied then copy post-op info into the new optimized_p_ops structure
|
||||
if (!(has_out_scales(attr) && prev_post_op_idx == 0) && !cur_ops_pair_is_optimized) {
|
||||
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
|
||||
}
|
||||
|
||||
if (cur_post_op_idx == post_ops_size - 1 && !cur_ops_pair_is_optimized) {
|
||||
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
|
||||
optimization_done = true;
|
||||
} else if (cur_post_ops[cur_post_op_idx].op_type != onednn_post_op_type::optimized) {
|
||||
cur_post_op_idx++;
|
||||
prev_post_op_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
optimization_is_completed = !optimization_is_completed;
|
||||
|
||||
return optimized_p_ops;
|
||||
}
|
||||
|
||||
void configure_post_ops_arguments(typed_primitive_inst<PType>& instance, std::unordered_map<int, dnnl::memory>& args) const {
|
||||
// Get current node id for creating of optimization map
|
||||
auto node_id = instance.id();
|
||||
auto& engine = instance.get_network().get_engine();
|
||||
auto dnnl_engine = engine.get_onednn_engine();
|
||||
|
||||
// Get current post-ops info
|
||||
dnnl::post_ops post_ops = _attrs->get_post_ops();
|
||||
|
||||
// Create onednn memory buffers for post-ops
|
||||
auto& cur_post_ops = onednn_fusing_map[node_id];
|
||||
auto post_ops_size = cur_post_ops.size();
|
||||
for (size_t post_op_idx = 0, num_of_optimized_post_ops = 0; post_op_idx < post_ops_size; post_op_idx++) {
|
||||
auto post_op_type = cur_post_ops[post_op_idx].op_type;
|
||||
auto memory_offset = cur_post_ops[post_op_idx].mem_offset;
|
||||
auto onednn_post_op_idx = has_out_scales(_attrs) && post_op_idx > 0 ? post_op_idx - 1 : post_op_idx;
|
||||
onednn_post_op_idx -= num_of_optimized_post_ops;
|
||||
|
||||
switch (post_op_type) {
|
||||
case onednn_post_op_type::eltwise_act:
|
||||
case onednn_post_op_type::eltwise_clip:
|
||||
case onednn_post_op_type::eltwise_linear:
|
||||
case onednn_post_op_type::eltwise_round:
|
||||
{
|
||||
// onednn elwise doesn't need any data from memory buffers
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_add:
|
||||
case onednn_post_op_type::binary_mul:
|
||||
case onednn_post_op_type::binary_max:
|
||||
case onednn_post_op_type::binary_min:
|
||||
{
|
||||
auto binary_op_mem = instance.fused_memory(memory_offset);
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
post_ops.get_params_binary(static_cast<int>(onednn_post_op_idx), alg, desc);
|
||||
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(static_cast<int>(onednn_post_op_idx)) | DNNL_ARG_SRC_1,
|
||||
binary_op_mem->get_onednn_memory(desc)});
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::scale:
|
||||
{
|
||||
auto scale_op_mem = instance.fused_memory(memory_offset);
|
||||
dnnl::memory::desc desc = onednn::layout_to_memory_desc(scale_op_mem->get_layout(), dnnl::memory::format_tag::a, true);
|
||||
args.insert({DNNL_ARG_ATTR_OUTPUT_SCALES, scale_op_mem->get_onednn_memory(desc)});
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::sum:
|
||||
case onednn_post_op_type::optimized_sum:
|
||||
case onednn_post_op_type::optimized_eltwise:
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::optimized:
|
||||
{
|
||||
// Optimized post-op, count it to respect onednn_post_op_idx in the next operations
|
||||
num_of_optimized_post_ops++;
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported onednn post-operation type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual std::unordered_map<int, dnnl::memory> get_arguments(typed_primitive_inst<PType>& instance) const {
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
auto& engine = instance.get_network().get_engine();
|
||||
auto dnnl_engine = engine.get_onednn_engine();
|
||||
|
||||
for (size_t i = 0; i < instance.inputs_memory_count(); i++) {
|
||||
auto& input = instance.input_memory(i);
|
||||
args.insert({DNNL_ARG_SRC, input.get_onednn_memory(_pd.dnnl::primitive_desc_base::src_desc(static_cast<int>(i)))});
|
||||
{
|
||||
auto& input = instance.input_memory(0);
|
||||
args.insert({DNNL_ARG_SRC, input.get_onednn_memory(_pd.dnnl::primitive_desc_base::src_desc(0))});
|
||||
}
|
||||
|
||||
{
|
||||
@ -80,15 +529,281 @@ protected:
|
||||
args.insert({DNNL_ARG_DST, output.get_onednn_memory(_pd.dnnl::primitive_desc_base::dst_desc(0))});
|
||||
}
|
||||
|
||||
configure_post_ops_arguments(instance, args);
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
void init_kernels() override { }
|
||||
|
||||
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<PType>& /* arg */) {
|
||||
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<PType>& arg) {
|
||||
const std::vector<fused_primitive_desc>& cldnn_post_ops = arg.get_fused_primitives();
|
||||
auto attrs = std::make_shared<dnnl::primitive_attr>();
|
||||
dnnl::post_ops post_ops;
|
||||
size_t memory_offset = 0;
|
||||
|
||||
// Create onednn post-ops list related to the current node
|
||||
std::vector<onednn_post_op_desc> fused_ops;
|
||||
|
||||
// Added this for debug purposes only
|
||||
size_t empty_mem = 0xff;
|
||||
|
||||
// Add information about post-operation into the list, update indices
|
||||
auto update_onednn_post_op_list = [&](onednn_post_op_type type, size_t m_dep) {
|
||||
onednn_post_op_desc cur_op_desc = { type, memory_offset, m_dep };
|
||||
fused_ops.push_back(cur_op_desc);
|
||||
|
||||
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
|
||||
type == onednn_post_op_type::binary_mul ||
|
||||
type == onednn_post_op_type::binary_max ||
|
||||
type == onednn_post_op_type::binary_min ||
|
||||
type == onednn_post_op_type::scale ||
|
||||
type == onednn_post_op_type::sum;
|
||||
if (has_memory_buffers)
|
||||
memory_offset++;
|
||||
};
|
||||
|
||||
for (size_t idx = 0; idx < cldnn_post_ops.size(); idx++) {
|
||||
auto node = cldnn_post_ops[idx].node;
|
||||
|
||||
if (node->is_type<activation>()) {
|
||||
auto fused_desc = node->as<activation>().get_primitive();
|
||||
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
|
||||
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||
} else if (node->is_type<eltwise>()) {
|
||||
auto& e_node = node->as<eltwise>();
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
auto in = arg.get_dependency(dep_idx).get_output_layout();
|
||||
|
||||
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
|
||||
if (e_node.get_primitive()->needs_onednn_sum_post_op(in)) {
|
||||
post_ops.append_sum(1.0f, onednn::convert_data_type(in.data_type));
|
||||
update_onednn_post_op_list(onednn_post_op_type::sum, dep_idx);
|
||||
} else {
|
||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
|
||||
}
|
||||
} else {
|
||||
if (in.size.spatial[0] > 1 || in.size.spatial[1] > 1 || in.size.batch[0] > 1)
|
||||
throw std::runtime_error("Unsupported eltwise mode for fused onednn op");
|
||||
if (idx == 0 && !has_out_scales(attrs)) {
|
||||
int mask = in.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
|
||||
} else {
|
||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
|
||||
}
|
||||
}
|
||||
} else if (node->is_type<quantize>()) {
|
||||
auto& q_node = node->as<quantize>();
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
|
||||
if (q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val()) {
|
||||
// 1. pre-scale & pre-shift
|
||||
{
|
||||
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_per_tensor_input_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
if (idx == 0 && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
|
||||
arg.type() == convolution::type_id() &&
|
||||
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
|
||||
int mask = in_scale.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
|
||||
} else {
|
||||
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_pre_shift()) {
|
||||
if (q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. round
|
||||
auto out_dt = cldnn_post_ops[idx].output_layout.data_type;
|
||||
bool output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
|
||||
if (!output_type_is_int8) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
|
||||
}
|
||||
|
||||
// 3. post-scale & post-shift
|
||||
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
|
||||
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_need_post_scale()) {
|
||||
if (q_node.get_per_tensor_output_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_post_shift()) {
|
||||
if (q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. clamp
|
||||
if (q_node.get_need_clamp()) {
|
||||
float out_lo = q_node.get_need_min_clamp() ? q_node.get_output_lo_val() : data_type_traits::min<float>(out_dt);
|
||||
float out_hi = q_node.get_need_max_clamp() ? q_node.get_output_hi_val() : data_type_traits::max<float>(out_dt);
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip, out_lo, out_hi);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_clip, empty_mem);
|
||||
}
|
||||
} else {
|
||||
// 1. clamp
|
||||
if (q_node.get_need_clamp()) {
|
||||
auto in_lo = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
auto in_hi = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::algorithm clamp_max = dnnl::algorithm::binary_max;
|
||||
dnnl::algorithm clamp_min = dnnl::algorithm::binary_min;
|
||||
dnnl::memory::desc in_lo_desc = onednn::layout_to_memory_desc(in_lo, dnnl::memory::format_tag::ab, true);
|
||||
dnnl::memory::desc in_hi_desc = onednn::layout_to_memory_desc(in_hi, dnnl::memory::format_tag::ab, true);
|
||||
|
||||
post_ops.append_binary(clamp_max, in_lo_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_max, dep_idx - 2);
|
||||
post_ops.append_binary(clamp_min, in_hi_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_min, dep_idx - 1);
|
||||
}
|
||||
|
||||
// 2. pre-scale & pre-shift
|
||||
{
|
||||
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_per_tensor_input_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
if (idx == 0 && !q_node.get_need_clamp() && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
|
||||
arg.type() == convolution::type_id() &&
|
||||
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
|
||||
int mask = in_scale.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
|
||||
} else {
|
||||
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_pre_shift()) {
|
||||
if (q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. round
|
||||
{
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
|
||||
}
|
||||
|
||||
// 4. post-scale & post-shift
|
||||
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
|
||||
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_need_post_scale()) {
|
||||
if (q_node.get_per_tensor_output_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_post_shift()) {
|
||||
if (q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported fused op for onednn prim");
|
||||
}
|
||||
}
|
||||
|
||||
// Update total onednn post-ops info
|
||||
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops));
|
||||
|
||||
// Trying to optimize more than 1 post-ops
|
||||
auto post_ops_size = onednn_fusing_map[arg.id()].size();
|
||||
|
||||
if (post_ops_size > 1) {
|
||||
dnnl::post_ops optimized_post_ops = post_ops;
|
||||
bool optimization_is_finished = false;
|
||||
|
||||
// Trying to combine multiplications and additions which are placed one after another.
|
||||
// We do it in the cycle because "eltw + eltw" cases can be simplified again in some cases.
|
||||
do {
|
||||
optimized_post_ops = try_optimize_post_ops(arg, optimized_post_ops, attrs, optimization_is_finished);
|
||||
} while (!optimization_is_finished);
|
||||
|
||||
attrs->set_post_ops(optimized_post_ops);
|
||||
} else {
|
||||
// Set post-ops without any optimizations
|
||||
attrs->set_post_ops(post_ops);
|
||||
}
|
||||
|
||||
return attrs;
|
||||
}
|
||||
@ -104,7 +819,8 @@ protected:
|
||||
}
|
||||
|
||||
void set_arguments_impl(typed_primitive_inst<PType>& instance) override {
|
||||
_args = get_arguments(instance);
|
||||
uint32_t net_id = instance.get_network().get_id();
|
||||
_args[net_id] = get_arguments(instance);
|
||||
}
|
||||
|
||||
event::ptr execute_impl(const std::vector<event::ptr>& /* events */,
|
||||
@ -113,6 +829,7 @@ protected:
|
||||
auto& engine = network.get_engine();
|
||||
auto& stream = network.get_stream();
|
||||
auto profiling = engine.configuration().enable_profiling;
|
||||
auto net_id = network.get_id();
|
||||
event::ptr event;
|
||||
|
||||
if (profiling) {
|
||||
@ -120,7 +837,7 @@ protected:
|
||||
event = stream.create_user_event(false);
|
||||
}
|
||||
|
||||
_prim.execute(stream.get_onednn_stream(), _args);
|
||||
_prim.execute(stream.get_onednn_stream(), _args[net_id]);
|
||||
|
||||
if (profiling) {
|
||||
stream.finish();
|
||||
|
@ -29,28 +29,36 @@ public:
|
||||
bool get_need_post_scale() const { return need_post_scale; }
|
||||
bool get_need_post_shift() const { return need_post_shift; }
|
||||
bool get_need_clamp() const { return need_clamp; }
|
||||
bool get_need_min_clamp() const { return need_min_clamp; }
|
||||
bool get_need_max_clamp() const { return need_max_clamp; }
|
||||
bool get_per_tensor_input_scale() const { return per_tensor_input_scale; }
|
||||
bool get_per_tensor_input_shift() const { return per_tensor_input_shift; }
|
||||
bool get_per_tensor_input_range() const { return per_tensor_input_range; }
|
||||
bool get_per_tensor_output_scale() const { return per_tensor_output_scale; }
|
||||
bool get_per_tensor_output_shift() const { return per_tensor_output_shift; }
|
||||
bool get_per_tensor_output_range() const { return per_tensor_output_range; }
|
||||
float get_input_scale_val() const { return in_scale; }
|
||||
float get_input_shift_val() const { return in_shift; }
|
||||
float get_input_lo_val() const { return in_lo; }
|
||||
float get_input_hi_val() const { return in_hi; }
|
||||
float get_output_scale_val() const { return out_scale; }
|
||||
float get_output_shift_val() const { return out_shift; }
|
||||
float get_output_lo_val() const { return out_lo; }
|
||||
float get_output_hi_val() const { return out_hi; }
|
||||
|
||||
void set_scale_shift_opt() { scale_shift_opt = true; }
|
||||
void set_need_post_scale() { need_post_scale = true; }
|
||||
void set_need_post_shift() { need_post_shift = true; }
|
||||
void set_need_pre_shift() { need_pre_shift = true; }
|
||||
void set_need_clamp() { need_clamp = true; }
|
||||
void set_need_min_clamp() { need_min_clamp = true; }
|
||||
void set_need_max_clamp() { need_max_clamp = true; }
|
||||
void set_per_tensor_input_scale() { per_tensor_input_scale = true; }
|
||||
void set_per_tensor_input_shift() { per_tensor_input_shift = true; }
|
||||
void set_per_tensor_input_range() { per_tensor_input_range = true; }
|
||||
void set_per_tensor_output_scale() { per_tensor_output_scale = true; }
|
||||
void set_per_tensor_output_shift() { per_tensor_output_shift = true; }
|
||||
void set_per_tensor_output_range() { per_tensor_output_range = true; }
|
||||
// Clamp is needed to avoid inf and -inf which are converted to undefined "inf" constant in opencl
|
||||
void set_input_scale_val(float val) { in_scale = clamp(val); }
|
||||
void set_input_shift_val(float val) { in_shift = clamp(val); }
|
||||
@ -58,6 +66,8 @@ public:
|
||||
void set_input_hi_val(float val) { in_hi = clamp(val); }
|
||||
void set_output_scale_val(float val) { out_scale = clamp(val); }
|
||||
void set_output_shift_val(float val) { out_shift = clamp(val); }
|
||||
void set_output_lo_val(float val) { out_lo = clamp(val); }
|
||||
void set_output_hi_val(float val) { out_hi = clamp(val); }
|
||||
|
||||
std::shared_ptr<kernel_selector::fuse_params> get_fuse_params() const override {
|
||||
return std::make_shared<kernel_selector::quantize_fuse_params>(scale_shift_opt,
|
||||
@ -65,15 +75,20 @@ public:
|
||||
need_post_shift,
|
||||
need_pre_shift,
|
||||
need_clamp,
|
||||
need_min_clamp,
|
||||
need_max_clamp,
|
||||
per_tensor_input_range,
|
||||
per_tensor_input_scale,
|
||||
per_tensor_input_shift,
|
||||
per_tensor_output_range,
|
||||
per_tensor_output_scale,
|
||||
per_tensor_output_shift,
|
||||
in_lo,
|
||||
in_hi,
|
||||
in_scale,
|
||||
in_shift,
|
||||
out_lo,
|
||||
out_hi,
|
||||
out_scale,
|
||||
out_shift);
|
||||
}
|
||||
@ -88,10 +103,13 @@ private:
|
||||
bool need_post_shift = false;
|
||||
bool need_pre_shift = false;
|
||||
bool need_clamp = false;
|
||||
bool need_min_clamp = false;
|
||||
bool need_max_clamp = false;
|
||||
|
||||
bool per_tensor_input_range = false;
|
||||
bool per_tensor_input_scale = false;
|
||||
bool per_tensor_input_shift = false;
|
||||
bool per_tensor_output_range = false;
|
||||
bool per_tensor_output_scale = false;
|
||||
bool per_tensor_output_shift = false;
|
||||
|
||||
@ -99,6 +117,8 @@ private:
|
||||
float in_hi = 0.0f;
|
||||
float in_scale = 0.0f;
|
||||
float in_shift = 0.0f;
|
||||
float out_lo = 0.0f;
|
||||
float out_hi = 0.0f;
|
||||
float out_scale = 0.0f;
|
||||
float out_shift = 0.0f;
|
||||
};
|
||||
|
@ -24,6 +24,59 @@
|
||||
|
||||
using namespace cldnn;
|
||||
|
||||
static size_t get_post_ops_count(const program_node& node) {
|
||||
size_t onednn_post_ops_count = 0;
|
||||
for (auto& fo : node.get_fused_primitives()) {
|
||||
if (fo.node->is_type<activation>() || fo.node->is_type<eltwise>()) {
|
||||
onednn_post_ops_count++;
|
||||
} else if (fo.node->is_type<quantize>()) {
|
||||
auto& q = fo.node->as<quantize>();
|
||||
|
||||
// pre-scale, pre-shift
|
||||
if (q.get_per_tensor_input_scale() && q.get_per_tensor_input_shift()) {
|
||||
onednn_post_ops_count++;
|
||||
} else {
|
||||
onednn_post_ops_count += 2;
|
||||
}
|
||||
|
||||
// post-scale, post-shift
|
||||
if (q.get_need_post_scale() && q.get_need_post_shift() &&
|
||||
q.get_per_tensor_output_scale() && q.get_per_tensor_output_shift()) {
|
||||
onednn_post_ops_count++;
|
||||
} else {
|
||||
onednn_post_ops_count += 2;
|
||||
}
|
||||
|
||||
auto out_dt = fo.output_layout.data_type;
|
||||
auto output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
|
||||
auto out_range_usage = q.get_per_tensor_output_range() && q.get_output_lo_val() < q.get_output_hi_val();
|
||||
|
||||
if (out_range_usage) {
|
||||
// round
|
||||
if (!output_type_is_int8) {
|
||||
onednn_post_ops_count++;
|
||||
}
|
||||
|
||||
// clamp
|
||||
if (q.get_need_clamp()) {
|
||||
onednn_post_ops_count++;
|
||||
}
|
||||
} else {
|
||||
// clamp
|
||||
if (q.get_need_clamp()) {
|
||||
onednn_post_ops_count += 2;
|
||||
}
|
||||
// round
|
||||
{
|
||||
onednn_post_ops_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return onednn_post_ops_count;
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<reorder>, bool> reorder_factory::get_reorder(primitive_id src_id,
|
||||
const layout& in_layout,
|
||||
const layout& out_layout
|
||||
|
@ -473,6 +473,47 @@ void network::allocate_primitives() {
|
||||
for (auto const& node : nodes_to_allocate) {
|
||||
allocate_primitive_instance(*node);
|
||||
}
|
||||
|
||||
for (auto const& node : _program->get_processing_order()) {
|
||||
if (node->get_preferred_impl_type() == impl_types::onednn) {
|
||||
bool can_reuse_eltwise_mem = false;
|
||||
size_t eltw_dep = 0;
|
||||
|
||||
for (auto& fused_op : node->get_fused_primitives()) {
|
||||
if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
|
||||
auto eltw_in_layout = node->get_dependency(fused_op.dep_start_idx).get_output_layout();
|
||||
auto out_layout = node->get_output_layout();
|
||||
|
||||
if (eltw_in_layout.size == out_layout.size &&
|
||||
eltw_in_layout.format == out_layout.format &&
|
||||
eltw_in_layout.data_padding == out_layout.data_padding &&
|
||||
data_type_traits::size_of(eltw_in_layout.data_type) == data_type_traits::size_of(out_layout.data_type)) {
|
||||
if (eltw_dep > 0) {
|
||||
throw std::runtime_error("Unsupported multiple full size tensors.");
|
||||
}
|
||||
eltw_dep = fused_op.dep_start_idx;
|
||||
can_reuse_eltwise_mem = true;
|
||||
}
|
||||
|
||||
if (fused_op.node->as<eltwise>().get_primitive()->needs_onednn_sum_post_op(eltw_in_layout) && !can_reuse_eltwise_mem) {
|
||||
throw std::runtime_error("Buffer reuse is required for onednn sum post operation.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (can_reuse_eltwise_mem) {
|
||||
auto& eltw_in = node->get_dependency(eltw_dep);
|
||||
if (_primitives.find(eltw_in.id()) != _primitives.end() && _primitives.find(node->id()) != _primitives.end()) {
|
||||
auto& eltw_inst = _primitives.at(eltw_in.id());
|
||||
auto& prim_inst = _primitives.at(node->id());
|
||||
auto& eltw_mem = eltw_inst->output_memory();
|
||||
auto new_mem = eltw_mem.get_engine()->reinterpret_buffer(eltw_mem, node->get_output_layout());
|
||||
|
||||
prim_inst->set_output_memory(new_mem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void network::build_insts_deps() {
|
||||
|
@ -1025,9 +1025,10 @@ void program::fuse_nodes(program_node &fused_node, program_node &peer_node, std:
|
||||
quantize_node& q_node = peer_node.as<quantize>();
|
||||
if (q_node.get_scale_shift_opt()) {
|
||||
bool can_drop_input = false;
|
||||
bool out_range_usage = q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val();
|
||||
|
||||
// Drop input range if clamp is not needed
|
||||
can_drop_input |= (i == 1 || i == 2) && !q_node.get_need_clamp();
|
||||
// Drop input range if we use output per-tensor range or if clamp is used for input range
|
||||
can_drop_input |= (i == 1 || i == 2) && (out_range_usage || (!out_range_usage && !q_node.get_need_clamp()));
|
||||
// Drop output range - it's not used in scale-shift-opt quantize kernel
|
||||
can_drop_input |= i == 3 || i == 4;
|
||||
// Drop tensor with input scale when we have per-tensor parameter
|
||||
|
Loading…
Reference in New Issue
Block a user