[GPU] Add substract post-op for onednn (#14947)

This commit is contained in:
Jade Cho 2023-01-06 17:18:25 +09:00 committed by GitHub
parent 1e71bdd1d4
commit c1f6da31b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 1 deletions

View File

@ -466,6 +466,7 @@ protected:
}
case onednn_post_op_type::binary_add:
case onednn_post_op_type::binary_sub:
case onednn_post_op_type::binary_mul:
case onednn_post_op_type::binary_max:
case onednn_post_op_type::binary_min:

View File

@ -55,6 +55,7 @@ enum class onednn_post_op_type : uint32_t {
eltwise_round,
binary_mul,
binary_add,
binary_sub,
binary_max,
binary_min,
binary_relu,
@ -76,6 +77,7 @@ static inline std::ostream& operator<< (std::ostream& os, onednn_post_op_type& t
case onednn_post_op_type::eltwise_round: os << "eltwise_round"; break;
case onednn_post_op_type::binary_mul: os << "binary_mul"; break;
case onednn_post_op_type::binary_add: os << "binary_add"; break;
case onednn_post_op_type::binary_sub: os << "binary_sub"; break;
case onednn_post_op_type::binary_max: os << "binary_max"; break;
case onednn_post_op_type::binary_min: os << "binary_min"; break;
case onednn_post_op_type::binary_relu: os << "binary_relu"; break;

View File

@ -120,6 +120,7 @@ inline std::string onednn_post_op_type_to_str(onednn_post_op_type type) {
case onednn_post_op_type::eltwise_round: return "eltwise_round";
case onednn_post_op_type::binary_mul: return "binary_mul";
case onednn_post_op_type::binary_add: return "binary_add";
case onednn_post_op_type::binary_sub: return "binary_add";
case onednn_post_op_type::binary_max: return "binary_max";
case onednn_post_op_type::binary_min: return "binary_min";
case onednn_post_op_type::binary_relu: return "binary_relu";

View File

@ -520,6 +520,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
}
case onednn_post_op_type::binary_add:
case onednn_post_op_type::binary_sub:
case onednn_post_op_type::binary_mul:
case onednn_post_op_type::binary_max:
case onednn_post_op_type::binary_min:
@ -997,6 +998,7 @@ void program_node::init_onednn_primitive_attributes() {
fused_ops.push_back(cur_op_desc);
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
type == onednn_post_op_type::binary_sub ||
type == onednn_post_op_type::binary_mul ||
type == onednn_post_op_type::binary_max ||
type == onednn_post_op_type::binary_min ||
@ -1054,7 +1056,11 @@ void program_node::init_onednn_primitive_attributes() {
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
}
} else {
} else if (desc.typed_desc<eltwise>()->mode == eltwise_mode::sub) {
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in);
post_ops.append_binary(dnnl::algorithm::binary_sub, in_desc);
update_onednn_post_op_list(onednn_post_op_type::binary_sub, dep_idx);
} else if (desc.typed_desc<eltwise>()->mode == eltwise_mode::prod) {
auto input_datatype = get_dependency(0).get_output_layout().data_type;
// convolution using post-op output scales can only be used when i8/u8 input (which use integer accumulator)
bool cant_use_output_scales =
@ -1073,6 +1079,11 @@ void program_node::init_onednn_primitive_attributes() {
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
}
} else {
std::stringstream error_msg;
error_msg << "Unsupported eltwise mode: " << static_cast<int>(desc.typed_desc<eltwise>()->mode) << ". ";
error_msg << desc.desc->id << " is fused node of " + this->id() + ".";
OPENVINO_ASSERT(false, error_msg.str());
}
} else if (desc.is_type<quantize>()) {
auto dep_idx = desc.dep_start_idx;