Broadcast for post ops enable enable onednngemm (#16074)

* [GPU] Add data broadcasting for OneDNN binary ops for Gemm primitive
* Based on https://github.com/openvinotoolkit/openvino/pull/15790 and enable onednn gemm from support multiple users and non constant input.

--------

Signed-off-by: hyunback <hyunback.kim@intel.com>
Co-authored-by: Sergey Shlyapnikov <sergey.shlyapnikov@intel.com>
This commit is contained in:
hyunback kim
2023-03-08 13:55:51 +09:00
committed by GitHub
parent 681faadce3
commit a9cbccd829
6 changed files with 102 additions and 8 deletions

View File

@@ -173,6 +173,11 @@ inline bool any_not_zero(const std::vector<T> vec) {
return std::any_of(vec.begin(), vec.end(), [](const T& val) { return val != 0; });
}
template <typename T>
inline bool one_of(const T& val, const std::vector<T>& vec) {
return std::any_of(vec.begin(), vec.end(), [&val](const T& v) { return v == val; });
}
// Helpers to get string for types that have operator<< defined
template <typename T>
inline std::string to_string(const T& v) {

View File

@@ -7,6 +7,7 @@
#include "layout_optimizer.h"
#include "intel_gpu/graph/program.hpp"
#include "intel_gpu/runtime/debug_configuration.hpp"
#include "intel_gpu/runtime/utils.hpp"
#include "program_helpers.h"
#include "binary_convolution_inst.h"
#include "mvn_inst.h"
@@ -14,6 +15,12 @@
#include "pooling_inst.h"
#include "reshape_inst.h"
#ifdef ENABLE_ONEDNN_FOR_GPU
#include "gemm_inst.h"
#include "broadcast_inst.h"
#include <impls/onednn/utils.hpp>
#endif
#include <vector>
#include <memory>
#include <list>
@@ -958,4 +965,44 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
}
}
}
// WA for OneDNN binary add fusions: we need to broadcast batch dimension to avoid situation with
// batch dimension mismatch in OneDNN tensor descriptors as follow:
// * Gemm output shape: (b,f,y,x) -> OneDNN shape: (b*f,y,x)
// * Gemm fused op shape: (1,f,y,x) -> OneDNN shape: (1*f,y,x)
// If batch dimension of gemm output is not equal to 1, then OneDNN will not be able to broadcast fused op data
// correctly and we need to do it manually
#ifdef ENABLE_ONEDNN_FOR_GPU
for (auto& node : p.get_processing_order()) {
if (node->is_type<gemm>() && node->get_preferred_impl_type() == impl_types::onednn) {
for (const auto& fused_prim : node->get_fused_primitives()) {
if (fused_prim.is_type<eltwise>() &&
one_of(fused_prim.typed_desc<eltwise>()->mode, {eltwise_mode::sum, eltwise_mode::sub, eltwise_mode::prod})) {
auto& data = node->get_dependency(fused_prim.dep_start_idx);
auto gemm_layout = node->get_output_layout();
auto gemm_dims = onednn::convert_gemm_tensor(gemm_layout.get_tensor(),
cldnn::format::dimension(gemm_layout.format),
false);
auto data_layout = data.get_output_layout();
auto data_dims = onednn::convert_gemm_tensor(data_layout.get_tensor(),
cldnn::format::dimension(data_layout.format),
false);
if (gemm_dims[0] == data_dims[0])
continue;
static size_t idx = 0;
const auto prim_id = "broadcast:" + data.id() + "_broadcasted" + std::to_string(idx++);
auto broadcast_prim = std::make_shared<cldnn::broadcast>(prim_id, cldnn::input_info(data.id()), gemm_layout.get_shape(), ov::AxisSet{});
auto& broadcast_node = p.get_or_create(broadcast_prim);
p.add_intermediate(broadcast_node, *node, fused_prim.dep_start_idx, true);
broadcast_node.recalc_output_layouts(false);
}
}
}
}
#endif
}

View File

@@ -1531,10 +1531,6 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
if (node.is_type<fully_connected>()) {
if (!is_node_for_onednn(node.as<fully_connected>()))
impl_candidate = impl_types::ocl;
} else {
if (node.is_dynamic()) {
impl_candidate = impl_types::ocl;
}
}
preferred_impl = impl_candidate;

View File

@@ -97,9 +97,9 @@ add_fusing_type onednn_add_fusing_helpers::get_add_fusing_type(
if (!desc.is_type<eltwise>()) {
return add_fusing_type::not_supported;
}
if (desc.typed_desc<eltwise>()->mode != eltwise_mode::sum) {
return add_fusing_type::not_supported;
}
if (desc.typed_desc<eltwise>()->mode != eltwise_mode::sum) {
return add_fusing_type::not_supported;
}
auto& dep_node = p_node.get_dependency(desc.dep_start_idx);
auto p_layout = p_node.get_output_layout();

View File

@@ -971,7 +971,7 @@ void program_node::init_onednn_primitive_attributes() {
update_onednn_post_op_list(op_type, dep_idx);
} else if (is_type<gemm>()) {
size_t rank = cldnn::format::dimension(in.format);
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in.batch() > 1);
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in.batch() == 1);
dnnl::memory::data_type dt = onednn::convert_data_type(in.data_type);
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format(dims);
post_ops.append_binary(alg, dnnl::memory::desc(dims, dt, fmt));

View File

@@ -10,10 +10,12 @@
#include <intel_gpu/primitives/eltwise.hpp>
#include <intel_gpu/primitives/gemm.hpp>
#include <intel_gpu/primitives/data.hpp>
#include <intel_gpu/runtime/tensor.hpp>
#include <cmath>
using namespace cldnn;
using namespace ::details;
using namespace ::tests;
namespace {
@@ -31,6 +33,8 @@ struct gemm_test_params {
size_t expected_fused_primitives;
size_t expected_not_fused_primitives;
std::string kernel_name;
dim_vec_kind broadcast_kind;
eltwise_mode eltwise_m;
};
class GemmFusingTest : public ::BaseFusingTest<gemm_test_params> {
@@ -108,6 +112,7 @@ public:
#define CASE_GEMM_2IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
@@ -275,6 +280,47 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_scale, ::testing::ValuesIn(std::v
gemm_test_params{ CASE_GEMM_2IN_U8U8_3, 3, 4 },
}));
class gemm_2in_add : public GemmFusingTest {};
TEST_P(gemm_2in_add, eltwise_postop) {
auto p = GetParam();
if (engine.get_device_info().supports_immad) {
ov::intel_gpu::ImplementationDesc gemmv_impl = { cldnn::format::type::any, "", impl_types::onednn };
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemmv_impl } }));
cfg_fused.set_property(ov::intel_gpu::queue_type(QueueTypes::in_order));
}
auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_tensor();
if (p.broadcast_kind == dim_vec_kind::batch)
add_data_size.batch[0] = 1;
else
add_data_size.feature[0] = 1;
add_data_layout.set_tensor(add_data_size);
create_topologies(
input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(p.default_type);
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vector<gemm_test_params>{
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::batch, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
}));
class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};
TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
// TODO: Fix me, refer PR(#15873)