[GPU] Add substract post-op for onednn (#14947)
This commit is contained in:
parent
1e71bdd1d4
commit
c1f6da31b6
@ -466,6 +466,7 @@ protected:
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_add:
|
||||
case onednn_post_op_type::binary_sub:
|
||||
case onednn_post_op_type::binary_mul:
|
||||
case onednn_post_op_type::binary_max:
|
||||
case onednn_post_op_type::binary_min:
|
||||
|
@ -55,6 +55,7 @@ enum class onednn_post_op_type : uint32_t {
|
||||
eltwise_round,
|
||||
binary_mul,
|
||||
binary_add,
|
||||
binary_sub,
|
||||
binary_max,
|
||||
binary_min,
|
||||
binary_relu,
|
||||
@ -76,6 +77,7 @@ static inline std::ostream& operator<< (std::ostream& os, onednn_post_op_type& t
|
||||
case onednn_post_op_type::eltwise_round: os << "eltwise_round"; break;
|
||||
case onednn_post_op_type::binary_mul: os << "binary_mul"; break;
|
||||
case onednn_post_op_type::binary_add: os << "binary_add"; break;
|
||||
case onednn_post_op_type::binary_sub: os << "binary_sub"; break;
|
||||
case onednn_post_op_type::binary_max: os << "binary_max"; break;
|
||||
case onednn_post_op_type::binary_min: os << "binary_min"; break;
|
||||
case onednn_post_op_type::binary_relu: os << "binary_relu"; break;
|
||||
|
@ -120,6 +120,7 @@ inline std::string onednn_post_op_type_to_str(onednn_post_op_type type) {
|
||||
case onednn_post_op_type::eltwise_round: return "eltwise_round";
|
||||
case onednn_post_op_type::binary_mul: return "binary_mul";
|
||||
case onednn_post_op_type::binary_add: return "binary_add";
|
||||
case onednn_post_op_type::binary_sub: return "binary_add";
|
||||
case onednn_post_op_type::binary_max: return "binary_max";
|
||||
case onednn_post_op_type::binary_min: return "binary_min";
|
||||
case onednn_post_op_type::binary_relu: return "binary_relu";
|
||||
|
@ -520,6 +520,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const
|
||||
}
|
||||
|
||||
case onednn_post_op_type::binary_add:
|
||||
case onednn_post_op_type::binary_sub:
|
||||
case onednn_post_op_type::binary_mul:
|
||||
case onednn_post_op_type::binary_max:
|
||||
case onednn_post_op_type::binary_min:
|
||||
@ -997,6 +998,7 @@ void program_node::init_onednn_primitive_attributes() {
|
||||
fused_ops.push_back(cur_op_desc);
|
||||
|
||||
auto has_memory_buffers = type == onednn_post_op_type::binary_add ||
|
||||
type == onednn_post_op_type::binary_sub ||
|
||||
type == onednn_post_op_type::binary_mul ||
|
||||
type == onednn_post_op_type::binary_max ||
|
||||
type == onednn_post_op_type::binary_min ||
|
||||
@ -1054,7 +1056,11 @@ void program_node::init_onednn_primitive_attributes() {
|
||||
post_ops.append_binary(dnnl::algorithm::binary_add, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx);
|
||||
}
|
||||
} else {
|
||||
} else if (desc.typed_desc<eltwise>()->mode == eltwise_mode::sub) {
|
||||
dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in);
|
||||
post_ops.append_binary(dnnl::algorithm::binary_sub, in_desc);
|
||||
update_onednn_post_op_list(onednn_post_op_type::binary_sub, dep_idx);
|
||||
} else if (desc.typed_desc<eltwise>()->mode == eltwise_mode::prod) {
|
||||
auto input_datatype = get_dependency(0).get_output_layout().data_type;
|
||||
// convolution using post-op output scales can only be used when i8/u8 input (which use integer accumulator)
|
||||
bool cant_use_output_scales =
|
||||
@ -1073,6 +1079,11 @@ void program_node::init_onednn_primitive_attributes() {
|
||||
attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
|
||||
update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx);
|
||||
}
|
||||
} else {
|
||||
std::stringstream error_msg;
|
||||
error_msg << "Unsupported eltwise mode: " << static_cast<int>(desc.typed_desc<eltwise>()->mode) << ". ";
|
||||
error_msg << desc.desc->id << " is fused node of " + this->id() + ".";
|
||||
OPENVINO_ASSERT(false, error_msg.str());
|
||||
}
|
||||
} else if (desc.is_type<quantize>()) {
|
||||
auto dep_idx = desc.dep_start_idx;
|
||||
|
Loading…
Reference in New Issue
Block a user