|
|
|
|
@@ -378,7 +378,10 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
|
|
|
|
|
|
|
|
|
case onednn_post_op_type::optimized:
|
|
|
|
|
case onednn_post_op_type::optimized_sum:
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise:
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_act:
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_linear:
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_clip:
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_round:
|
|
|
|
|
{
|
|
|
|
|
// Current operation already has been optimized => don't need extra actions
|
|
|
|
|
break;
|
|
|
|
|
@@ -392,7 +395,10 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
|
|
|
|
// Check that post-op type is any optimized
|
|
|
|
|
auto type_is_any_optimized = [](onednn_post_op_type type) -> bool {
|
|
|
|
|
return type == onednn_post_op_type::optimized || type == onednn_post_op_type::optimized_sum ||
|
|
|
|
|
type == onednn_post_op_type::optimized_eltwise;
|
|
|
|
|
type == onednn_post_op_type::optimized_eltwise_act ||
|
|
|
|
|
type == onednn_post_op_type::optimized_eltwise_linear ||
|
|
|
|
|
type == onednn_post_op_type::optimized_eltwise_clip ||
|
|
|
|
|
type == onednn_post_op_type::optimized_eltwise_round;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Check that post-op type is eltwise
|
|
|
|
|
@@ -409,13 +415,28 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
|
|
|
|
// Simple post-op type checks
|
|
|
|
|
auto type_is_optimized = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized; };
|
|
|
|
|
auto type_is_eltwise_linear = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::eltwise_linear; };
|
|
|
|
|
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_eltwise; };
|
|
|
|
|
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool {
|
|
|
|
|
return type == onednn_post_op_type::optimized_eltwise_act || type == onednn_post_op_type::optimized_eltwise_linear ||
|
|
|
|
|
type == onednn_post_op_type::optimized_eltwise_round || type == onednn_post_op_type::optimized_eltwise_clip;
|
|
|
|
|
};
|
|
|
|
|
auto type_is_binary_add = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_add; };
|
|
|
|
|
auto type_is_binary_mul = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_mul; };
|
|
|
|
|
auto type_is_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::sum; };
|
|
|
|
|
auto type_is_optimized_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_sum; };
|
|
|
|
|
auto type_is_scale = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::scale; };
|
|
|
|
|
|
|
|
|
|
auto get_eltwise_type = [](onednn_post_op_type type) {
|
|
|
|
|
switch (type) {
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_act: return onednn_post_op_type::eltwise_act;
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_clip: return onednn_post_op_type::eltwise_clip;
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_linear: return onednn_post_op_type::eltwise_linear;
|
|
|
|
|
case onednn_post_op_type::optimized_eltwise_round: return onednn_post_op_type::eltwise_round;
|
|
|
|
|
default:
|
|
|
|
|
throw std::runtime_error("Unsupported optimized eltwise post-operation type");
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto& cur_post_ops = get_fused_primitives_onednn();
|
|
|
|
|
|
|
|
|
|
size_t cur_post_op_idx = 1;
|
|
|
|
|
@@ -427,7 +448,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
|
|
|
|
if (type_is_optimized_sum(cur_post_ops[post_op_idx].op_type))
|
|
|
|
|
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
|
|
|
|
|
else if (type_is_optimized_eltwise(cur_post_ops[post_op_idx].op_type))
|
|
|
|
|
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
|
|
|
|
|
cur_post_ops[post_op_idx].op_type = get_eltwise_type(cur_post_ops[post_op_idx].op_type);
|
|
|
|
|
else if (type_is_optimized(cur_post_ops[post_op_idx].op_type))
|
|
|
|
|
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
|
|
|
|
|
}
|
|
|
|
|
@@ -435,6 +456,18 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
|
|
|
|
// Get post-ops size for current node
|
|
|
|
|
auto post_ops_size = cur_post_ops.size();
|
|
|
|
|
|
|
|
|
|
auto get_optimized_eltwise_type = [](onednn_post_op_type type) {
|
|
|
|
|
switch (type) {
|
|
|
|
|
case onednn_post_op_type::eltwise_linear: return onednn_post_op_type::optimized_eltwise_linear;
|
|
|
|
|
case onednn_post_op_type::eltwise_act: return onednn_post_op_type::optimized_eltwise_act;
|
|
|
|
|
case onednn_post_op_type::eltwise_round: return onednn_post_op_type::optimized_eltwise_round;
|
|
|
|
|
case onednn_post_op_type::eltwise_clip: return onednn_post_op_type::optimized_eltwise_clip;
|
|
|
|
|
default:
|
|
|
|
|
throw std::runtime_error("Unsupported optimized eltwise post-operation type");
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
|
|
|
|
|
while (!optimization_done) {
|
|
|
|
|
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
|
|
|
|
|
@@ -516,7 +549,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
|
|
|
|
if (eltw_linear_and_eltw_linear || eltw_linear_and_eltw_non_linear) {
|
|
|
|
|
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
|
|
|
|
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
|
|
|
|
|
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
|
|
|
|
|
cur_post_ops[prev_post_op_idx].op_type = get_optimized_eltwise_type(prev_type);
|
|
|
|
|
|
|
|
|
|
// Set the flag if extra optimizations checking is needed
|
|
|
|
|
if (cur_post_op_idx < post_ops_size - 1) {
|
|
|
|
|
@@ -649,7 +682,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
|
|
|
|
add_post_op(cur_type, sum_p_op, optimized_p_ops, 0);
|
|
|
|
|
|
|
|
|
|
// Marked current, previous and next operations as 'optimized' (they will be ignored on the next iteration of cycle)
|
|
|
|
|
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
|
|
|
|
|
cur_post_ops[prev_post_op_idx].op_type = get_optimized_eltwise_type(prev_type);
|
|
|
|
|
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
|
|
|
|
|
cur_post_ops[next_post_op_idx].op_type = onednn_post_op_type::optimized;
|
|
|
|
|
|
|
|
|
|
|