diff --git a/src/plugins/intel_gpu/src/graph/impls/onednn/primitive_onednn_base.h b/src/plugins/intel_gpu/src/graph/impls/onednn/primitive_onednn_base.h index 398546c4774..617ea6ee1d8 100644 --- a/src/plugins/intel_gpu/src/graph/impls/onednn/primitive_onednn_base.h +++ b/src/plugins/intel_gpu/src/graph/impls/onednn/primitive_onednn_base.h @@ -466,6 +466,7 @@ protected: } case onednn_post_op_type::binary_add: + case onednn_post_op_type::binary_sub: case onednn_post_op_type::binary_mul: case onednn_post_op_type::binary_max: case onednn_post_op_type::binary_min: diff --git a/src/plugins/intel_gpu/src/graph/include/fused_primitive_desc.h b/src/plugins/intel_gpu/src/graph/include/fused_primitive_desc.h index 822cf499b65..5262ec9df45 100644 --- a/src/plugins/intel_gpu/src/graph/include/fused_primitive_desc.h +++ b/src/plugins/intel_gpu/src/graph/include/fused_primitive_desc.h @@ -55,6 +55,7 @@ enum class onednn_post_op_type : uint32_t { eltwise_round, binary_mul, binary_add, + binary_sub, binary_max, binary_min, binary_relu, @@ -76,6 +77,7 @@ static inline std::ostream& operator<< (std::ostream& os, onednn_post_op_type& t case onednn_post_op_type::eltwise_round: os << "eltwise_round"; break; case onednn_post_op_type::binary_mul: os << "binary_mul"; break; case onednn_post_op_type::binary_add: os << "binary_add"; break; + case onednn_post_op_type::binary_sub: os << "binary_sub"; break; case onednn_post_op_type::binary_max: os << "binary_max"; break; case onednn_post_op_type::binary_min: os << "binary_min"; break; case onednn_post_op_type::binary_relu: os << "binary_relu"; break; diff --git a/src/plugins/intel_gpu/src/graph/include/to_string_utils.h b/src/plugins/intel_gpu/src/graph/include/to_string_utils.h index 3e57954f36d..9e80c5501fe 100644 --- a/src/plugins/intel_gpu/src/graph/include/to_string_utils.h +++ b/src/plugins/intel_gpu/src/graph/include/to_string_utils.h @@ -120,6 +120,7 @@ inline std::string onednn_post_op_type_to_str(onednn_post_op_type type) { case onednn_post_op_type::eltwise_round: return "eltwise_round"; case onednn_post_op_type::binary_mul: return "binary_mul"; case onednn_post_op_type::binary_add: return "binary_add"; + case onednn_post_op_type::binary_sub: return "binary_add"; case onednn_post_op_type::binary_max: return "binary_max"; case onednn_post_op_type::binary_min: return "binary_min"; case onednn_post_op_type::binary_relu: return "binary_relu"; diff --git a/src/plugins/intel_gpu/src/graph/program_node.cpp b/src/plugins/intel_gpu/src/graph/program_node.cpp index 640c520f540..c89737c929f 100644 --- a/src/plugins/intel_gpu/src/graph/program_node.cpp +++ b/src/plugins/intel_gpu/src/graph/program_node.cpp @@ -520,6 +520,7 @@ dnnl::post_ops program_node::try_optimize_post_ops(dnnl::post_ops& p_ops, const } case onednn_post_op_type::binary_add: + case onednn_post_op_type::binary_sub: case onednn_post_op_type::binary_mul: case onednn_post_op_type::binary_max: case onednn_post_op_type::binary_min: @@ -997,6 +998,7 @@ void program_node::init_onednn_primitive_attributes() { fused_ops.push_back(cur_op_desc); auto has_memory_buffers = type == onednn_post_op_type::binary_add || + type == onednn_post_op_type::binary_sub || type == onednn_post_op_type::binary_mul || type == onednn_post_op_type::binary_max || type == onednn_post_op_type::binary_min || @@ -1054,7 +1056,11 @@ void program_node::init_onednn_primitive_attributes() { post_ops.append_binary(dnnl::algorithm::binary_add, in_desc); update_onednn_post_op_list(onednn_post_op_type::binary_add, dep_idx); } - } else { + } else if (desc.typed_desc()->mode == eltwise_mode::sub) { + dnnl::memory::desc in_desc = onednn::layout_to_memory_desc(in); + post_ops.append_binary(dnnl::algorithm::binary_sub, in_desc); + update_onednn_post_op_list(onednn_post_op_type::binary_sub, dep_idx); + } else if (desc.typed_desc()->mode == eltwise_mode::prod) { auto input_datatype = get_dependency(0).get_output_layout().data_type; // convolution using post-op output scales can only be used when i8/u8 input (which use integer accumulator) bool cant_use_output_scales = @@ -1073,6 +1079,11 @@ void program_node::init_onednn_primitive_attributes() { attrs->set_output_scales(mask, {DNNL_RUNTIME_F32_VAL}); update_onednn_post_op_list(onednn_post_op_type::scale, dep_idx); } + } else { + std::stringstream error_msg; + error_msg << "Unsupported eltwise mode: " << static_cast(desc.typed_desc()->mode) << ". "; + error_msg << desc.desc->id << " is fused node of " + this->id() + "."; + OPENVINO_ASSERT(false, error_msg.str()); } } else if (desc.is_type()) { auto dep_idx = desc.dep_start_idx;