[GPU] Moving onednn post ops logic into cldnn::program_node (#8511)
This commit is contained in:
parent
2c9a4c59f2
commit
c565799c71
20
inference-engine/thirdparty/clDNN/src/graph_optimizer/add_onednn_optimization_attributes.cpp
vendored
Normal file
20
inference-engine/thirdparty/clDNN/src/graph_optimizer/add_onednn_optimization_attributes.cpp
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "pass_manager.h"
|
||||
#include "program_node.h"
|
||||
|
||||
using namespace cldnn;
|
||||
|
||||
void add_onednn_optimization_attributes::run(program& p) {
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
for (auto& node : p.get_processing_order()) {
|
||||
if (node->get_preferred_impl_type() == impl_types::onednn) {
|
||||
node->init_onednn_primitive_attributes();
|
||||
}
|
||||
}
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
}
|
@ -36,7 +36,7 @@ void basic_memory_dependencies::run(program& p) {
|
||||
add_memory_dependency(it, node);
|
||||
}
|
||||
|
||||
if (node->is_type<convolution>()) {
|
||||
if (node->is_type<convolution>() && node->get_preferred_impl_type() == impl_types::onednn) {
|
||||
auto& conv = node->as<convolution>();
|
||||
bool can_reuse_eltwise_mem = false;
|
||||
size_t eltw_dep = 0;
|
||||
@ -59,6 +59,7 @@ void basic_memory_dependencies::run(program& p) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (can_reuse_eltwise_mem) {
|
||||
auto& eltw_node = conv.get_dependency(eltw_dep);
|
||||
eltw_node.can_share_buffer(false);
|
||||
|
@ -73,7 +73,7 @@ protected:
|
||||
public:
|
||||
static primitive_impl* create(const concatenation_node& arg) {
|
||||
auto desc = get_concatenation_descriptor(arg);
|
||||
auto attr = get_primitive_attributes(arg);
|
||||
auto attr = arg.get_onednn_primitive_attributes();
|
||||
|
||||
std::shared_ptr<void> dummy = nullptr;
|
||||
|
||||
|
@ -47,6 +47,7 @@ protected:
|
||||
|
||||
std::unordered_map<int, dnnl::memory> get_arguments(convolution_inst& instance) const override {
|
||||
std::unordered_map<int, dnnl::memory> args = parent::get_arguments(instance);
|
||||
auto attrs = instance.get_node().get_onednn_primitive_attributes();
|
||||
|
||||
{
|
||||
auto weights = instance.weights_memory(0);
|
||||
@ -58,13 +59,13 @@ protected:
|
||||
args.insert({DNNL_ARG_BIAS, bias->get_onednn_memory(_pd.weights_desc(1))});
|
||||
}
|
||||
|
||||
if (has_zero_points(DNNL_ARG_SRC, _attrs)) {
|
||||
if (has_zero_points(DNNL_ARG_SRC, attrs)) {
|
||||
auto a_zp = instance.activations_zero_points_memory(0);
|
||||
dnnl::memory::desc desc = onednn::layout_to_memory_desc(a_zp->get_layout(), dnnl::memory::format_tag::a, true);
|
||||
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, a_zp->get_onednn_memory(desc)});
|
||||
}
|
||||
|
||||
if (has_zero_points(DNNL_ARG_WEIGHTS, _attrs)) {
|
||||
if (has_zero_points(DNNL_ARG_WEIGHTS, attrs)) {
|
||||
auto w_zp = instance.weights_zero_points_memory(0);
|
||||
dnnl::memory::desc desc = onednn::layout_to_memory_desc(w_zp->get_layout(), dnnl::memory::format_tag::a, true);
|
||||
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, w_zp->get_onednn_memory(desc)});
|
||||
@ -74,7 +75,7 @@ protected:
|
||||
}
|
||||
|
||||
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<convolution>& arg) {
|
||||
auto attrs = parent::get_primitive_attributes(arg);
|
||||
auto attrs = arg.get_onednn_primitive_attributes();
|
||||
|
||||
if (arg.activations_zero_points_term()) {
|
||||
auto& a_zp = arg.activations_zero_points();
|
||||
|
@ -62,7 +62,7 @@ protected:
|
||||
}
|
||||
|
||||
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<deconvolution>& arg) {
|
||||
auto attrs = parent::get_primitive_attributes(arg);
|
||||
auto attrs = arg.get_onednn_primitive_attributes();
|
||||
|
||||
return attrs;
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
auto attr = get_primitive_attributes(arg);
|
||||
auto attr = arg.get_onednn_primitive_attributes();
|
||||
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
|
||||
|
||||
return new fully_connected_onednn(arg, desc, attr, prim_desc, get_weights_reorder(arg, prim_desc));
|
||||
|
@ -71,7 +71,7 @@ public:
|
||||
static primitive_impl* create(const pooling_node& arg) {
|
||||
auto& engine = arg.get_program().get_engine();
|
||||
auto desc = get_pooling_descriptor(arg);
|
||||
auto attr = get_primitive_attributes(arg);
|
||||
auto attr = arg.get_onednn_primitive_attributes();
|
||||
dnnl::primitive_desc prim_desc{&desc->data, attr.get(), engine.get_onednn_engine(), nullptr};
|
||||
|
||||
return new pooling_onednn(arg, desc, attr, prim_desc);
|
||||
|
@ -26,36 +26,6 @@
|
||||
namespace cldnn {
|
||||
namespace onednn {
|
||||
|
||||
enum class onednn_post_op_type : uint32_t {
|
||||
eltwise_act,
|
||||
eltwise_clip,
|
||||
eltwise_linear,
|
||||
eltwise_round,
|
||||
binary_mul,
|
||||
binary_add,
|
||||
binary_max,
|
||||
binary_min,
|
||||
scale,
|
||||
sum,
|
||||
optimized,
|
||||
optimized_eltwise,
|
||||
optimized_sum
|
||||
};
|
||||
|
||||
struct onednn_post_op_desc {
|
||||
onednn_post_op_type op_type;
|
||||
size_t mem_offset;
|
||||
size_t mem_dep;
|
||||
};
|
||||
|
||||
// This map contains information about onednn post-ops types, memory buffer offsets and dependencies
|
||||
// key is cldnn::primitive_id,
|
||||
// value is an instance of struct onednn_post_op_desc containing info about post-ops related to the node defined by key:
|
||||
// op_type - onednn_post_op_type (enum),
|
||||
// mem_offset - index of memory buffer for current post-operation,
|
||||
// mem_dep - memory dependency for working with fused node
|
||||
static std::unordered_map<cldnn::primitive_id, std::vector<onednn_post_op_desc>> onednn_fusing_map;
|
||||
|
||||
template <class PType, class DescType, class PrimDescType = dnnl::primitive_desc, class PrimType = dnnl::primitive>
|
||||
struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
|
||||
const typed_program_node<PType>& _outer;
|
||||
@ -82,7 +52,7 @@ struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
|
||||
protected:
|
||||
virtual bool optimized_out(typed_primitive_inst<PType>&) const { return false; }
|
||||
|
||||
static bool has_out_scales(const std::shared_ptr<dnnl::primitive_attr>& attr) {
|
||||
static bool has_output_scales(const std::shared_ptr<dnnl::primitive_attr>& attr) {
|
||||
int mask;
|
||||
std::vector<float> scales;
|
||||
attr->get_output_scales(mask, scales);
|
||||
@ -98,408 +68,22 @@ protected:
|
||||
return !zp.empty() && (reinterpret_cast<const int32_t&>(zp[0]) == drsv);
|
||||
}
|
||||
|
||||
static dnnl::post_ops try_optimize_post_ops(const typed_program_node<PType>& arg, dnnl::post_ops& p_ops,
|
||||
const std::shared_ptr<dnnl::primitive_attr>& attr,
|
||||
bool& optimization_is_completed) {
|
||||
// Get current node id for creating of optimization map
|
||||
auto node_id = arg.id();
|
||||
|
||||
// Create new dnnl::post_ops object which will be filled inside the optimization process
|
||||
dnnl::post_ops optimized_p_ops;
|
||||
|
||||
// Add new post-op into optimized_p_ops structure
|
||||
auto add_post_op = [&](onednn_post_op_type type, const dnnl::post_ops& cur_p_ops, dnnl::post_ops& new_p_ops, int idx) {
|
||||
switch (type) {
|
||||
case onednn_post_op_type::eltwise_act:
|
||||
case onednn_post_op_type::eltwise_clip:
|
||||
case onednn_post_op_type::eltwise_linear:
|
||||
case onednn_post_op_type::eltwise_round:
|
||||
{
|
||||
dnnl::algorithm alg;
|
||||
float scale, alpha, beta;
|
||||
cur_p_ops.get_params_eltwise(idx, scale, alg, alpha, beta);
|
||||
new_p_ops.append_eltwise(scale, alg, alpha, beta);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_add:
|
||||
case onednn_post_op_type::binary_mul:
|
||||
case onednn_post_op_type::binary_max:
|
||||
case onednn_post_op_type::binary_min:
|
||||
{
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
cur_p_ops.get_params_binary(idx, alg, desc);
|
||||
new_p_ops.append_binary(alg, desc);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::scale:
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::sum:
|
||||
{
|
||||
float scale;
|
||||
dnnl::memory::data_type data_type;
|
||||
cur_p_ops.get_params_sum(idx, scale, data_type);
|
||||
new_p_ops.append_sum(scale, data_type);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::optimized:
|
||||
case onednn_post_op_type::optimized_sum:
|
||||
case onednn_post_op_type::optimized_eltwise:
|
||||
{
|
||||
// Current operation already has been optimized => don't need extra actions
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported onednn post-operation type");
|
||||
}
|
||||
};
|
||||
|
||||
// Check that post-op type is any optimized
|
||||
auto type_is_any_optimized = [](onednn_post_op_type type) -> bool {
|
||||
return type == onednn_post_op_type::optimized || type == onednn_post_op_type::optimized_sum ||
|
||||
type == onednn_post_op_type::optimized_eltwise;
|
||||
};
|
||||
|
||||
// Check that post-op type is eltwise
|
||||
auto type_is_eltwise = [](onednn_post_op_type type) -> bool {
|
||||
return type == onednn_post_op_type::eltwise_round || type == onednn_post_op_type::eltwise_linear ||
|
||||
type == onednn_post_op_type::eltwise_clip || type == onednn_post_op_type::eltwise_act;
|
||||
};
|
||||
|
||||
// Check that post-op type is binary_add or binary_mul
|
||||
auto type_is_binary_add_or_mul = [](onednn_post_op_type type) -> bool {
|
||||
return type == onednn_post_op_type::binary_add || type == onednn_post_op_type::binary_mul;
|
||||
};
|
||||
|
||||
// Simple post-op type checks
|
||||
auto type_is_optimized = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized; };
|
||||
auto type_is_eltwise_linear = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::eltwise_linear; };
|
||||
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_eltwise; };
|
||||
auto type_is_binary_add = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_add; };
|
||||
auto type_is_binary_mul = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_mul; };
|
||||
auto type_is_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::sum; };
|
||||
auto type_is_optimized_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_sum; };
|
||||
auto type_is_scale = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::scale; };
|
||||
|
||||
auto& cur_post_ops = onednn_fusing_map[node_id];
|
||||
|
||||
size_t cur_post_op_idx = 1;
|
||||
size_t prev_post_op_idx = 0;
|
||||
bool optimization_done = false;
|
||||
|
||||
// Check and update post-op map if we already optimized something
|
||||
for (size_t post_op_idx = 0; post_op_idx < cur_post_ops.size(); post_op_idx++) {
|
||||
if (type_is_optimized_sum(cur_post_ops[post_op_idx].op_type))
|
||||
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
|
||||
else if (type_is_optimized_eltwise(cur_post_ops[post_op_idx].op_type))
|
||||
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
|
||||
else if (type_is_optimized(cur_post_ops[post_op_idx].op_type))
|
||||
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
|
||||
}
|
||||
|
||||
// Get post-ops size for current node
|
||||
auto post_ops_size = cur_post_ops.size();
|
||||
|
||||
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
|
||||
while (!optimization_done) {
|
||||
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
auto prev_type = cur_post_ops[prev_post_op_idx].op_type;
|
||||
|
||||
// Ignore optimized operations for "previous" operation in our operation pair
|
||||
while (type_is_any_optimized(prev_type) && cur_post_op_idx < post_ops_size - 1) {
|
||||
prev_post_op_idx++;
|
||||
cur_post_op_idx++;
|
||||
prev_type = cur_post_ops[prev_post_op_idx].op_type;
|
||||
cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
}
|
||||
|
||||
// Ignore optimized operations for "current" operation in our operation pair
|
||||
while (type_is_any_optimized(cur_type) && cur_post_op_idx < post_ops_size - 1) {
|
||||
cur_post_op_idx++;
|
||||
cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
}
|
||||
|
||||
auto cur_idx = static_cast<int>(has_out_scales(attr) ? (cur_post_op_idx >= 1 ? cur_post_op_idx - 1 : 0) : cur_post_op_idx);
|
||||
auto prev_idx = static_cast<int>(has_out_scales(attr) ? (prev_post_op_idx >= 1 ? prev_post_op_idx - 1 : 0) : prev_post_op_idx);
|
||||
|
||||
// If this is the last pair and it's optimized - add the last post-op and go out from the cycle
|
||||
if (cur_post_op_idx == post_ops_size - 1 && (type_is_any_optimized(cur_type) || type_is_any_optimized(prev_type))) {
|
||||
if (!type_is_any_optimized(prev_type)) {
|
||||
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
|
||||
}
|
||||
if (!type_is_any_optimized(cur_type)) {
|
||||
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Post-ops combinations which can be simplified
|
||||
auto eltw_and_eltw = type_is_eltwise(cur_type) && type_is_eltwise(prev_type);
|
||||
auto bin_and_eltw = type_is_binary_add_or_mul(cur_type) && type_is_eltwise_linear(prev_type);
|
||||
auto eltw_and_bin = type_is_eltwise_linear(cur_type) && type_is_binary_add_or_mul(prev_type);
|
||||
auto sum_and_eltw = type_is_sum(cur_type) && type_is_eltwise(prev_type);
|
||||
auto eltw_and_scale = type_is_eltwise_linear(cur_type) && type_is_scale(prev_type);
|
||||
|
||||
auto can_try_optimize = eltw_and_eltw ||
|
||||
bin_and_eltw ||
|
||||
eltw_and_bin ||
|
||||
sum_and_eltw ||
|
||||
eltw_and_scale;
|
||||
|
||||
bool cur_ops_pair_is_optimized = false;
|
||||
|
||||
if (can_try_optimize) {
|
||||
if (eltw_and_eltw) {
|
||||
dnnl::algorithm cur_alg, prev_alg;
|
||||
float cur_scale, prev_scale, cur_alpha, prev_alpha, cur_beta, prev_beta;
|
||||
|
||||
p_ops.get_params_eltwise(prev_idx, prev_scale, prev_alg, prev_alpha, prev_beta);
|
||||
p_ops.get_params_eltwise(cur_idx, cur_scale, cur_alg, cur_alpha, cur_beta);
|
||||
|
||||
auto eltw_linear_and_eltw_linear = type_is_eltwise_linear(cur_type) && type_is_eltwise_linear(prev_type);
|
||||
auto eltw_linear_and_eltw_non_linear = type_is_eltwise_linear(cur_type) && !type_is_eltwise_linear(prev_type) && cur_beta == 0;
|
||||
|
||||
// eltwise_linear + eltwise_linear combination can be optimized always
|
||||
if (eltw_linear_and_eltw_linear) {
|
||||
dnnl::post_ops eltw_p_op;
|
||||
float optimized_alpha = cur_alpha * prev_alpha * prev_scale;
|
||||
float optimized_beta = cur_alpha * prev_beta * prev_scale + cur_beta;
|
||||
float optimized_scale = cur_scale;
|
||||
eltw_p_op.append_eltwise(optimized_scale, cur_alg, optimized_alpha, optimized_beta);
|
||||
|
||||
// Combine 2 eltwises into one
|
||||
add_post_op(cur_type, eltw_p_op, optimized_p_ops, 0);
|
||||
} else if (eltw_linear_and_eltw_non_linear) {
|
||||
dnnl::post_ops eltw_p_op;
|
||||
eltw_p_op.append_eltwise(cur_scale * prev_scale * cur_alpha, prev_alg, prev_alpha, prev_beta);
|
||||
|
||||
// Combine 2 eltwises into one
|
||||
add_post_op(prev_type, eltw_p_op, optimized_p_ops, 0);
|
||||
}
|
||||
|
||||
if (eltw_linear_and_eltw_linear || eltw_linear_and_eltw_non_linear) {
|
||||
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
|
||||
|
||||
// Set the flag if extra optimizations checking is needed
|
||||
if (cur_post_op_idx < post_ops_size - 1) {
|
||||
if (type_is_eltwise_linear(cur_post_ops[cur_post_op_idx + 1].op_type) ||
|
||||
type_is_binary_add_or_mul(cur_post_ops[cur_post_op_idx + 1].op_type) ||
|
||||
type_is_optimized_eltwise(cur_post_ops[cur_post_op_idx + 1].op_type)) {
|
||||
optimization_is_completed = true;
|
||||
}
|
||||
}
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (bin_and_eltw) {
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
float scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& cur_node = arg.get_dependency(cur_post_ops[cur_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_binary(cur_idx, alg, desc);
|
||||
p_ops.get_params_eltwise(prev_idx, scale, alg, alpha, beta);
|
||||
|
||||
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
|
||||
auto bin_ops_can_be_optimized = cur_node.is_type<data>() && cur_node.is_constant() &&
|
||||
cur_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
|
||||
|
||||
auto bin_add_and_eltw = alpha == 1.0f && scale == 1.0f && type_is_binary_add(cur_type) && bin_ops_can_be_optimized;
|
||||
auto bin_mul_and_eltw = beta == 0.f && type_is_binary_mul(cur_type) && bin_ops_can_be_optimized;
|
||||
|
||||
if (bin_add_and_eltw || bin_mul_and_eltw) {
|
||||
memory::ptr cur_bin_mem_ptr = cur_node.as<data>().get_attached_memory_ptr();
|
||||
if (cur_bin_mem_ptr == nullptr)
|
||||
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for bin + eltw");
|
||||
auto& stream = cur_bin_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::write> bin_and_eltw_lock(cur_bin_mem_ptr, stream);
|
||||
|
||||
size_t cur_bin_mem_size = cur_node.get_output_layout().count();
|
||||
|
||||
// Update all binary coefficients
|
||||
if (bin_add_and_eltw) {
|
||||
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
|
||||
bin_and_eltw_lock[data_idx] += beta;
|
||||
}
|
||||
} else {
|
||||
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
|
||||
bin_and_eltw_lock[data_idx] *= alpha * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Marked previous eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (eltw_and_bin) {
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
float scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_eltwise(cur_idx, scale, alg, alpha, beta);
|
||||
p_ops.get_params_binary(prev_idx, alg, desc);
|
||||
|
||||
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
|
||||
auto bin_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
|
||||
prev_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
|
||||
|
||||
auto eltw_and_bin_add = alpha == 1.0f && scale == 1.0f && type_is_binary_add(prev_type) && bin_ops_can_be_optimized;
|
||||
auto eltw_and_bin_mul = beta == 0.f && type_is_binary_mul(prev_type) && bin_ops_can_be_optimized;
|
||||
|
||||
if (eltw_and_bin_add || eltw_and_bin_mul) {
|
||||
memory::ptr prev_bin_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
|
||||
if (prev_bin_mem_ptr == nullptr)
|
||||
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + bin");
|
||||
auto& stream = prev_bin_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::write> eltw_and_bin_lock(prev_bin_mem_ptr, stream);
|
||||
|
||||
size_t prev_bin_mem_size = prev_node.get_output_layout().count();
|
||||
|
||||
// Update all binary coefficients
|
||||
if (eltw_and_bin_add) {
|
||||
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
|
||||
eltw_and_bin_lock[data_idx] += beta;
|
||||
}
|
||||
} else {
|
||||
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
|
||||
eltw_and_bin_lock[data_idx] *= alpha * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (sum_and_eltw) {
|
||||
dnnl::algorithm alg;
|
||||
float sum_scale, eltw_scale, alpha, beta;
|
||||
dnnl::memory::data_type data_type;
|
||||
|
||||
dnnl::algorithm next_alg;
|
||||
float next_scale, next_alpha, next_beta;
|
||||
size_t next_idx = cur_idx + 1;
|
||||
size_t next_post_op_idx = cur_post_op_idx + 1;
|
||||
|
||||
bool can_optimize_eltw_and_sum = false;
|
||||
|
||||
if (cur_post_op_idx < post_ops_size - 1) {
|
||||
auto next_type = cur_post_ops[next_post_op_idx].op_type;
|
||||
if (type_is_eltwise_linear(next_type)) {
|
||||
p_ops.get_params_eltwise(next_idx, next_scale, next_alg, next_alpha, next_beta);
|
||||
|
||||
if (next_beta == 0)
|
||||
can_optimize_eltw_and_sum = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to optimize eltwise (any) + sum + eltwise_linear (with beta = 0) chain of operations
|
||||
if (can_optimize_eltw_and_sum) {
|
||||
p_ops.get_params_sum(cur_idx, sum_scale, data_type);
|
||||
p_ops.get_params_eltwise(prev_idx, eltw_scale, alg, alpha, beta);
|
||||
|
||||
dnnl::post_ops eltw_p_op_prev, sum_p_op;
|
||||
|
||||
eltw_p_op_prev.append_eltwise(eltw_scale * next_alpha * next_scale, alg, alpha, beta);
|
||||
sum_p_op.append_sum(sum_scale * next_alpha, data_type);
|
||||
|
||||
add_post_op(prev_type, eltw_p_op_prev, optimized_p_ops, 0);
|
||||
add_post_op(cur_type, sum_p_op, optimized_p_ops, 0);
|
||||
|
||||
// Marked current, previous and next operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
|
||||
cur_post_ops[next_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
// Set the flag if extra optimizations checking is needed
|
||||
if (next_post_op_idx < post_ops_size - 1) {
|
||||
if (type_is_eltwise_linear(cur_post_ops[next_post_op_idx + 1].op_type) ||
|
||||
type_is_optimized_eltwise(cur_post_ops[next_post_op_idx + 1].op_type)) {
|
||||
optimization_is_completed = true;
|
||||
}
|
||||
}
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (eltw_and_scale) {
|
||||
dnnl::algorithm alg;
|
||||
float eltw_scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& prev_node = arg.get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_eltwise(cur_idx, eltw_scale, alg, alpha, beta);
|
||||
|
||||
// Eltwise can be inserted into the output_scale if cur_beta is equal to 0.f
|
||||
if (beta == 0.f && prev_node.get_output_layout().data_type == data_types::f32) {
|
||||
memory::ptr prev_scale_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
|
||||
if (prev_scale_mem_ptr == nullptr)
|
||||
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + scale");
|
||||
auto& stream = prev_scale_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::write> eltw_and_scale_lock(prev_scale_mem_ptr, stream);
|
||||
|
||||
size_t prev_scale_mem_size = prev_node.get_output_layout().count();
|
||||
|
||||
// Update all scale coefficients
|
||||
for (size_t data_idx = 0; data_idx < prev_scale_mem_size; data_idx++) {
|
||||
eltw_and_scale_lock[data_idx] *= alpha * eltw_scale;
|
||||
}
|
||||
|
||||
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no optimizations have been applied then copy post-op info into the new optimized_p_ops structure
|
||||
if (!(has_out_scales(attr) && prev_post_op_idx == 0) && !cur_ops_pair_is_optimized) {
|
||||
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
|
||||
}
|
||||
|
||||
if (cur_post_op_idx == post_ops_size - 1 && !cur_ops_pair_is_optimized) {
|
||||
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
|
||||
optimization_done = true;
|
||||
} else if (cur_post_ops[cur_post_op_idx].op_type != onednn_post_op_type::optimized) {
|
||||
cur_post_op_idx++;
|
||||
prev_post_op_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
optimization_is_completed = !optimization_is_completed;
|
||||
|
||||
return optimized_p_ops;
|
||||
}
|
||||
|
||||
void configure_post_ops_arguments(typed_primitive_inst<PType>& instance, std::unordered_map<int, dnnl::memory>& args) const {
|
||||
// Get current node id for creating of optimization map
|
||||
auto node_id = instance.id();
|
||||
auto& node = instance.get_node();
|
||||
auto& engine = instance.get_network().get_engine();
|
||||
auto dnnl_engine = engine.get_onednn_engine();
|
||||
|
||||
// Get current post-ops info
|
||||
dnnl::post_ops post_ops = _attrs->get_post_ops();
|
||||
auto onednn_attrs = node.get_onednn_primitive_attributes();
|
||||
dnnl::post_ops post_ops = onednn_attrs->get_post_ops();
|
||||
|
||||
// Create onednn memory buffers for post-ops
|
||||
auto& cur_post_ops = onednn_fusing_map[node_id];
|
||||
auto& cur_post_ops = node.get_fused_primitives_onednn();
|
||||
auto post_ops_size = cur_post_ops.size();
|
||||
for (size_t post_op_idx = 0, num_of_optimized_post_ops = 0; post_op_idx < post_ops_size; post_op_idx++) {
|
||||
auto post_op_type = cur_post_ops[post_op_idx].op_type;
|
||||
auto memory_offset = cur_post_ops[post_op_idx].mem_offset;
|
||||
auto onednn_post_op_idx = has_out_scales(_attrs) && post_op_idx > 0 ? post_op_idx - 1 : post_op_idx;
|
||||
auto onednn_post_op_idx = has_output_scales(onednn_attrs) && post_op_idx > 0 ? post_op_idx - 1 : post_op_idx;
|
||||
onednn_post_op_idx -= num_of_optimized_post_ops;
|
||||
|
||||
switch (post_op_type) {
|
||||
@ -576,296 +160,6 @@ protected:
|
||||
|
||||
void init_kernels() override { }
|
||||
|
||||
static std::shared_ptr<dnnl::primitive_attr> get_primitive_attributes(const typed_program_node<PType>& arg) {
|
||||
const std::vector<fused_primitive_desc>& cldnn_post_ops = arg.get_fused_primitives();
|
||||
auto attrs = std::make_shared<dnnl::primitive_attr>();
|
||||
dnnl::post_ops post_ops;
|
||||
size_t memory_offset = 0;
|
||||
|
||||
// Create onednn post-ops list related to the current node
|
||||
std::vector<onednn_post_op_desc> fused_ops;
|
||||
|
||||
// Added this for debug purposes only
|
||||
size_t empty_mem = 0xff;
|
||||
|
||||
// Add information about post-operation into the list, update indices
|
||||
auto update_onednn_post_op_list = [&](onednn_post_op_type type, size_t m_dep) {
|
||||
onednn_post_op_desc cur_op_desc = { type, memory_offset, m_dep };
|
||||
fused_ops.push_back(cur_op_desc);
|
||||
|
||||
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
|
||||
type == onednn_post_op_type::binary_mul ||
|
||||
type == onednn_post_op_type::binary_max ||
|
||||
type == onednn_post_op_type::binary_min ||
|
||||
type == onednn_post_op_type::scale ||
|
||||
type == onednn_post_op_type::sum;
|
||||
if (has_memory_buffers)
|
||||
memory_offset++;
|
||||
};
|
||||
|
||||
for (size_t idx = 0; idx < cldnn_post_ops.size(); idx++) {
|
||||
auto node = cldnn_post_ops[idx].node;
|
||||
|
||||
if (node->is_type<activation>()) {
|
||||
auto fused_desc = node->as<activation>().get_primitive();
|
||||
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
|
||||
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||
} else if (node->is_type<eltwise>()) {
|
||||
auto& e_node = node->as<eltwise>();
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
auto in = arg.get_dependency(dep_idx).get_output_layout();
|
||||
|
||||
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
|
||||
if (e_node.get_primitive()->needs_onednn_sum_post_op(in)) {
|
||||
post_ops.append_sum(1.0f, onednn::convert_data_type(in.data_type));
|
||||
update_onednn_post_op_list(onednn_post_op_type::sum, dep_idx);
|
||||
} else {
|
||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
|
||||
}
|
||||
} else {
|
||||
if (in.size.spatial[0] > 1 || in.size.spatial[1] > 1 || in.size.batch[0] > 1)
|
||||
throw std::runtime_error("Unsupported eltwise mode for fused onednn op");
|
||||
if (idx == 0 && !has_out_scales(attrs)) {
|
||||
int mask = in.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
|
||||
} else {
|
||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
|
||||
}
|
||||
}
|
||||
} else if (node->is_type<quantize>()) {
|
||||
auto& q_node = node->as<quantize>();
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
|
||||
if (q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val()) {
|
||||
// 1. pre-scale & pre-shift
|
||||
{
|
||||
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_per_tensor_input_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
if (idx == 0 && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
|
||||
arg.type() == convolution::type_id() &&
|
||||
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
|
||||
int mask = in_scale.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
|
||||
} else {
|
||||
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_pre_shift()) {
|
||||
if (q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. round
|
||||
auto out_dt = cldnn_post_ops[idx].output_layout.data_type;
|
||||
bool output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
|
||||
if (!output_type_is_int8) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
|
||||
}
|
||||
|
||||
// 3. post-scale & post-shift
|
||||
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
|
||||
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_need_post_scale()) {
|
||||
if (q_node.get_per_tensor_output_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_post_shift()) {
|
||||
if (q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. clamp
|
||||
if (q_node.get_need_clamp()) {
|
||||
float out_lo = q_node.get_need_min_clamp() ? q_node.get_output_lo_val() : data_type_traits::min<float>(out_dt);
|
||||
float out_hi = q_node.get_need_max_clamp() ? q_node.get_output_hi_val() : data_type_traits::max<float>(out_dt);
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip, out_lo, out_hi);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_clip, empty_mem);
|
||||
}
|
||||
} else {
|
||||
// 1. clamp
|
||||
if (q_node.get_need_clamp()) {
|
||||
auto in_lo = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
auto in_hi = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::algorithm clamp_max = dnnl::algorithm::binary_max;
|
||||
dnnl::algorithm clamp_min = dnnl::algorithm::binary_min;
|
||||
dnnl::memory::desc in_lo_desc = onednn::layout_to_memory_desc(in_lo, dnnl::memory::format_tag::ab, true);
|
||||
dnnl::memory::desc in_hi_desc = onednn::layout_to_memory_desc(in_hi, dnnl::memory::format_tag::ab, true);
|
||||
|
||||
post_ops.append_binary(clamp_max, in_lo_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_max, dep_idx - 2);
|
||||
post_ops.append_binary(clamp_min, in_hi_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_min, dep_idx - 1);
|
||||
}
|
||||
|
||||
// 2. pre-scale & pre-shift
|
||||
{
|
||||
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_per_tensor_input_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
if (idx == 0 && !q_node.get_need_clamp() && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
|
||||
arg.type() == convolution::type_id() &&
|
||||
!data_type_traits::is_floating_point(arg.get_dependency(0).get_output_layout().data_type)) {
|
||||
int mask = in_scale.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
|
||||
} else {
|
||||
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_pre_shift()) {
|
||||
if (q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. round
|
||||
{
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
|
||||
}
|
||||
|
||||
// 4. post-scale & post-shift
|
||||
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
|
||||
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_need_post_scale()) {
|
||||
if (q_node.get_per_tensor_output_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_scale = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_post_shift()) {
|
||||
if (q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_shift = arg.get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (node->is_type<reorder>()) {
|
||||
continue;
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported fused op of " + node->get_primitive()->type_string() + " type for oneDNN primitive");
|
||||
}
|
||||
}
|
||||
|
||||
if (cldnn_post_ops.size() && arg.get_fused_activations_funcs().size())
|
||||
throw std::runtime_error("Unsupported mix of fused ops and activations");
|
||||
|
||||
for (size_t i = 0; i < arg.get_fused_activations_funcs().size(); i++) {
|
||||
auto activation_type = arg.get_fused_activations_funcs()[i];
|
||||
auto params = arg.get_fused_activations_params()[i];
|
||||
dnnl::algorithm alg = onednn::convert_activation_func(activation_type);
|
||||
post_ops.append_eltwise(1.0f, alg, params.a, params.b);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||
}
|
||||
|
||||
// Update total onednn post-ops info
|
||||
auto it = onednn_fusing_map.find(arg.id());
|
||||
if (it != onednn_fusing_map.end()) {
|
||||
it->second = std::move(fused_ops);
|
||||
} else {
|
||||
onednn_fusing_map.emplace(arg.id(), std::move(fused_ops));
|
||||
}
|
||||
|
||||
// Trying to optimize more than 1 post-ops
|
||||
auto post_ops_size = onednn_fusing_map[arg.id()].size();
|
||||
|
||||
if (post_ops_size > 1) {
|
||||
dnnl::post_ops optimized_post_ops = post_ops;
|
||||
bool optimization_is_finished = false;
|
||||
|
||||
// Trying to combine multiplications and additions which are placed one after another.
|
||||
// We do it in the cycle because "eltw + eltw" cases can be simplified again in some cases.
|
||||
do {
|
||||
optimized_post_ops = try_optimize_post_ops(arg, optimized_post_ops, attrs, optimization_is_finished);
|
||||
} while (!optimization_is_finished);
|
||||
|
||||
attrs->set_post_ops(optimized_post_ops);
|
||||
} else {
|
||||
// Set post-ops without any optimizations
|
||||
attrs->set_post_ops(post_ops);
|
||||
}
|
||||
|
||||
return attrs;
|
||||
}
|
||||
|
||||
event::ptr aggregate_events(const std::vector<event::ptr>& events, stream& stream, bool group = false, bool is_output = false) const {
|
||||
if (events.size() == 1 && !is_output)
|
||||
return events[0];
|
||||
|
@ -55,13 +55,13 @@ protected:
|
||||
input_md,
|
||||
engine.get_onednn_engine(),
|
||||
output_md,
|
||||
*get_primitive_attributes(arg));
|
||||
*(arg.get_onednn_primitive_attributes()));
|
||||
}
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const reorder_node& arg) {
|
||||
auto desc = get_reorder_descriptor(arg);
|
||||
auto attr = get_primitive_attributes(arg);
|
||||
auto attr = arg.get_onednn_primitive_attributes();
|
||||
|
||||
std::shared_ptr<void> dummy = nullptr;
|
||||
|
||||
|
@ -393,4 +393,10 @@ private:
|
||||
void run(program& p) override;
|
||||
};
|
||||
|
||||
class add_onednn_optimization_attributes : public base_pass {
|
||||
public:
|
||||
add_onednn_optimization_attributes() : base_pass("add_onednn_optimization_attributes") {}
|
||||
void run(program& p) override;
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -33,6 +33,29 @@ struct typed_program_node;
|
||||
class json_composite;
|
||||
class xml_composite;
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
enum class onednn_post_op_type : uint32_t {
|
||||
eltwise_act,
|
||||
eltwise_clip,
|
||||
eltwise_linear,
|
||||
eltwise_round,
|
||||
binary_mul,
|
||||
binary_add,
|
||||
binary_max,
|
||||
binary_min,
|
||||
scale,
|
||||
sum,
|
||||
optimized,
|
||||
optimized_eltwise,
|
||||
optimized_sum
|
||||
};
|
||||
|
||||
struct fused_primitive_desc_onednn {
|
||||
onednn_post_op_type op_type; // onednn post-operation type
|
||||
size_t mem_offset; // index of a memory buffer for current post-operation
|
||||
size_t mem_dep; // memory dependency for working with fused node
|
||||
};
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
|
||||
struct fused_primitive_desc {
|
||||
std::shared_ptr<program_node> node;
|
||||
@ -57,7 +80,7 @@ struct fused_primitive_desc {
|
||||
to API level where all primitives store only ids of related ones.
|
||||
*/
|
||||
struct program_node {
|
||||
friend struct program; // to be removed when possible
|
||||
friend struct program; // to be removed when possible
|
||||
friend class compile_graph; // to be removed when possible
|
||||
friend class graph_initializations; // to be removed when possible
|
||||
friend class pre_replace_deconv; // to be removed when possible
|
||||
@ -293,6 +316,16 @@ public:
|
||||
const std::vector<fused_primitive_desc>& get_fused_primitives() const { return fused_prims; }
|
||||
std::vector<fused_primitive_desc>& get_fused_primitives() { return fused_prims; }
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
const std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() const { return onednn_attrs; }
|
||||
std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() { return onednn_attrs; }
|
||||
|
||||
const std::vector<fused_primitive_desc_onednn>& get_fused_primitives_onednn() const { return fused_prims_onednn; }
|
||||
std::vector<fused_primitive_desc_onednn>& get_fused_primitives_onednn() { return fused_prims_onednn; }
|
||||
|
||||
void init_onednn_primitive_attributes();
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
|
||||
size_t get_fused_inputs_count() const {
|
||||
size_t count = 0;
|
||||
for (auto& fp : get_fused_primitives()) {
|
||||
@ -360,7 +393,26 @@ protected:
|
||||
|
||||
std::vector<fused_activation_params> fused_activations;
|
||||
std::vector<fused_primitive_desc> fused_prims;
|
||||
|
||||
void invalidate_users() const;
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
std::vector<fused_primitive_desc_onednn> fused_prims_onednn;
|
||||
std::shared_ptr<dnnl::primitive_attr> onednn_attrs;
|
||||
|
||||
void add_onednn_fused_primitives(std::vector<fused_primitive_desc_onednn> descs) {
|
||||
fused_prims_onednn.erase(fused_prims_onednn.begin(), fused_prims_onednn.end());
|
||||
fused_prims_onednn.insert(fused_prims_onednn.end(), descs.begin(), descs.end());
|
||||
}
|
||||
|
||||
void add_onednn_attrs(std::shared_ptr<dnnl::primitive_attr> attrs) {
|
||||
onednn_attrs = attrs;
|
||||
}
|
||||
|
||||
bool has_out_scales(const std::shared_ptr<dnnl::primitive_attr>& attr);
|
||||
dnnl::post_ops try_optimize_post_ops(dnnl::post_ops& p_ops, const std::shared_ptr<dnnl::primitive_attr>& attr, bool& optimization_is_completed);
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
};
|
||||
|
||||
/*
|
||||
|
@ -534,6 +534,9 @@ void program::pre_optimize_graph(bool is_internal) {
|
||||
|
||||
// check if there exists some layout incompatibilities and add an reorder node if required
|
||||
apply_opt_pass<add_required_reorders>();
|
||||
|
||||
// add optimization attributes for onednn primitives
|
||||
apply_opt_pass<add_onednn_optimization_attributes>();
|
||||
}
|
||||
|
||||
void program::post_optimize_graph(bool is_internal) {
|
||||
|
@ -5,6 +5,14 @@
|
||||
#include "program_node.h"
|
||||
#include "cldnn/graph/program.hpp"
|
||||
#include "primitive_inst.h"
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
#include "convolution_inst.h"
|
||||
#include "quantize_inst.h"
|
||||
#include "reorder_inst.h"
|
||||
#include <impls/onednn/utils.hpp>
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
|
||||
#include "to_string_utils.h"
|
||||
#include "json_object.h"
|
||||
#include <vector>
|
||||
@ -280,3 +288,699 @@ bool program_node::need_lockable_memory() const {
|
||||
|
||||
return need_lockable_mem;
|
||||
}
|
||||
|
||||
/* ----------------------------------------- */
|
||||
/* Onednn fused operations integration logic */
|
||||
/* ----------------------------------------- */
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
|
||||
bool program_node::has_out_scales(const std::shared_ptr<dnnl::primitive_attr>& attr) {
|
||||
int mask;
|
||||
std::vector<float> scales;
|
||||
attr->get_output_scales(mask, scales);
|
||||
const auto drfv = reinterpret_cast<const int32_t&>(DNNL_RUNTIME_F32_VAL);
|
||||
return !scales.empty() && (reinterpret_cast<const int32_t&>(scales[0]) == drfv);
|
||||
}
|
||||
|
||||
dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const std::shared_ptr<dnnl::primitive_attr>& attr,
|
||||
bool& optimization_is_completed) {
|
||||
// Create new dnnl::post_ops object which will be filled inside the optimization process
|
||||
dnnl::post_ops optimized_p_ops;
|
||||
|
||||
// Add new post-op into optimized_p_ops structure
|
||||
auto add_post_op = [&](onednn_post_op_type type, const dnnl::post_ops& cur_p_ops, dnnl::post_ops& new_p_ops, int idx) {
|
||||
switch (type) {
|
||||
case onednn_post_op_type::eltwise_act:
|
||||
case onednn_post_op_type::eltwise_clip:
|
||||
case onednn_post_op_type::eltwise_linear:
|
||||
case onednn_post_op_type::eltwise_round:
|
||||
{
|
||||
dnnl::algorithm alg;
|
||||
float scale, alpha, beta;
|
||||
cur_p_ops.get_params_eltwise(idx, scale, alg, alpha, beta);
|
||||
new_p_ops.append_eltwise(scale, alg, alpha, beta);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_add:
|
||||
case onednn_post_op_type::binary_mul:
|
||||
case onednn_post_op_type::binary_max:
|
||||
case onednn_post_op_type::binary_min:
|
||||
{
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
cur_p_ops.get_params_binary(idx, alg, desc);
|
||||
new_p_ops.append_binary(alg, desc);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::scale:
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::sum:
|
||||
{
|
||||
float scale;
|
||||
dnnl::memory::data_type data_type;
|
||||
cur_p_ops.get_params_sum(idx, scale, data_type);
|
||||
new_p_ops.append_sum(scale, data_type);
|
||||
break;
|
||||
}
|
||||
|
||||
case onednn_post_op_type::optimized:
|
||||
case onednn_post_op_type::optimized_sum:
|
||||
case onednn_post_op_type::optimized_eltwise:
|
||||
{
|
||||
// Current operation already has been optimized => don't need extra actions
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported onednn post-operation type");
|
||||
}
|
||||
};
|
||||
|
||||
// Check that post-op type is any optimized
|
||||
auto type_is_any_optimized = [](onednn_post_op_type type) -> bool {
|
||||
return type == onednn_post_op_type::optimized || type == onednn_post_op_type::optimized_sum ||
|
||||
type == onednn_post_op_type::optimized_eltwise;
|
||||
};
|
||||
|
||||
// Check that post-op type is eltwise
|
||||
auto type_is_eltwise = [](onednn_post_op_type type) -> bool {
|
||||
return type == onednn_post_op_type::eltwise_round || type == onednn_post_op_type::eltwise_linear ||
|
||||
type == onednn_post_op_type::eltwise_clip || type == onednn_post_op_type::eltwise_act;
|
||||
};
|
||||
|
||||
// Check that post-op type is binary_add or binary_mul
|
||||
auto type_is_binary_add_or_mul = [](onednn_post_op_type type) -> bool {
|
||||
return type == onednn_post_op_type::binary_add || type == onednn_post_op_type::binary_mul;
|
||||
};
|
||||
|
||||
// Simple post-op type checks
|
||||
auto type_is_optimized = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized; };
|
||||
auto type_is_eltwise_linear = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::eltwise_linear; };
|
||||
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_eltwise; };
|
||||
auto type_is_binary_add = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_add; };
|
||||
auto type_is_binary_mul = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_mul; };
|
||||
auto type_is_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::sum; };
|
||||
auto type_is_optimized_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_sum; };
|
||||
auto type_is_scale = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::scale; };
|
||||
|
||||
auto& cur_post_ops = get_fused_primitives_onednn();
|
||||
|
||||
size_t cur_post_op_idx = 1;
|
||||
size_t prev_post_op_idx = 0;
|
||||
bool optimization_done = false;
|
||||
|
||||
// Check and update post-op map if we already optimized something
|
||||
for (size_t post_op_idx = 0; post_op_idx < cur_post_ops.size(); post_op_idx++) {
|
||||
if (type_is_optimized_sum(cur_post_ops[post_op_idx].op_type))
|
||||
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
|
||||
else if (type_is_optimized_eltwise(cur_post_ops[post_op_idx].op_type))
|
||||
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
|
||||
else if (type_is_optimized(cur_post_ops[post_op_idx].op_type))
|
||||
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
|
||||
}
|
||||
|
||||
// Get post-ops size for current node
|
||||
auto post_ops_size = cur_post_ops.size();
|
||||
|
||||
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
|
||||
while (!optimization_done) {
|
||||
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
auto prev_type = cur_post_ops[prev_post_op_idx].op_type;
|
||||
|
||||
// Ignore optimized operations for "previous" operation in our operation pair
|
||||
while (type_is_any_optimized(prev_type) && cur_post_op_idx < post_ops_size - 1) {
|
||||
prev_post_op_idx++;
|
||||
cur_post_op_idx++;
|
||||
prev_type = cur_post_ops[prev_post_op_idx].op_type;
|
||||
cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
}
|
||||
|
||||
// Ignore optimized operations for "current" operation in our operation pair
|
||||
while (type_is_any_optimized(cur_type) && cur_post_op_idx < post_ops_size - 1) {
|
||||
cur_post_op_idx++;
|
||||
cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
||||
}
|
||||
|
||||
auto cur_idx = static_cast<int>(has_out_scales(attr) ? (cur_post_op_idx >= 1 ? cur_post_op_idx - 1 : 0) : cur_post_op_idx);
|
||||
auto prev_idx = static_cast<int>(has_out_scales(attr) ? (prev_post_op_idx >= 1 ? prev_post_op_idx - 1 : 0) : prev_post_op_idx);
|
||||
|
||||
// If this is the last pair and it's optimized - add the last post-op and go out from the cycle
|
||||
if (cur_post_op_idx == post_ops_size - 1 && (type_is_any_optimized(cur_type) || type_is_any_optimized(prev_type))) {
|
||||
if (!type_is_any_optimized(prev_type)) {
|
||||
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
|
||||
}
|
||||
if (!type_is_any_optimized(cur_type)) {
|
||||
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Post-ops combinations which can be simplified
|
||||
auto eltw_and_eltw = type_is_eltwise(cur_type) && type_is_eltwise(prev_type);
|
||||
auto bin_and_eltw = type_is_binary_add_or_mul(cur_type) && type_is_eltwise_linear(prev_type);
|
||||
auto eltw_and_bin = type_is_eltwise_linear(cur_type) && type_is_binary_add_or_mul(prev_type);
|
||||
auto sum_and_eltw = type_is_sum(cur_type) && type_is_eltwise(prev_type);
|
||||
auto eltw_and_scale = type_is_eltwise_linear(cur_type) && type_is_scale(prev_type);
|
||||
|
||||
auto can_try_optimize = eltw_and_eltw ||
|
||||
bin_and_eltw ||
|
||||
eltw_and_bin ||
|
||||
sum_and_eltw ||
|
||||
eltw_and_scale;
|
||||
|
||||
bool cur_ops_pair_is_optimized = false;
|
||||
|
||||
if (can_try_optimize) {
|
||||
if (eltw_and_eltw) {
|
||||
dnnl::algorithm cur_alg, prev_alg;
|
||||
float cur_scale, prev_scale, cur_alpha, prev_alpha, cur_beta, prev_beta;
|
||||
|
||||
p_ops.get_params_eltwise(prev_idx, prev_scale, prev_alg, prev_alpha, prev_beta);
|
||||
p_ops.get_params_eltwise(cur_idx, cur_scale, cur_alg, cur_alpha, cur_beta);
|
||||
|
||||
auto eltw_linear_and_eltw_linear = type_is_eltwise_linear(cur_type) && type_is_eltwise_linear(prev_type);
|
||||
auto eltw_linear_and_eltw_non_linear = type_is_eltwise_linear(cur_type) && !type_is_eltwise_linear(prev_type) && cur_beta == 0;
|
||||
|
||||
// eltwise_linear + eltwise_linear combination can be optimized always
|
||||
if (eltw_linear_and_eltw_linear) {
|
||||
dnnl::post_ops eltw_p_op;
|
||||
float optimized_alpha = cur_alpha * prev_alpha * prev_scale;
|
||||
float optimized_beta = cur_alpha * prev_beta * prev_scale + cur_beta;
|
||||
float optimized_scale = cur_scale;
|
||||
eltw_p_op.append_eltwise(optimized_scale, cur_alg, optimized_alpha, optimized_beta);
|
||||
|
||||
// Combine 2 eltwises into one
|
||||
add_post_op(cur_type, eltw_p_op, optimized_p_ops, 0);
|
||||
} else if (eltw_linear_and_eltw_non_linear) {
|
||||
dnnl::post_ops eltw_p_op;
|
||||
eltw_p_op.append_eltwise(cur_scale * prev_scale * cur_alpha, prev_alg, prev_alpha, prev_beta);
|
||||
|
||||
// Combine 2 eltwises into one
|
||||
add_post_op(prev_type, eltw_p_op, optimized_p_ops, 0);
|
||||
}
|
||||
|
||||
if (eltw_linear_and_eltw_linear || eltw_linear_and_eltw_non_linear) {
|
||||
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
|
||||
|
||||
// Set the flag if extra optimizations checking is needed
|
||||
if (cur_post_op_idx < post_ops_size - 1) {
|
||||
if (type_is_eltwise_linear(cur_post_ops[cur_post_op_idx + 1].op_type) ||
|
||||
type_is_binary_add_or_mul(cur_post_ops[cur_post_op_idx + 1].op_type) ||
|
||||
type_is_optimized_eltwise(cur_post_ops[cur_post_op_idx + 1].op_type)) {
|
||||
optimization_is_completed = true;
|
||||
}
|
||||
}
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (bin_and_eltw) {
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
float scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& cur_node = get_dependency(cur_post_ops[cur_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_binary(cur_idx, alg, desc);
|
||||
p_ops.get_params_eltwise(prev_idx, scale, alg, alpha, beta);
|
||||
|
||||
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
|
||||
auto bin_ops_can_be_optimized = cur_node.is_type<data>() && cur_node.is_constant() &&
|
||||
cur_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
|
||||
|
||||
auto bin_add_and_eltw = alpha == 1.0f && scale == 1.0f && type_is_binary_add(cur_type) && bin_ops_can_be_optimized;
|
||||
auto bin_mul_and_eltw = beta == 0.f && type_is_binary_mul(cur_type) && bin_ops_can_be_optimized;
|
||||
|
||||
if (bin_add_and_eltw || bin_mul_and_eltw) {
|
||||
memory::ptr cur_bin_mem_ptr = cur_node.as<data>().get_attached_memory_ptr();
|
||||
if (cur_bin_mem_ptr == nullptr)
|
||||
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for bin + eltw");
|
||||
auto& stream = cur_bin_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::read_write> bin_and_eltw_lock(cur_bin_mem_ptr, stream);
|
||||
|
||||
size_t cur_bin_mem_size = cur_node.get_output_layout().count();
|
||||
|
||||
// Update all binary coefficients
|
||||
if (bin_add_and_eltw) {
|
||||
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
|
||||
bin_and_eltw_lock[data_idx] += beta;
|
||||
}
|
||||
} else {
|
||||
for (size_t data_idx = 0; data_idx < cur_bin_mem_size; data_idx++) {
|
||||
bin_and_eltw_lock[data_idx] *= alpha * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Marked previous eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (eltw_and_bin) {
|
||||
dnnl::algorithm alg;
|
||||
dnnl::memory::desc desc;
|
||||
float scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& prev_node = get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_eltwise(cur_idx, scale, alg, alpha, beta);
|
||||
p_ops.get_params_binary(prev_idx, alg, desc);
|
||||
|
||||
// Eltwise operations can use runtime non-constant data buffers, so check that memory buffers consist of constant data only
|
||||
auto bin_ops_can_be_optimized = prev_node.is_type<data>() && prev_node.is_constant() &&
|
||||
prev_node.get_users().size() == 1 && desc.data_type() == dnnl_f32;
|
||||
|
||||
auto eltw_and_bin_add = alpha == 1.0f && scale == 1.0f && type_is_binary_add(prev_type) && bin_ops_can_be_optimized;
|
||||
auto eltw_and_bin_mul = beta == 0.f && type_is_binary_mul(prev_type) && bin_ops_can_be_optimized;
|
||||
|
||||
if (eltw_and_bin_add || eltw_and_bin_mul) {
|
||||
memory::ptr prev_bin_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
|
||||
if (prev_bin_mem_ptr == nullptr)
|
||||
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + bin");
|
||||
auto& stream = prev_bin_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::read_write> eltw_and_bin_lock(prev_bin_mem_ptr, stream);
|
||||
|
||||
size_t prev_bin_mem_size = prev_node.get_output_layout().count();
|
||||
|
||||
// Update all binary coefficients
|
||||
if (eltw_and_bin_add) {
|
||||
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
|
||||
eltw_and_bin_lock[data_idx] += beta;
|
||||
}
|
||||
} else {
|
||||
for (size_t data_idx = 0; data_idx < prev_bin_mem_size; data_idx++) {
|
||||
eltw_and_bin_lock[data_idx] *= alpha * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (sum_and_eltw) {
|
||||
dnnl::algorithm alg;
|
||||
float sum_scale, eltw_scale, alpha, beta;
|
||||
dnnl::memory::data_type data_type;
|
||||
|
||||
dnnl::algorithm next_alg;
|
||||
float next_scale, next_alpha, next_beta;
|
||||
size_t next_idx = cur_idx + 1;
|
||||
size_t next_post_op_idx = cur_post_op_idx + 1;
|
||||
|
||||
bool can_optimize_eltw_and_sum = false;
|
||||
|
||||
if (cur_post_op_idx < post_ops_size - 1) {
|
||||
auto next_type = cur_post_ops[next_post_op_idx].op_type;
|
||||
if (type_is_eltwise_linear(next_type)) {
|
||||
p_ops.get_params_eltwise(next_idx, next_scale, next_alg, next_alpha, next_beta);
|
||||
|
||||
if (next_beta == 0)
|
||||
can_optimize_eltw_and_sum = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to optimize eltwise (any) + sum + eltwise_linear (with beta = 0) chain of operations
|
||||
if (can_optimize_eltw_and_sum) {
|
||||
p_ops.get_params_sum(cur_idx, sum_scale, data_type);
|
||||
p_ops.get_params_eltwise(prev_idx, eltw_scale, alg, alpha, beta);
|
||||
|
||||
dnnl::post_ops eltw_p_op_prev, sum_p_op;
|
||||
|
||||
eltw_p_op_prev.append_eltwise(eltw_scale * next_alpha * next_scale, alg, alpha, beta);
|
||||
sum_p_op.append_sum(sum_scale * next_alpha, data_type);
|
||||
|
||||
add_post_op(prev_type, eltw_p_op_prev, optimized_p_ops, 0);
|
||||
add_post_op(cur_type, sum_p_op, optimized_p_ops, 0);
|
||||
|
||||
// Marked current, previous and next operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
|
||||
cur_post_ops[next_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
// Set the flag if extra optimizations checking is needed
|
||||
if (next_post_op_idx < post_ops_size - 1) {
|
||||
if (type_is_eltwise_linear(cur_post_ops[next_post_op_idx + 1].op_type) ||
|
||||
type_is_optimized_eltwise(cur_post_ops[next_post_op_idx + 1].op_type)) {
|
||||
optimization_is_completed = true;
|
||||
}
|
||||
}
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
} else if (eltw_and_scale) {
|
||||
dnnl::algorithm alg;
|
||||
float eltw_scale, alpha, beta;
|
||||
|
||||
cldnn::program_node& prev_node = get_dependency(cur_post_ops[prev_post_op_idx].mem_dep);
|
||||
|
||||
p_ops.get_params_eltwise(cur_idx, eltw_scale, alg, alpha, beta);
|
||||
|
||||
// Eltwise can be inserted into the output_scale if cur_beta is equal to 0.f
|
||||
if (beta == 0.f && prev_node.get_output_layout().data_type == data_types::f32) {
|
||||
memory::ptr prev_scale_mem_ptr = prev_node.as<data>().get_attached_memory_ptr();
|
||||
if (prev_scale_mem_ptr == nullptr)
|
||||
throw std::runtime_error("OneDNN post-ops optimization error: nonexistent node for eltw + scale");
|
||||
auto& stream = prev_scale_mem_ptr->get_engine()->get_program_stream();
|
||||
mem_lock<float, mem_lock_type::read_write> eltw_and_scale_lock(prev_scale_mem_ptr, stream);
|
||||
|
||||
size_t prev_scale_mem_size = prev_node.get_output_layout().count();
|
||||
|
||||
// Update all scale coefficients
|
||||
for (size_t data_idx = 0; data_idx < prev_scale_mem_size; data_idx++) {
|
||||
eltw_and_scale_lock[data_idx] *= alpha * eltw_scale;
|
||||
}
|
||||
|
||||
// Marked current eltwise operation as 'optimized' (it will be ignored on the next iteration of cycle)
|
||||
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
||||
|
||||
cur_ops_pair_is_optimized = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no optimizations have been applied then copy post-op info into the new optimized_p_ops structure
|
||||
if (!(has_out_scales(attr) && prev_post_op_idx == 0) && !cur_ops_pair_is_optimized) {
|
||||
add_post_op(prev_type, p_ops, optimized_p_ops, prev_idx);
|
||||
}
|
||||
|
||||
if (cur_post_op_idx == post_ops_size - 1 && !cur_ops_pair_is_optimized) {
|
||||
add_post_op(cur_type, p_ops, optimized_p_ops, cur_idx);
|
||||
optimization_done = true;
|
||||
} else if (cur_post_ops[cur_post_op_idx].op_type != onednn_post_op_type::optimized) {
|
||||
cur_post_op_idx++;
|
||||
prev_post_op_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
optimization_is_completed = !optimization_is_completed;
|
||||
|
||||
add_onednn_fused_primitives(cur_post_ops);
|
||||
|
||||
return optimized_p_ops;
|
||||
}
|
||||
|
||||
|
||||
void program_node::init_onednn_primitive_attributes() {
|
||||
const std::vector<fused_primitive_desc>& cldnn_post_ops = get_fused_primitives();
|
||||
auto attrs = std::make_shared<dnnl::primitive_attr>();
|
||||
dnnl::post_ops post_ops;
|
||||
size_t memory_offset = 0;
|
||||
|
||||
// Create onednn post-ops list related to the current node
|
||||
std::vector<fused_primitive_desc_onednn> fused_ops;
|
||||
|
||||
// Added this for debug purposes only
|
||||
size_t empty_mem = 0xff;
|
||||
|
||||
// Add information about post-operation into the list, update indices
|
||||
auto update_onednn_post_op_list = [&](onednn_post_op_type type, size_t m_dep) {
|
||||
fused_primitive_desc_onednn cur_op_desc = { type, memory_offset, m_dep };
|
||||
fused_ops.push_back(cur_op_desc);
|
||||
|
||||
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
|
||||
type == onednn_post_op_type::binary_mul ||
|
||||
type == onednn_post_op_type::binary_max ||
|
||||
type == onednn_post_op_type::binary_min ||
|
||||
type == onednn_post_op_type::scale ||
|
||||
type == onednn_post_op_type::sum;
|
||||
if (has_memory_buffers)
|
||||
memory_offset++;
|
||||
};
|
||||
|
||||
for (size_t idx = 0; idx < cldnn_post_ops.size(); idx++) {
|
||||
auto node = cldnn_post_ops[idx].node;
|
||||
|
||||
if (node->is_type<activation>()) {
|
||||
auto fused_desc = node->as<activation>().get_primitive();
|
||||
dnnl::algorithm alg = onednn::convert_activation_func(fused_desc->activation_function);
|
||||
post_ops.append_eltwise(1.0f, alg, fused_desc->additional_params.a, fused_desc->additional_params.b);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||
} else if (node->is_type<eltwise>()) {
|
||||
auto& e_node = node->as<eltwise>();
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
auto in = get_dependency(dep_idx).get_output_layout();
|
||||
|
||||
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
|
||||
if (e_node.get_primitive()->needs_onednn_sum_post_op(in)) {
|
||||
post_ops.append_sum(1.0f, onednn::convert_data_type(in.data_type));
|
||||
update_onednn_post_op_list(onednn_post_op_type::sum, dep_idx);
|
||||
} else {
|
||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
|
||||
}
|
||||
} else {
|
||||
if (in.size.spatial[0] > 1 || in.size.spatial[1] > 1 || in.size.batch[0] > 1)
|
||||
throw std::runtime_error("Unsupported eltwise mode for fused onednn op");
|
||||
if (idx == 0 && !has_out_scales(attrs)) {
|
||||
int mask = in.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
|
||||
} else {
|
||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx);
|
||||
}
|
||||
}
|
||||
} else if (node->is_type<quantize>()) {
|
||||
auto& q_node = node->as<quantize>();
|
||||
auto dep_idx = cldnn_post_ops[idx].dep_start_idx;
|
||||
|
||||
// ********************************* Common case with output range usage ********************************* //
|
||||
if (q_node.get_per_tensor_output_range() && q_node.get_output_lo_val() < q_node.get_output_hi_val()) {
|
||||
// 1. pre-scale & pre-shift
|
||||
{
|
||||
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_per_tensor_input_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_scale = get_dependency(dep_idx++).get_output_layout();
|
||||
if (idx == 0 && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
|
||||
is_type<convolution>() &&
|
||||
!data_type_traits::is_floating_point(get_dependency(0).get_output_layout().data_type)) {
|
||||
int mask = in_scale.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
|
||||
} else {
|
||||
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_pre_shift()) {
|
||||
if (q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_shift = get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. round
|
||||
auto out_dt = cldnn_post_ops[idx].output_layout.data_type;
|
||||
{
|
||||
bool output_type_is_int8 = out_dt == data_types::u8 || out_dt == data_types::i8;
|
||||
if (!output_type_is_int8) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. post-scale & post-shift
|
||||
{
|
||||
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
|
||||
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_need_post_scale()) {
|
||||
if (q_node.get_per_tensor_output_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_scale = get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_post_shift()) {
|
||||
if (q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_shift = get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. clamp
|
||||
{
|
||||
if (q_node.get_need_clamp()) {
|
||||
float out_lo = q_node.get_need_min_clamp() ? q_node.get_output_lo_val() : data_type_traits::min<float>(out_dt);
|
||||
float out_hi = q_node.get_need_max_clamp() ? q_node.get_output_hi_val() : data_type_traits::max<float>(out_dt);
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip, out_lo, out_hi);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_clip, empty_mem);
|
||||
}
|
||||
}
|
||||
// ********************************* Rare case with input range usage ********************************* //
|
||||
} else {
|
||||
// 1. clamp
|
||||
{
|
||||
if (q_node.get_need_clamp()) {
|
||||
auto in_lo = get_dependency(dep_idx++).get_output_layout();
|
||||
auto in_hi = get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::algorithm clamp_max = dnnl::algorithm::binary_max;
|
||||
dnnl::algorithm clamp_min = dnnl::algorithm::binary_min;
|
||||
dnnl::memory::desc in_lo_desc = onednn::layout_to_memory_desc(in_lo, dnnl::memory::format_tag::ab, true);
|
||||
dnnl::memory::desc in_hi_desc = onednn::layout_to_memory_desc(in_hi, dnnl::memory::format_tag::ab, true);
|
||||
|
||||
post_ops.append_binary(clamp_max, in_lo_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_max, dep_idx - 2);
|
||||
post_ops.append_binary(clamp_min, in_hi_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_min, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. pre-scale & pre-shift
|
||||
{
|
||||
if (q_node.get_per_tensor_input_scale() && q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_per_tensor_input_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_input_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_scale = get_dependency(dep_idx++).get_output_layout();
|
||||
if (idx == 0 && !q_node.get_need_clamp() && !has_out_scales(attrs) && in_scale.data_type == data_types::f32 &&
|
||||
is_type<convolution>() &&
|
||||
!data_type_traits::is_floating_point(get_dependency(0).get_output_layout().data_type)) {
|
||||
int mask = in_scale.count() > 1 ? 2 : 0;
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx - 1);
|
||||
} else {
|
||||
dnnl::memory::desc in_scale_desc = onednn::layout_to_memory_desc(in_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, in_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_pre_shift()) {
|
||||
if (q_node.get_per_tensor_input_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_input_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto in_shift = get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc in_shift_desc = onednn::layout_to_memory_desc(in_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. round
|
||||
{
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_round, empty_mem);
|
||||
}
|
||||
|
||||
// 4. post-scale & post-shift
|
||||
{
|
||||
if (q_node.get_need_post_scale() && q_node.get_need_post_shift() &&
|
||||
q_node.get_per_tensor_output_scale() && q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
if (q_node.get_need_post_scale()) {
|
||||
if (q_node.get_per_tensor_output_scale()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, q_node.get_output_scale_val(), 0.0f);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_scale = get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_scale_desc = onednn::layout_to_memory_desc(out_scale, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_mul, out_scale_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_mul, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (q_node.get_need_post_shift()) {
|
||||
if (q_node.get_per_tensor_output_shift()) {
|
||||
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f, q_node.get_output_shift_val());
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_linear, empty_mem);
|
||||
} else {
|
||||
auto out_shift = get_dependency(dep_idx++).get_output_layout();
|
||||
dnnl::memory::desc out_shift_desc = onednn::layout_to_memory_desc(out_shift, dnnl::memory::format_tag::ab, true);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, out_shift_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (node->is_type<reorder>()) {
|
||||
continue;
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported fused op of " + node->get_primitive()->type_string() + " type for oneDNN primitive");
|
||||
}
|
||||
}
|
||||
|
||||
if (cldnn_post_ops.size() && get_fused_activations_funcs().size())
|
||||
throw std::runtime_error("Unsupported mix of fused ops and activations");
|
||||
|
||||
for (size_t i = 0; i < get_fused_activations_funcs().size(); i++) {
|
||||
auto activation_type = get_fused_activations_funcs()[i];
|
||||
auto params = get_fused_activations_params()[i];
|
||||
dnnl::algorithm alg = onednn::convert_activation_func(activation_type);
|
||||
post_ops.append_eltwise(1.0f, alg, params.a, params.b);
|
||||
update_onednn_post_op_list(onednn_post_op_type::eltwise_act, empty_mem);
|
||||
}
|
||||
|
||||
// Trying to optimize more than 1 post-ops
|
||||
if (fused_ops.size() > 1) {
|
||||
dnnl::post_ops optimized_post_ops = post_ops;
|
||||
bool optimization_is_finished = false;
|
||||
|
||||
add_onednn_fused_primitives(fused_ops);
|
||||
|
||||
// Trying to combine multiplications and additions which are placed one after another.
|
||||
// We do it in the cycle because some optimization cases can be simplified again from time to time
|
||||
do {
|
||||
optimized_post_ops = try_optimize_post_ops(optimized_post_ops, attrs, optimization_is_finished);
|
||||
} while (!optimization_is_finished);
|
||||
|
||||
attrs->set_post_ops(optimized_post_ops);
|
||||
} else {
|
||||
// Set post-ops without any optimizations
|
||||
add_onednn_fused_primitives(fused_ops);
|
||||
attrs->set_post_ops(post_ops);
|
||||
}
|
||||
|
||||
add_onednn_attrs(attrs);
|
||||
}
|
||||
|
||||
#endif // ENABLE_ONEDNN_FOR_GPU
|
||||
|
Loading…
Reference in New Issue
Block a user