[GPU] Moving onednn post ops logic into cldnn::program_node (#8511)

This commit is contained in:
Ilya Znamenskiy 2021-11-12 22:48:55 +03:00 committed by GitHub
parent 2c9a4c59f2
commit c565799c71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 804 additions and 723 deletions

View File

@ -0,0 +1,20 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#include "pass_manager.h"
#include "program_node.h"
using namespace cldnn;
void add_onednn_optimization_attributes::run(program& p) {
#ifdef ENABLE_ONEDNN_FOR_GPU
for (auto& node : p.get_processing_order()) {
if (node->get_preferred_impl_type() == impl_types::onednn) {
node->init_onednn_primitive_attributes();
}
}
#endif // ENABLE_ONEDNN_FOR_GPU
}

View File

@ -36,7 +36,7 @@ void basic_memory_dependencies::run(program& p) {
add_memory_dependency(it, node);
}
if (node->is_type<convolution>()) {
if (node->is_type<convolution>() && node->get_preferred_impl_type() == impl_types::onednn) {
auto& conv = node->as<convolution>();
bool can_reuse_eltwise_mem = false;
size_t eltw_dep = 0;
@ -59,6 +59,7 @@ void basic_memory_dependencies::run(program& p) {
}
}
}
if (can_reuse_eltwise_mem) {
auto& eltw_node = conv.get_dependency(eltw_dep);
eltw_node.can_share_buffer(false);

View File

@ -73,7 +73,7 @@ protected:
public:
static primitive_impl* create(const concatenation_node& arg) {
auto desc = get_concatenation_descriptor(arg);
auto attr = get_primitive_attributes(arg);
auto attr = arg.get_onednn_primitive_attributes();
std::shared_ptr<void> dummy = nullptr;

View File

@ -47,6 +47,7 @@ protected:
std::unordered_map<int, dnnl::memory> get_arguments(convolution_inst& instance) const override {
std::unordered_map<int, dnnl::memory> args = parent::get_arguments(instance);
auto attrs = instance.get_node().get_onednn_primitive_attributes();
{
auto weights = instance.weights_memory(0);
@ -58,13 +59,13 @@ protected:
args.insert({DNNL_ARG_BIAS, bias->get_onednn_memory(_pd.weights_desc(1))});
}
if (has_zero_points(DNNL_ARG_SRC, _attrs)) {
if (has_zero_points(DNNL_ARG_SRC, attrs)) {
auto a_zp = instance.activations_zero_points_memory(0);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(a_zp->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, a_zp->get_onednn_memory(desc)});
}
if (has_zero_points(DNNL_ARG_WEIGHTS, _attrs)) {
if (has_zero_points(DNNL_ARG_WEIGHTS, attrs)) {
auto w_zp = instance.weights_zero_points_memory(0);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(w_zp->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, w_zp->get_onednn_memory(desc)});
@ -74,7 +75,7 @@ protected:
}
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<convolution>& arg) {
auto attrs = parent::get_primitive_attributes(arg);
auto attrs = arg.get_onednn_primitive_attributes();
if (arg.activations_zero_points_term()) {
auto& a_zp = arg.activations_zero_points();

View File

@ -62,7 +62,7 @@ protected:
}
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<deconvolution>& arg) {
auto attrs = parent::get_primitive_attributes(arg);
auto attrs = arg.get_onednn_primitive_attributes();
return attrs;
}

View File

@ -142,7 +142,7 @@ public:
}
}
auto attr = get_primitive_attributes(arg);
auto attr = arg.get_onednn_primitive_attributes();
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
return new fully_connected_onednn(arg, desc, attr, prim_desc, get_weights_reorder(arg, prim_desc));

View File

@ -71,7 +71,7 @@ public:
static primitive_impl* create(const pooling_node& arg) {
auto& engine = arg.get_program().get_engine();
auto desc = get_pooling_descriptor(arg);
auto attr = get_primitive_attributes(arg);
auto attr = arg.get_onednn_primitive_attributes();
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
return new pooling_onednn(arg, desc, attr, prim_desc);

View File

@ -26,36 +26,6 @@
namespace cldnn {
namespace onednn {
enum class onednn_post_op_type : uint32_t {
eltwise_act,
eltwise_clip,
eltwise_linear,
eltwise_round,
binary_mul,
binary_add,
binary_max,
binary_min,
scale,
sum,
optimized,
optimized_eltwise,
optimized_sum
};
struct onednn_post_op_desc {
onednn_post_op_type op_type;
size_t mem_offset;
size_t mem_dep;
};
// This map contains information about onednn post-ops types, memory buffer offsets and dependencies
// key is cldnn::primitive_id,
// value is an instance of struct onednn_post_op_desc containing info about post-ops related to the node defined by key:
// op_type - onednn_post_op_type (enum),
// mem_offset - index of memory buffer for current post-operation,
// mem_dep - memory dependency for working with fused node
static std::unordered_map<cldnn::primitive_id, std::vector<onednn_post_op_desc>> onednn_fusing_map;
template <class PType, class DescType, class PrimDescType = dnnl::primitive_desc, class PrimType = dnnl::primitive>
struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
const typed_program_node<PType>& _outer;
@ -82,7 +52,7 @@ struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
protected:
virtual bool optimized_out(typed_primitive_inst<PType>&) const { return false; }
static bool has_out_scales(const std::shared_ptr<dnnl::primitive_attr>& attr) {
static bool has_output_scales(const std::shared_ptr<dnnl::primitive_attr>& attr) {
int mask;
std::vector<float> scales;
attr->get_output_scales(mask, scales);
@ -98,408 +68,22 @@ protected:
return !zp.empty() && (reinterpret_cast<const int32_t&>(zp[0]) == drsv);
}
static dnnl::post_ops try_optimize_post_ops(const typed_program_node<PType>& arg, dnnl::post_ops& p_ops,
const std::shared_ptr<dnnl::primitive_attr>& attr,
bool& optimization_is_completed) {
// Get current node id for creating of optimization map
auto node_id = arg.id();
// Create new dnnl::post_ops object which will be filled inside the optimization process
dnnl::post_ops optimized_p_ops;
// Add new post-op into optimized_p_ops structure
auto add_post_op = [&](onednn_post_op_type type, const dnnl::post_ops& cur_p_ops, dnnl::post_ops& new_p_ops, int idx) {
switch (type) {
case onednn_post_op_type::eltwise_act:
case onednn_post_op_type::eltwise_clip:
case onednn_post_op_type::eltwise_linear:
case onednn_post_op_type::eltwise_round:
{
dnnl::algorithm alg;
float scale, alpha, beta;
cur_p_ops.get_params_eltwise(idx, scale, alg, alpha, beta);
new_p_ops.append_eltwise(scale, alg, alpha, beta);
break;
}
case onednn_post_op_type::binary_add:
case onednn_post_op_type::binary_mul:
case onednn_post_op_type::binary_max:
case onednn_post_op_type::binary_min:
{
dnnl::algorithm alg;
dnnl::memory::desc desc;
cur_p_ops.get_params_binary(idx, alg, desc);
new_p_ops.append_binary(alg, desc);
break;
}
case onednn_post_op_type::scale:
{
break;
}
case onednn_post_op_type::sum:
{
float scale;
dnnl::memory::data_type data_type;
cur_p_ops.get_params_sum(idx, scale, data_type);
new_p_ops.append_sum(scale, data_type);
break;
}
case onednn_post_op_type::optimized:
case onednn_post_op_type::optimized_sum:
case onednn_post_op_type::optimized_eltwise:
{
// Current operation already has been optimized => don't need extra actions
break;
}
default:
throw std::runtime_error("Unsupported onednn post-operation type");
}
};
// Check that post-op type is any optimized
auto type_is_any_optimized = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::optimized || type == onednn_post_op_type::optimized_sum ||
type == onednn_post_op_type::optimized_eltwise;
};
// Check that post-op type is eltwise
auto type_is_eltwise = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::eltwise_round || type == onednn_post_op_type::eltwise_linear ||
type == onednn_post_op_type::eltwise_clip || type == onednn_post_op_type::eltwise_act;
};
// Check that post-op type is binary_add or binary_mul
auto type_is_binary_add_or_mul = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::binary_add || type == onednn_post_op_type::binary_mul;
};
// Simple post-op type checks
auto type_is_optimized = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized; };
auto type_is_eltwise_linear = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::eltwise_linear; };
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_eltwise; };
auto type_is_binary_add = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_add; };
auto type_is_binary_mul = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_mul; };
auto type_is_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::sum; };
auto type_is_optimized_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_sum; };
auto type_is_scale = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::scale; };
auto& cur_post_ops = onednn_fusing_map[node_id];
size_t cur_post_op_idx = 1;
size_t prev_post_op_idx = 0;
bool optimization_done = false;
// Check and update post-op map if we already optimized something
for (size_t post_op_idx = 0; post_op_idx < cur_post_ops.size(); post_op_idx++) {
if (type_is_optimized_sum(cur_post_ops[post_op_idx].op_type))
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
else if (type_is_optimized_eltwise(cur_post_ops[post_op_idx].op_type))
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
else if (type_is_optimized(cur_post_ops[post_op_idx].op_type))
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
}
// Get post-ops size for current node
auto post_ops_size = cur_post_ops.size();
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
while (!optimization_done) {
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
auto prev_type = cur_post_ops[prev_post_op_idx].op_type;
// Ignore optimized operations for "previous" operation in our operation pair
while (type_is_any_optimized(prev_type) && cur_post_op_idx < post_ops_size - 1) {
prev_post_op_idx++;
cur_post_op_idx++;
prev_type = cur_post_ops[prev_post_op_idx].op_type;
cur_type = cur_post_ops[cur_post_op_idx].op_type;
}
// Ignore optimized operations for "current" operation in our operation pair
while (type_is_any_optimized(cur_type) && cur_post_op_idx < post_ops_size - 1) {
cur_post_op_idx++;
cur_type = cur_post_ops[cur_post_op_idx].op_type;
}
auto cur_idx = static_cast<int>(has_out_scales(attr) ? (cur_post_op_idx >= 1 ? cur_post_op_idx - 1 : 0) : cur_post_op_idx);
auto prev_idx = static_cast<int>(has_out_scales(attr) ? (prev_post_op_idx >= 1 ? prev_post_op_idx - 1 : 0) : prev_post_op_idx);
// If this is the last pair and it's optimized - add the last post-op and go out from the cycle
if (cur_post_op_idx == post_ops_size - 1 && (type_is_any_optimized(cur_type) || type_is_any_optimized(prev_type))) {
if (!type_is_any_optimized(prev_type)) {
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
}
if (!type_is_any_optimized(cur_type)) {
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
}
break;
}
// Post-ops combinations which can be simplified
auto eltw_and_eltw = type_is_eltwise(cur_type) && type_is_eltwise(prev_type);
auto bin_and_eltw = type_is_binary_add_or_mul(cur_type) && type_is_eltwise_linear(prev_type);
auto eltw_and_bin = type_is_eltwise_linear(cur_type) && type_is_binary_add_or_mul(prev_type);
auto sum_and_eltw = type_is_sum(cur_type) && type_is_eltwise(prev_type);
auto eltw_and_scale = type_is_eltwise_linear(cur_type) && type_is_scale(prev_type);
auto can_try_optimize = eltw_and_eltw ||
bin_and_eltw ||
eltw_and_bin ||
sum_and_eltw ||
eltw_and_scale;
bool cur_ops_pair_is_optimized = false;
if (can_try_optimize) {
if (eltw_and_eltw) {
dnnl::algorithm cur_alg, prev_alg;
float cur_scale, prev_scale, cur_alpha, prev_alpha, cur_beta, prev_beta;
p_ops.get_params_eltwise(prev_idx, prev_scale, prev_alg, prev_alpha, prev_beta);
p_ops.get_params_eltwise(cur_idx, cur_scale, cur_alg, cur_alpha, cur_beta);
auto eltw_linear_and_eltw_linear = type_is_eltwise_linear(cur_type) && type_is_eltwise_linear(prev_type);
auto eltw_linear_and_eltw_non_linear = type_is_eltwise_linear(cur_type) && !type_is_eltwise_linear(prev_type) && cur_beta == 0;
// eltwise_linear + eltwise_linear combination can be optimized always
if (eltw_linear_and_eltw_linear) {
dnnl::post_ops eltw_p_op;
float optimized_alpha = cur_alpha * prev_alpha * prev_scale;
float optimized_beta = cur_alpha * prev_beta * prev_scale + cur_beta;
float optimized_scale = cur_scale;
eltw_p_op.append_eltwise(optimized_scale, cur_alg, optimized_alpha, optimized_beta);
// Combine 2 eltwises into one
add_post_op(cur_type, eltw_p_op, optimized_p_ops, 0);
} else if (eltw_linear_and_eltw_non_linear) {
dnnl::post_ops eltw_p_op;
eltw_p_op.append_eltwise(cur_scale * prev_scale * cur_alpha, prev_alg, prev_alpha, prev_beta);
// Combine 2 eltwises into one
add_post_op(prev_type, eltw_p_op, optimized_p_ops, 0);
}
if (eltw_linear_and_eltw_linear || eltw_linear_and_eltw_non_linear) {
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
// Set the flag if extra optimizations checking is needed
if (cur_post_op_idx < post_ops_size - 1) {
if (type_is_eltwise_linear(cur_post_ops[cur_post_op_idx + 1].op_type) ||
type_is_binary_add_or_mul(cur_post_ops[cur_post_op_idx + 1].op_type) ||
type_is_optimized_eltwise(cur_post_ops[cur_post_op_idx + 1].op_type)) {
optimization_is_completed = true;
}
}
cur_ops_pair_is_optimized = true;
}
} else if (bin_and_eltw) {
dnnl::algorithm alg;
dnnl::memory::desc desc;
float scale, alpha, beta;
cldnn::program_node& cur_node = arg.get_dependency(cur_post_ops[cur_post_op_idx].mem_dep);
p_ops.get_params_binary(cur_idx, alg, desc);
p_ops.get_params_eltwise(prev_idx, scale, alg, alpha, beta);
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
auto bin_ops_can_be_optimized = cur_node.is_type<data>() && cur_node.is_constant() &&
cur_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
auto bin_add_and_eltw = alpha == 1.0f && scale == 1.0f && type_is_binary_add(cur_type) && bin_ops_can_be_optimized;
auto bin_mul_and_eltw = beta == 0.f && type_is_binary_mul(cur_type) && bin_ops_can_be_optimized;
if (bin_add_and_eltw || bin_mul_and_eltw) {
memory::ptr cur_bin_mem_ptr = cur_node.as<data>().get_attached_memory_ptr();
if (cur_bin_mem_ptr == nullptr)
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for bin + eltw");
auto& stream = cur_bin_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::write> bin_and_eltw_lock(cur_bin_mem_ptr, stream);
size_t cur_bin_mem_size = cur_node.get_output_layout().count();
// Update all binary coefficients
if (bin_add_and_eltw) {
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
bin_and_eltw_lock[data_idx] += beta;
}
} else {
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
bin_and_eltw_lock[data_idx] *= alpha * scale;
}
}
// Marked previous eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
} else if (eltw_and_bin) {
dnnl::algorithm alg;
dnnl::memory::desc desc;
float scale, alpha, beta;
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
p_ops.get_params_eltwise(cur_idx, scale, alg, alpha, beta);
p_ops.get_params_binary(prev_idx, alg, desc);
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
auto bin_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
prev_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
auto eltw_and_bin_add = alpha == 1.0f && scale == 1.0f && type_is_binary_add(prev_type) && bin_ops_can_be_optimized;
auto eltw_and_bin_mul = beta == 0.f && type_is_binary_mul(prev_type) && bin_ops_can_be_optimized;
if (eltw_and_bin_add || eltw_and_bin_mul) {
memory::ptr prev_bin_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
if (prev_bin_mem_ptr == nullptr)
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + bin");
auto& stream = prev_bin_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::write> eltw_and_bin_lock(prev_bin_mem_ptr, stream);
size_t prev_bin_mem_size = prev_node.get_output_layout().count();
// Update all binary coefficients
if (eltw_and_bin_add) {
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
eltw_and_bin_lock[data_idx] += beta;
}
} else {
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
eltw_and_bin_lock[data_idx] *= alpha * scale;
}
}
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
} else if (sum_and_eltw) {
dnnl::algorithm alg;
float sum_scale, eltw_scale, alpha, beta;
dnnl::memory::data_type data_type;
dnnl::algorithm next_alg;
float next_scale, next_alpha, next_beta;
size_t next_idx = cur_idx + 1;
size_t next_post_op_idx = cur_post_op_idx + 1;
bool can_optimize_eltw_and_sum = false;
if (cur_post_op_idx < post_ops_size - 1) {
auto next_type = cur_post_ops[next_post_op_idx].op_type;
if (type_is_eltwise_linear(next_type)) {
p_ops.get_params_eltwise(next_idx, next_scale, next_alg, next_alpha, next_beta);
if (next_beta == 0)
can_optimize_eltw_and_sum = true;
}
}
// Try to optimize eltwise (any) + sum + eltwise_linear (with beta = 0) chain of operations
if (can_optimize_eltw_and_sum) {
p_ops.get_params_sum(cur_idx, sum_scale, data_type);
p_ops.get_params_eltwise(prev_idx, eltw_scale, alg, alpha, beta);
dnnl::post_ops eltw_p_op_prev, sum_p_op;
eltw_p_op_prev.append_eltwise(eltw_scale * next_alpha * next_scale, alg, alpha, beta);
sum_p_op.append_sum(sum_scale * next_alpha, data_type);
add_post_op(prev_type, eltw_p_op_prev, optimized_p_ops, 0);
add_post_op(cur_type, sum_p_op, optimized_p_ops, 0);
// Marked current, previous and next operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
cur_post_ops[next_post_op_idx].op_type = onednn_post_op_type::optimized;
// Set the flag if extra optimizations checking is needed
if (next_post_op_idx < post_ops_size - 1) {
if (type_is_eltwise_linear(cur_post_ops[next_post_op_idx + 1].op_type) ||
type_is_optimized_eltwise(cur_post_ops[next_post_op_idx + 1].op_type)) {
optimization_is_completed = true;
}
}
cur_ops_pair_is_optimized = true;
}
} else if (eltw_and_scale) {
dnnl::algorithm alg;
float eltw_scale, alpha, beta;
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
p_ops.get_params_eltwise(cur_idx, eltw_scale, alg, alpha, beta);
// Eltwise can be inserted into the output_scale if cur_beta is equal to 0.f
if (beta == 0.f && prev_node.get_output_layout().data_type == data_types::f32) {
memory::ptr prev_scale_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
if (prev_scale_mem_ptr == nullptr)
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + scale");
auto& stream = prev_scale_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::write> eltw_and_scale_lock(prev_scale_mem_ptr, stream);
size_t prev_scale_mem_size = prev_node.get_output_layout().count();
// Update all scale coefficients
for (size_t data_idx = 0; data_idx < prev_scale_mem_size; data_idx++) {
eltw_and_scale_lock[data_idx] *= alpha * eltw_scale;
}
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
}
}
// If no optimizations have been applied then copy post-op info into the new optimized_p_ops structure
if (!(has_out_scales(attr) && prev_post_op_idx == 0) && !cur_ops_pair_is_optimized) {
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
}
if (cur_post_op_idx == post_ops_size - 1 && !cur_ops_pair_is_optimized) {
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
optimization_done = true;
} else if (cur_post_ops[cur_post_op_idx].op_type != onednn_post_op_type::optimized) {
cur_post_op_idx++;
prev_post_op_idx++;
}
}
optimization_is_completed = !optimization_is_completed;
return optimized_p_ops;
}
void configure_post_ops_arguments(typed_primitive_inst<PType>& instance, std::unordered_map<int, dnnl::memory>& args) const {
// Get current node id for creating of optimization map
auto node_id = instance.id();
auto& node = instance.get_node();
auto& engine = instance.get_network().get_engine();
auto dnnl_engine = engine.get_onednn_engine();
// Get current post-ops info
dnnl::post_ops post_ops = _attrs->get_post_ops();
auto onednn_attrs = node.get_onednn_primitive_attributes();
dnnl::post_ops post_ops = onednn_attrs->get_post_ops();
// Create onednn memory buffers for post-ops
auto& cur_post_ops = onednn_fusing_map[node_id];
auto& cur_post_ops = node.get_fused_primitives_onednn();
auto post_ops_size = cur_post_ops.size();
for (size_t post_op_idx = 0, num_of_optimized_post_ops = 0; post_op_idx < post_ops_size; post_op_idx++) {
auto post_op_type = cur_post_ops[post_op_idx].op_type;
auto memory_offset = cur_post_ops[post_op_idx].mem_offset;
auto onednn_post_op_idx = has_out_scales(_attrs) && post_op_idx > 0 ? post_op_idx - 1 : post_op_idx;
auto onednn_post_op_idx = has_output_scales(onednn_attrs) && post_op_idx > 0 ? post_op_idx - 1 : post_op_idx;
onednn_post_op_idx -= num_of_optimized_post_ops;
switch (post_op_type) {
@ -576,296 +160,6 @@ protected:
void init_kernels() override { }
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<PType>& arg) {
const std::vector<fused_primitive_desc>& cldnn_post_ops = arg.get_fused_primitives();
auto attrs = std::make_shared<dnnl::primitive_attr>();
dnnl::post_ops post_ops;
size_t memory_offset = 0;
// Create onednn post-ops list related to the current node
std::vector<onednn_post_op_desc> fused_ops;
// Added this for debug purposes only
size_t empty_mem = 0xff;
// Add information about post-operation into the list, update indices
auto update_onednn_post_op_list = [&](onednn_post_op_type type, size_t m_dep) {
onednn_post_op_desc cur_op_desc = { type, memory_offset, m_dep };
fused_ops.push_back(cur_op_desc);
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
type == onednn_post_op_type::binary_mul ||
type == onednn_post_op_type::binary_max ||
type == onednn_post_op_type::binary_min ||
type == onednn_post_op_type::scale ||
type == onednn_post_op_type::sum;
if (has_memory_buffers)
memory_offset++;
};
for (size_t idx = 0; idx < cldnn_post_ops.size(); idx++) {
auto node = cldnn_post_ops[idx].node;
if (node->is_type<activation>()) {
auto fused_desc = node->as<activation>().get_primitive();
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
} else if (node->is_type<eltwise>()) {
auto& e_node = node->as<eltwise>();
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
auto in = arg.get_dependency(dep_idx).get_output_layout();
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
if (e_node.get_primitive()->needs_onednn_sum_post_op(in)) {
post_ops.append_sum(1.0f, onednn::convert_data_type(in.data_type));
update_onednn_post_op_list(onednn_post_op_type::sum, dep_idx);
} else {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
}
} else {
if (in.size.spatial[0] > 1 || in.size.spatial[1] > 1 || in.size.batch[0] > 1)
throw std::runtime_error("Unsupported eltwise mode for fused onednn op");
if (idx == 0 && !has_out_scales(attrs)) {
int mask = in.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
} else {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
}
}
} else if (node->is_type<quantize>()) {
auto& q_node = node->as<quantize>();
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
if (q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val()) {
// 1. pre-scale & pre-shift
{
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_per_tensor_input_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
if (idx == 0 && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
arg.type() == convolution::type_id() &&
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
int mask = in_scale.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
} else {
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_pre_shift()) {
if (q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
// 2. round
auto out_dt = cldnn_post_ops[idx].output_layout.data_type;
bool output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
if (!output_type_is_int8) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
}
// 3. post-scale & post-shift
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_need_post_scale()) {
if (q_node.get_per_tensor_output_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_post_shift()) {
if (q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
// 4. clamp
if (q_node.get_need_clamp()) {
float out_lo = q_node.get_need_min_clamp() ? q_node.get_output_lo_val() : data_type_traits::min<float>(out_dt);
float out_hi = q_node.get_need_max_clamp() ? q_node.get_output_hi_val() : data_type_traits::max<float>(out_dt);
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip, out_lo, out_hi);
update_onednn_post_op_list(onednn_post_op_type::eltwise_clip, empty_mem);
}
} else {
// 1. clamp
if (q_node.get_need_clamp()) {
auto in_lo = arg.get_dependency(dep_idx++).get_output_layout();
auto in_hi = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::algorithm clamp_max = dnnl::algorithm::binary_max;
dnnl::algorithm clamp_min = dnnl::algorithm::binary_min;
dnnl::memory::desc in_lo_desc = onednn::layout_to_memory_desc(in_lo, dnnl::memory::format_tag::ab, true);
dnnl::memory::desc in_hi_desc = onednn::layout_to_memory_desc(in_hi, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(clamp_max, in_lo_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_max, dep_idx - 2);
post_ops.append_binary(clamp_min, in_hi_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_min, dep_idx - 1);
}
// 2. pre-scale & pre-shift
{
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_per_tensor_input_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
if (idx == 0 && !q_node.get_need_clamp() && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
arg.type() == convolution::type_id() &&
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
int mask = in_scale.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
} else {
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_pre_shift()) {
if (q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
// 3. round
{
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
}
// 4. post-scale & post-shift
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_need_post_scale()) {
if (q_node.get_per_tensor_output_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_post_shift()) {
if (q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
} else if (node->is_type<reorder>()) {
continue;
} else {
throw std::runtime_error("Unsupported fused op of " + node->get_primitive()->type_string() + " type for oneDNN primitive");
}
}
if (cldnn_post_ops.size() && arg.get_fused_activations_funcs().size())
throw std::runtime_error("Unsupported mix of fused ops and activations");
for (size_t i = 0; i < arg.get_fused_activations_funcs().size(); i++) {
auto activation_type = arg.get_fused_activations_funcs()[i];
auto params = arg.get_fused_activations_params()[i];
dnnl::algorithm alg = onednn::convert_activation_func(activation_type);
post_ops.append_eltwise(1.0f, alg, params.a, params.b);
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
}
// Update total onednn post-ops info
auto it = onednn_fusing_map.find(arg.id());
if (it != onednn_fusing_map.end()) {
it->second = std::move(fused_ops);
} else {
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops));
}
// Trying to optimize more than 1 post-ops
auto post_ops_size = onednn_fusing_map[arg.id()].size();
if (post_ops_size > 1) {
dnnl::post_ops optimized_post_ops = post_ops;
bool optimization_is_finished = false;
// Trying to combine multiplications and additions which are placed one after another.
// We do it in the cycle because "eltw + eltw" cases can be simplified again in some cases.
do {
optimized_post_ops = try_optimize_post_ops(arg, optimized_post_ops, attrs, optimization_is_finished);
} while (!optimization_is_finished);
attrs->set_post_ops(optimized_post_ops);
} else {
// Set post-ops without any optimizations
attrs->set_post_ops(post_ops);
}
return attrs;
}
event::ptr aggregate_events(const std::vector<event::ptr>& events, stream& stream, bool group = false, bool is_output = false) const {
if (events.size() == 1 && !is_output)
return events[0];

View File

@ -55,13 +55,13 @@ protected:
input_md,
engine.get_onednn_engine(),
output_md,
*get_primitive_attributes(arg));
*(arg.get_onednn_primitive_attributes()));
}
public:
static primitive_impl* create(const reorder_node& arg) {
auto desc = get_reorder_descriptor(arg);
auto attr = get_primitive_attributes(arg);
auto attr = arg.get_onednn_primitive_attributes();
std::shared_ptr<void> dummy = nullptr;

View File

@ -393,4 +393,10 @@ private:
void run(program& p) override;
};
class add_onednn_optimization_attributes : public base_pass {
public:
add_onednn_optimization_attributes() : base_pass("add_onednn_optimization_attributes") {}
void run(program& p) override;
};
} // namespace cldnn

View File

@ -33,6 +33,29 @@ struct typed_program_node;
class json_composite;
class xml_composite;
#ifdef ENABLE_ONEDNN_FOR_GPU
enum class onednn_post_op_type : uint32_t {
eltwise_act,
eltwise_clip,
eltwise_linear,
eltwise_round,
binary_mul,
binary_add,
binary_max,
binary_min,
scale,
sum,
optimized,
optimized_eltwise,
optimized_sum
};
struct fused_primitive_desc_onednn {
onednn_post_op_type op_type; // onednn post-operation type
size_t mem_offset; // index of a memory buffer for current post-operation
size_t mem_dep; // memory dependency for working with fused node
};
#endif // ENABLE_ONEDNN_FOR_GPU
struct fused_primitive_desc {
std::shared_ptr<program_node> node;
@ -57,7 +80,7 @@ struct fused_primitive_desc {
to API level where all primitives store only ids of related ones.
*/
struct program_node {
friend struct program; // to be removed when possible
friend struct program; // to be removed when possible
friend class compile_graph; // to be removed when possible
friend class graph_initializations; // to be removed when possible
friend class pre_replace_deconv; // to be removed when possible
@ -293,6 +316,16 @@ public:
const std::vector<fused_primitive_desc>& get_fused_primitives() const { return fused_prims; }
std::vector<fused_primitive_desc>& get_fused_primitives() { return fused_prims; }
#ifdef ENABLE_ONEDNN_FOR_GPU
const std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() const { return onednn_attrs; }
std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() { return onednn_attrs; }
const std::vector<fused_primitive_desc_onednn>& get_fused_primitives_onednn() const { return fused_prims_onednn; }
std::vector<fused_primitive_desc_onednn>& get_fused_primitives_onednn() { return fused_prims_onednn; }
void init_onednn_primitive_attributes();
#endif // ENABLE_ONEDNN_FOR_GPU
size_t get_fused_inputs_count() const {
size_t count = 0;
for (auto& fp : get_fused_primitives()) {
@ -360,7 +393,26 @@ protected:
std::vector<fused_activation_params> fused_activations;
std::vector<fused_primitive_desc> fused_prims;
void invalidate_users() const;
private:
#ifdef ENABLE_ONEDNN_FOR_GPU
std::vector<fused_primitive_desc_onednn> fused_prims_onednn;
std::shared_ptr<dnnl::primitive_attr> onednn_attrs;
void add_onednn_fused_primitives(std::vector<fused_primitive_desc_onednn> descs) {
fused_prims_onednn.erase(fused_prims_onednn.begin(), fused_prims_onednn.end());
fused_prims_onednn.insert(fused_prims_onednn.end(), descs.begin(), descs.end());
}
void add_onednn_attrs(std::shared_ptr<dnnl::primitive_attr> attrs) {
onednn_attrs = attrs;
}
bool has_out_scales(const std::shared_ptr<dnnl::primitive_attr>& attr);
dnnl::post_ops try_optimize_post_ops(dnnl::post_ops& p_ops, const std::shared_ptr<dnnl::primitive_attr>& attr, bool& optimization_is_completed);
#endif // ENABLE_ONEDNN_FOR_GPU
};
/*

View File

@ -534,6 +534,9 @@ void program::pre_optimize_graph(bool is_internal) {
// check if there exists some layout incompatibilities and add an reorder node if required
apply_opt_pass<add_required_reorders>();
// add optimization attributes for onednn primitives
apply_opt_pass<add_onednn_optimization_attributes>();
}
void program::post_optimize_graph(bool is_internal) {

View File

@ -5,6 +5,14 @@
#include "program_node.h"
#include "cldnn/graph/program.hpp"
#include "primitive_inst.h"
#ifdef ENABLE_ONEDNN_FOR_GPU
#include "convolution_inst.h"
#include "quantize_inst.h"
#include "reorder_inst.h"
#include <impls/onednn/utils.hpp>
#endif // ENABLE_ONEDNN_FOR_GPU
#include "to_string_utils.h"
#include "json_object.h"
#include <vector>
@ -280,3 +288,699 @@ bool program_node::need_lockable_memory() const {
return need_lockable_mem;
}
/* ----------------------------------------- */
/* Onednn fused operations integration logic */
/* ----------------------------------------- */
#ifdef ENABLE_ONEDNN_FOR_GPU
bool program_node::has_out_scales(const std::shared_ptr<dnnl::primitive_attr>& attr) {
int mask;
std::vector<float> scales;
attr->get_output_scales(mask, scales);
const auto drfv = reinterpret_cast<const int32_t&>(DNNL_RUNTIME_F32_VAL);
return !scales.empty() && (reinterpret_cast<const int32_t&>(scales[0]) == drfv);
}
dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const std::shared_ptr<dnnl::primitive_attr>& attr,
bool& optimization_is_completed) {
// Create new dnnl::post_ops object which will be filled inside the optimization process
dnnl::post_ops optimized_p_ops;
// Add new post-op into optimized_p_ops structure
auto add_post_op = [&](onednn_post_op_type type, const dnnl::post_ops& cur_p_ops, dnnl::post_ops& new_p_ops, int idx) {
switch (type) {
case onednn_post_op_type::eltwise_act:
case onednn_post_op_type::eltwise_clip:
case onednn_post_op_type::eltwise_linear:
case onednn_post_op_type::eltwise_round:
{
dnnl::algorithm alg;
float scale, alpha, beta;
cur_p_ops.get_params_eltwise(idx, scale, alg, alpha, beta);
new_p_ops.append_eltwise(scale, alg, alpha, beta);
break;
}
case onednn_post_op_type::binary_add:
case onednn_post_op_type::binary_mul:
case onednn_post_op_type::binary_max:
case onednn_post_op_type::binary_min:
{
dnnl::algorithm alg;
dnnl::memory::desc desc;
cur_p_ops.get_params_binary(idx, alg, desc);
new_p_ops.append_binary(alg, desc);
break;
}
case onednn_post_op_type::scale:
{
break;
}
case onednn_post_op_type::sum:
{
float scale;
dnnl::memory::data_type data_type;
cur_p_ops.get_params_sum(idx, scale, data_type);
new_p_ops.append_sum(scale, data_type);
break;
}
case onednn_post_op_type::optimized:
case onednn_post_op_type::optimized_sum:
case onednn_post_op_type::optimized_eltwise:
{
// Current operation already has been optimized => don't need extra actions
break;
}
default:
throw std::runtime_error("Unsupported onednn post-operation type");
}
};
// Check that post-op type is any optimized
auto type_is_any_optimized = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::optimized || type == onednn_post_op_type::optimized_sum ||
type == onednn_post_op_type::optimized_eltwise;
};
// Check that post-op type is eltwise
auto type_is_eltwise = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::eltwise_round || type == onednn_post_op_type::eltwise_linear ||
type == onednn_post_op_type::eltwise_clip || type == onednn_post_op_type::eltwise_act;
};
// Check that post-op type is binary_add or binary_mul
auto type_is_binary_add_or_mul = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::binary_add || type == onednn_post_op_type::binary_mul;
};
// Simple post-op type checks
auto type_is_optimized = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized; };
auto type_is_eltwise_linear = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::eltwise_linear; };
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_eltwise; };
auto type_is_binary_add = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_add; };
auto type_is_binary_mul = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_mul; };
auto type_is_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::sum; };
auto type_is_optimized_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_sum; };
auto type_is_scale = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::scale; };
auto& cur_post_ops = get_fused_primitives_onednn();
size_t cur_post_op_idx = 1;
size_t prev_post_op_idx = 0;
bool optimization_done = false;
// Check and update post-op map if we already optimized something
for (size_t post_op_idx = 0; post_op_idx < cur_post_ops.size(); post_op_idx++) {
if (type_is_optimized_sum(cur_post_ops[post_op_idx].op_type))
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
else if (type_is_optimized_eltwise(cur_post_ops[post_op_idx].op_type))
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
else if (type_is_optimized(cur_post_ops[post_op_idx].op_type))
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
}
// Get post-ops size for current node
auto post_ops_size = cur_post_ops.size();
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
while (!optimization_done) {
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
auto prev_type = cur_post_ops[prev_post_op_idx].op_type;
// Ignore optimized operations for "previous" operation in our operation pair
while (type_is_any_optimized(prev_type) && cur_post_op_idx < post_ops_size - 1) {
prev_post_op_idx++;
cur_post_op_idx++;
prev_type = cur_post_ops[prev_post_op_idx].op_type;
cur_type = cur_post_ops[cur_post_op_idx].op_type;
}
// Ignore optimized operations for "current" operation in our operation pair
while (type_is_any_optimized(cur_type) && cur_post_op_idx < post_ops_size - 1) {
cur_post_op_idx++;
cur_type = cur_post_ops[cur_post_op_idx].op_type;
}
auto cur_idx = static_cast<int>(has_out_scales(attr) ? (cur_post_op_idx >= 1 ? cur_post_op_idx - 1 : 0) : cur_post_op_idx);
auto prev_idx = static_cast<int>(has_out_scales(attr) ? (prev_post_op_idx >= 1 ? prev_post_op_idx - 1 : 0) : prev_post_op_idx);
// If this is the last pair and it's optimized - add the last post-op and go out from the cycle
if (cur_post_op_idx == post_ops_size - 1 && (type_is_any_optimized(cur_type) || type_is_any_optimized(prev_type))) {
if (!type_is_any_optimized(prev_type)) {
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
}
if (!type_is_any_optimized(cur_type)) {
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
}
break;
}
// Post-ops combinations which can be simplified
auto eltw_and_eltw = type_is_eltwise(cur_type) && type_is_eltwise(prev_type);
auto bin_and_eltw = type_is_binary_add_or_mul(cur_type) && type_is_eltwise_linear(prev_type);
auto eltw_and_bin = type_is_eltwise_linear(cur_type) && type_is_binary_add_or_mul(prev_type);
auto sum_and_eltw = type_is_sum(cur_type) && type_is_eltwise(prev_type);
auto eltw_and_scale = type_is_eltwise_linear(cur_type) && type_is_scale(prev_type);
auto can_try_optimize = eltw_and_eltw ||
bin_and_eltw ||
eltw_and_bin ||
sum_and_eltw ||
eltw_and_scale;
bool cur_ops_pair_is_optimized = false;
if (can_try_optimize) {
if (eltw_and_eltw) {
dnnl::algorithm cur_alg, prev_alg;
float cur_scale, prev_scale, cur_alpha, prev_alpha, cur_beta, prev_beta;
p_ops.get_params_eltwise(prev_idx, prev_scale, prev_alg, prev_alpha, prev_beta);
p_ops.get_params_eltwise(cur_idx, cur_scale, cur_alg, cur_alpha, cur_beta);
auto eltw_linear_and_eltw_linear = type_is_eltwise_linear(cur_type) && type_is_eltwise_linear(prev_type);
auto eltw_linear_and_eltw_non_linear = type_is_eltwise_linear(cur_type) && !type_is_eltwise_linear(prev_type) && cur_beta == 0;
// eltwise_linear + eltwise_linear combination can be optimized always
if (eltw_linear_and_eltw_linear) {
dnnl::post_ops eltw_p_op;
float optimized_alpha = cur_alpha * prev_alpha * prev_scale;
float optimized_beta = cur_alpha * prev_beta * prev_scale + cur_beta;
float optimized_scale = cur_scale;
eltw_p_op.append_eltwise(optimized_scale, cur_alg, optimized_alpha, optimized_beta);
// Combine 2 eltwises into one
add_post_op(cur_type, eltw_p_op, optimized_p_ops, 0);
} else if (eltw_linear_and_eltw_non_linear) {
dnnl::post_ops eltw_p_op;
eltw_p_op.append_eltwise(cur_scale * prev_scale * cur_alpha, prev_alg, prev_alpha, prev_beta);
// Combine 2 eltwises into one
add_post_op(prev_type, eltw_p_op, optimized_p_ops, 0);
}
if (eltw_linear_and_eltw_linear || eltw_linear_and_eltw_non_linear) {
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
// Set the flag if extra optimizations checking is needed
if (cur_post_op_idx < post_ops_size - 1) {
if (type_is_eltwise_linear(cur_post_ops[cur_post_op_idx + 1].op_type) ||
type_is_binary_add_or_mul(cur_post_ops[cur_post_op_idx + 1].op_type) ||
type_is_optimized_eltwise(cur_post_ops[cur_post_op_idx + 1].op_type)) {
optimization_is_completed = true;
}
}
cur_ops_pair_is_optimized = true;
}
} else if (bin_and_eltw) {
dnnl::algorithm alg;
dnnl::memory::desc desc;
float scale, alpha, beta;
cldnn::program_node& cur_node = get_dependency(cur_post_ops[cur_post_op_idx].mem_dep);
p_ops.get_params_binary(cur_idx, alg, desc);
p_ops.get_params_eltwise(prev_idx, scale, alg, alpha, beta);
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
auto bin_ops_can_be_optimized = cur_node.is_type<data>() && cur_node.is_constant() &&
cur_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
auto bin_add_and_eltw = alpha == 1.0f && scale == 1.0f && type_is_binary_add(cur_type) && bin_ops_can_be_optimized;
auto bin_mul_and_eltw = beta == 0.f && type_is_binary_mul(cur_type) && bin_ops_can_be_optimized;
if (bin_add_and_eltw || bin_mul_and_eltw) {
memory::ptr cur_bin_mem_ptr = cur_node.as<data>().get_attached_memory_ptr();
if (cur_bin_mem_ptr == nullptr)
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for bin + eltw");
auto& stream = cur_bin_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::read_write> bin_and_eltw_lock(cur_bin_mem_ptr, stream);
size_t cur_bin_mem_size = cur_node.get_output_layout().count();
// Update all binary coefficients
if (bin_add_and_eltw) {
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
bin_and_eltw_lock[data_idx] += beta;
}
} else {
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
bin_and_eltw_lock[data_idx] *= alpha * scale;
}
}
// Marked previous eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
} else if (eltw_and_bin) {
dnnl::algorithm alg;
dnnl::memory::desc desc;
float scale, alpha, beta;
cldnn::program_node& prev_node = get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
p_ops.get_params_eltwise(cur_idx, scale, alg, alpha, beta);
p_ops.get_params_binary(prev_idx, alg, desc);
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
auto bin_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
prev_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
auto eltw_and_bin_add = alpha == 1.0f && scale == 1.0f && type_is_binary_add(prev_type) && bin_ops_can_be_optimized;
auto eltw_and_bin_mul = beta == 0.f && type_is_binary_mul(prev_type) && bin_ops_can_be_optimized;
if (eltw_and_bin_add || eltw_and_bin_mul) {
memory::ptr prev_bin_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
if (prev_bin_mem_ptr == nullptr)
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + bin");
auto& stream = prev_bin_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::read_write> eltw_and_bin_lock(prev_bin_mem_ptr, stream);
size_t prev_bin_mem_size = prev_node.get_output_layout().count();
// Update all binary coefficients
if (eltw_and_bin_add) {
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
eltw_and_bin_lock[data_idx] += beta;
}
} else {
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
eltw_and_bin_lock[data_idx] *= alpha * scale;
}
}
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
} else if (sum_and_eltw) {
dnnl::algorithm alg;
float sum_scale, eltw_scale, alpha, beta;
dnnl::memory::data_type data_type;
dnnl::algorithm next_alg;
float next_scale, next_alpha, next_beta;
size_t next_idx = cur_idx + 1;
size_t next_post_op_idx = cur_post_op_idx + 1;
bool can_optimize_eltw_and_sum = false;
if (cur_post_op_idx < post_ops_size - 1) {
auto next_type = cur_post_ops[next_post_op_idx].op_type;
if (type_is_eltwise_linear(next_type)) {
p_ops.get_params_eltwise(next_idx, next_scale, next_alg, next_alpha, next_beta);
if (next_beta == 0)
can_optimize_eltw_and_sum = true;
}
}
// Try to optimize eltwise (any) + sum + eltwise_linear (with beta = 0) chain of operations
if (can_optimize_eltw_and_sum) {
p_ops.get_params_sum(cur_idx, sum_scale, data_type);
p_ops.get_params_eltwise(prev_idx, eltw_scale, alg, alpha, beta);
dnnl::post_ops eltw_p_op_prev, sum_p_op;
eltw_p_op_prev.append_eltwise(eltw_scale * next_alpha * next_scale, alg, alpha, beta);
sum_p_op.append_sum(sum_scale * next_alpha, data_type);
add_post_op(prev_type, eltw_p_op_prev, optimized_p_ops, 0);
add_post_op(cur_type, sum_p_op, optimized_p_ops, 0);
// Marked current, previous and next operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
cur_post_ops[next_post_op_idx].op_type = onednn_post_op_type::optimized;
// Set the flag if extra optimizations checking is needed
if (next_post_op_idx < post_ops_size - 1) {
if (type_is_eltwise_linear(cur_post_ops[next_post_op_idx + 1].op_type) ||
type_is_optimized_eltwise(cur_post_ops[next_post_op_idx + 1].op_type)) {
optimization_is_completed = true;
}
}
cur_ops_pair_is_optimized = true;
}
} else if (eltw_and_scale) {
dnnl::algorithm alg;
float eltw_scale, alpha, beta;
cldnn::program_node& prev_node = get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
p_ops.get_params_eltwise(cur_idx, eltw_scale, alg, alpha, beta);
// Eltwise can be inserted into the output_scale if cur_beta is equal to 0.f
if (beta == 0.f && prev_node.get_output_layout().data_type == data_types::f32) {
memory::ptr prev_scale_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
if (prev_scale_mem_ptr == nullptr)
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + scale");
auto& stream = prev_scale_mem_ptr->get_engine()->get_program_stream();
mem_lock<float, mem_lock_type::read_write> eltw_and_scale_lock(prev_scale_mem_ptr, stream);
size_t prev_scale_mem_size = prev_node.get_output_layout().count();
// Update all scale coefficients
for (size_t data_idx = 0; data_idx < prev_scale_mem_size; data_idx++) {
eltw_and_scale_lock[data_idx] *= alpha * eltw_scale;
}
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_ops_pair_is_optimized = true;
}
}
}
// If no optimizations have been applied then copy post-op info into the new optimized_p_ops structure
if (!(has_out_scales(attr) && prev_post_op_idx == 0) && !cur_ops_pair_is_optimized) {
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
}
if (cur_post_op_idx == post_ops_size - 1 && !cur_ops_pair_is_optimized) {
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
optimization_done = true;
} else if (cur_post_ops[cur_post_op_idx].op_type != onednn_post_op_type::optimized) {
cur_post_op_idx++;
prev_post_op_idx++;
}
}
optimization_is_completed = !optimization_is_completed;
add_onednn_fused_primitives(cur_post_ops);
return optimized_p_ops;
}
void program_node::init_onednn_primitive_attributes() {
const std::vector<fused_primitive_desc>& cldnn_post_ops = get_fused_primitives();
auto attrs = std::make_shared<dnnl::primitive_attr>();
dnnl::post_ops post_ops;
size_t memory_offset = 0;
// Create onednn post-ops list related to the current node
std::vector<fused_primitive_desc_onednn> fused_ops;
// Added this for debug purposes only
size_t empty_mem = 0xff;
// Add information about post-operation into the list, update indices
auto update_onednn_post_op_list = [&](onednn_post_op_type type, size_t m_dep) {
fused_primitive_desc_onednn cur_op_desc = { type, memory_offset, m_dep };
fused_ops.push_back(cur_op_desc);
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
type == onednn_post_op_type::binary_mul ||
type == onednn_post_op_type::binary_max ||
type == onednn_post_op_type::binary_min ||
type == onednn_post_op_type::scale ||
type == onednn_post_op_type::sum;
if (has_memory_buffers)
memory_offset++;
};
for (size_t idx = 0; idx < cldnn_post_ops.size(); idx++) {
auto node = cldnn_post_ops[idx].node;
if (node->is_type<activation>()) {
auto fused_desc = node->as<activation>().get_primitive();
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
} else if (node->is_type<eltwise>()) {
auto& e_node = node->as<eltwise>();
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
auto in = get_dependency(dep_idx).get_output_layout();
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
if (e_node.get_primitive()->needs_onednn_sum_post_op(in)) {
post_ops.append_sum(1.0f, onednn::convert_data_type(in.data_type));
update_onednn_post_op_list(onednn_post_op_type::sum, dep_idx);
} else {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
}
} else {
if (in.size.spatial[0] > 1 || in.size.spatial[1] > 1 || in.size.batch[0] > 1)
throw std::runtime_error("Unsupported eltwise mode for fused onednn op");
if (idx == 0 && !has_out_scales(attrs)) {
int mask = in.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
} else {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
}
}
} else if (node->is_type<quantize>()) {
auto& q_node = node->as<quantize>();
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
// ********************************* Common case with output range usage ********************************* //
if (q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val()) {
// 1. pre-scale & pre-shift
{
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_per_tensor_input_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_scale = get_dependency(dep_idx++).get_output_layout();
if (idx == 0 && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
is_type<convolution>() &&
!data_type_traits::is_floating_point(get_dependency(0).get_output_layout().data_type)) {
int mask = in_scale.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
} else {
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_pre_shift()) {
if (q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_shift = get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
// 2. round
auto out_dt = cldnn_post_ops[idx].output_layout.data_type;
{
bool output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
if (!output_type_is_int8) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
}
}
// 3. post-scale & post-shift
{
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_need_post_scale()) {
if (q_node.get_per_tensor_output_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_scale = get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_post_shift()) {
if (q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_shift = get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
// 4. clamp
{
if (q_node.get_need_clamp()) {
float out_lo = q_node.get_need_min_clamp() ? q_node.get_output_lo_val() : data_type_traits::min<float>(out_dt);
float out_hi = q_node.get_need_max_clamp() ? q_node.get_output_hi_val() : data_type_traits::max<float>(out_dt);
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip, out_lo, out_hi);
update_onednn_post_op_list(onednn_post_op_type::eltwise_clip, empty_mem);
}
}
// ********************************* Rare case with input range usage ********************************* //
} else {
// 1. clamp
{
if (q_node.get_need_clamp()) {
auto in_lo = get_dependency(dep_idx++).get_output_layout();
auto in_hi = get_dependency(dep_idx++).get_output_layout();
dnnl::algorithm clamp_max = dnnl::algorithm::binary_max;
dnnl::algorithm clamp_min = dnnl::algorithm::binary_min;
dnnl::memory::desc in_lo_desc = onednn::layout_to_memory_desc(in_lo, dnnl::memory::format_tag::ab, true);
dnnl::memory::desc in_hi_desc = onednn::layout_to_memory_desc(in_hi, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(clamp_max, in_lo_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_max, dep_idx - 2);
post_ops.append_binary(clamp_min, in_hi_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_min, dep_idx - 1);
}
}
// 2. pre-scale & pre-shift
{
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_per_tensor_input_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_scale = get_dependency(dep_idx++).get_output_layout();
if (idx == 0 && !q_node.get_need_clamp() && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
is_type<convolution>() &&
!data_type_traits::is_floating_point(get_dependency(0).get_output_layout().data_type)) {
int mask = in_scale.count() > 1 ? 2 : 0;
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
} else {
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_pre_shift()) {
if (q_node.get_per_tensor_input_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto in_shift = get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
// 3. round
{
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
}
// 4. post-scale & post-shift
{
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
if (q_node.get_need_post_scale()) {
if (q_node.get_per_tensor_output_scale()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_scale = get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
}
}
if (q_node.get_need_post_shift()) {
if (q_node.get_per_tensor_output_shift()) {
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
} else {
auto out_shift = get_dependency(dep_idx++).get_output_layout();
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
}
}
}
}
}
} else if (node->is_type<reorder>()) {
continue;
} else {
throw std::runtime_error("Unsupported fused op of " + node->get_primitive()->type_string() + " type for oneDNN primitive");
}
}
if (cldnn_post_ops.size() && get_fused_activations_funcs().size())
throw std::runtime_error("Unsupported mix of fused ops and activations");
for (size_t i = 0; i < get_fused_activations_funcs().size(); i++) {
auto activation_type = get_fused_activations_funcs()[i];
auto params = get_fused_activations_params()[i];
dnnl::algorithm alg = onednn::convert_activation_func(activation_type);
post_ops.append_eltwise(1.0f, alg, params.a, params.b);
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
}
// Trying to optimize more than 1 post-ops
if (fused_ops.size() > 1) {
dnnl::post_ops optimized_post_ops = post_ops;
bool optimization_is_finished = false;
add_onednn_fused_primitives(fused_ops);
// Trying to combine multiplications and additions which are placed one after another.
// We do it in the cycle because some optimization cases can be simplified again from time to time
do {
optimized_post_ops = try_optimize_post_ops(optimized_post_ops, attrs, optimization_is_finished);
} while (!optimization_is_finished);
attrs->set_post_ops(optimized_post_ops);
} else {
// Set post-ops without any optimizations
add_onednn_fused_primitives(fused_ops);
attrs->set_post_ops(post_ops);
}
add_onednn_attrs(attrs);
}
#endif // ENABLE_ONEDNN_FOR_GPU