[GPU] Fix a bug in post-operation optimization (#9443)

This commit is contained in:
Jade Cho
2021-12-28 23:48:29 +09:00
committed by GitHub
parent c6bc4d0045
commit da1261a1d8
4 changed files with 51 additions and 9 deletions

View File

@@ -128,7 +128,10 @@ protected:
case onednn_post_op_type::sum:
case onednn_post_op_type::optimized_sum:
case onednn_post_op_type::optimized_eltwise:
case onednn_post_op_type::optimized_eltwise_linear:
case onednn_post_op_type::optimized_eltwise_act:
case onednn_post_op_type::optimized_eltwise_round:
case onednn_post_op_type::optimized_eltwise_clip:
{
break;
}

View File

@@ -47,7 +47,10 @@ enum class onednn_post_op_type : uint32_t {
scale,
sum,
optimized,
optimized_eltwise,
optimized_eltwise_act,
optimized_eltwise_clip,
optimized_eltwise_linear,
optimized_eltwise_round,
optimized_sum
};

View File

@@ -289,7 +289,10 @@ inline std::string onednn_post_op_type_to_str(onednn_post_op_type type) {
case onednn_post_op_type::scale: return "scale";
case onednn_post_op_type::sum: return "sum";
case onednn_post_op_type::optimized: return "optimized";
case onednn_post_op_type::optimized_eltwise: return "optimized_eltwise";
case onednn_post_op_type::optimized_eltwise_act: return "optimized_eltwise_act";
case onednn_post_op_type::optimized_eltwise_linear: return "optimized_eltwise_linear";
case onednn_post_op_type::optimized_eltwise_clip: return "optimized_eltwise_clip";
case onednn_post_op_type::optimized_eltwise_round: return "optimized_eltwise_round";
case onednn_post_op_type::optimized_sum: return "optimized_sum";
default: return "unknown";
}

View File

@@ -378,7 +378,10 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
case onednn_post_op_type::optimized:
case onednn_post_op_type::optimized_sum:
case onednn_post_op_type::optimized_eltwise:
case onednn_post_op_type::optimized_eltwise_act:
case onednn_post_op_type::optimized_eltwise_linear:
case onednn_post_op_type::optimized_eltwise_clip:
case onednn_post_op_type::optimized_eltwise_round:
{
// Current operation already has been optimized => don't need extra actions
break;
@@ -392,7 +395,10 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
// Check that post-op type is any optimized
auto type_is_any_optimized = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::optimized || type == onednn_post_op_type::optimized_sum ||
type == onednn_post_op_type::optimized_eltwise;
type == onednn_post_op_type::optimized_eltwise_act ||
type == onednn_post_op_type::optimized_eltwise_linear ||
type == onednn_post_op_type::optimized_eltwise_clip ||
type == onednn_post_op_type::optimized_eltwise_round;
};
// Check that post-op type is eltwise
@@ -409,13 +415,28 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
// Simple post-op type checks
auto type_is_optimized = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized; };
auto type_is_eltwise_linear = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::eltwise_linear; };
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_eltwise; };
auto type_is_optimized_eltwise = [](onednn_post_op_type type) -> bool {
return type == onednn_post_op_type::optimized_eltwise_act || type == onednn_post_op_type::optimized_eltwise_linear ||
type == onednn_post_op_type::optimized_eltwise_round || type == onednn_post_op_type::optimized_eltwise_clip;
};
auto type_is_binary_add = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_add; };
auto type_is_binary_mul = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::binary_mul; };
auto type_is_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::sum; };
auto type_is_optimized_sum = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::optimized_sum; };
auto type_is_scale = [](onednn_post_op_type type) -> bool { return type == onednn_post_op_type::scale; };
auto get_eltwise_type = [](onednn_post_op_type type) {
switch (type) {
case onednn_post_op_type::optimized_eltwise_act: return onednn_post_op_type::eltwise_act;
case onednn_post_op_type::optimized_eltwise_clip: return onednn_post_op_type::eltwise_clip;
case onednn_post_op_type::optimized_eltwise_linear: return onednn_post_op_type::eltwise_linear;
case onednn_post_op_type::optimized_eltwise_round: return onednn_post_op_type::eltwise_round;
default:
throw std::runtime_error("Unsupported optimized eltwise post-operation type");
break;
}
};
auto& cur_post_ops = get_fused_primitives_onednn();
size_t cur_post_op_idx = 1;
@@ -427,7 +448,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
if (type_is_optimized_sum(cur_post_ops[post_op_idx].op_type))
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::sum;
else if (type_is_optimized_eltwise(cur_post_ops[post_op_idx].op_type))
cur_post_ops[post_op_idx].op_type = onednn_post_op_type::eltwise_linear;
cur_post_ops[post_op_idx].op_type = get_eltwise_type(cur_post_ops[post_op_idx].op_type);
else if (type_is_optimized(cur_post_ops[post_op_idx].op_type))
cur_post_ops.erase(cur_post_ops.begin() + post_op_idx);
}
@@ -435,6 +456,18 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
// Get post-ops size for current node
auto post_ops_size = cur_post_ops.size();
auto get_optimized_eltwise_type = [](onednn_post_op_type type) {
switch (type) {
case onednn_post_op_type::eltwise_linear: return onednn_post_op_type::optimized_eltwise_linear;
case onednn_post_op_type::eltwise_act: return onednn_post_op_type::optimized_eltwise_act;
case onednn_post_op_type::eltwise_round: return onednn_post_op_type::optimized_eltwise_round;
case onednn_post_op_type::eltwise_clip: return onednn_post_op_type::optimized_eltwise_clip;
default:
throw std::runtime_error("Unsupported optimized eltwise post-operation type");
break;
}
};
// Try to combine pairs of arithmetic post-ops (adds and muls) into one operation inside this cycle
while (!optimization_done) {
auto cur_type = cur_post_ops[cur_post_op_idx].op_type;
@@ -516,7 +549,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
if (eltw_linear_and_eltw_linear || eltw_linear_and_eltw_non_linear) {
// Marked current and previous eltwise operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized;
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
cur_post_ops[prev_post_op_idx].op_type = get_optimized_eltwise_type(prev_type);
// Set the flag if extra optimizations checking is needed
if (cur_post_op_idx < post_ops_size - 1) {
@@ -649,7 +682,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
add_post_op(cur_type, sum_p_op, optimized_p_ops, 0);
// Marked current, previous and next operations as 'optimized' (they will be ignored on the next iteration of cycle)
cur_post_ops[prev_post_op_idx].op_type = onednn_post_op_type::optimized_eltwise;
cur_post_ops[prev_post_op_idx].op_type = get_optimized_eltwise_type(prev_type);
cur_post_ops[cur_post_op_idx].op_type = onednn_post_op_type::optimized_sum;
cur_post_ops[next_post_op_idx].op_type = onednn_post_op_type::optimized;