[GPU] Handle unsupported eltwise fusion for onednn gemm in dynamic cases (#16875)

* [GPU] Handle unsupported eltwise fusion for onednn gemm in dynamic cases

* Update src/plugins/intel_gpu/tests/fusions/gemm_fusion_test.cpp

Co-authored-by: Sergey Shlyapnikov <Sergeishlyapnikov@gmail.com>

---------

Co-authored-by: Sergey Shlyapnikov <Sergeishlyapnikov@gmail.com>
This commit is contained in:
Vladimir Paramuzov
2023-04-13 15:55:44 +04:00
committed by GitHub
parent 656428bc4f
commit 5299f26168
3 changed files with 139 additions and 73 deletions

View File

@@ -995,11 +995,15 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
auto& data = node->get_dependency(fused_prim.dep_start_idx);
auto gemm_layout = node->get_output_layout();
auto data_layout = data.get_output_layout();
if (gemm_layout.is_dynamic() || data_layout.is_dynamic())
continue;
auto gemm_dims = onednn::convert_gemm_tensor(gemm_layout.get_tensor(),
cldnn::format::dimension(gemm_layout.format),
false);
auto data_layout = data.get_output_layout();
auto data_dims = onednn::convert_gemm_tensor(data_layout.get_tensor(),
cldnn::format::dimension(data_layout.format),
false);

View File

@@ -32,6 +32,10 @@
#include <memory>
#include <algorithm>
#ifdef ENABLE_ONEDNN_FOR_GPU
#include <impls/onednn/utils.hpp>
#endif
namespace cldnn {
namespace {
@@ -1065,6 +1069,29 @@ bool primitive_inst::is_valid_fusion() const {
auto merged_shape = out_pshape;
auto can_broadcast = ov::PartialShape::broadcast_merge_into(merged_shape, dep_pshape, fd.typed_desc<eltwise>()->broadcast_spec);
#ifdef ENABLE_ONEDNN_FOR_GPU
// WA for OneDNN binary add fusions: we need to broadcast batch dimension to avoid situation with
// batch dimension mismatch in OneDNN tensor descriptors as follow:
// * Gemm output shape: (b,f,y,x) -> OneDNN shape: (b*f,y,x)
// * Gemm fused op shape: (1,f,y,x) -> OneDNN shape: (1*f,y,x)
// If batch dimension of gemm output is not equal to 1, then OneDNN will not be able to broadcast fused op data
// correctly and we need to do it manually
if (_node->is_type<gemm>() && _node->get_preferred_impl_type() == impl_types::onednn) {
auto gemm_layout = _impl_params->get_output_layout();
auto data_layout = dep.first->_impl_params->get_output_layout();
auto gemm_dims = onednn::convert_gemm_tensor(gemm_layout.get_tensor(),
cldnn::format::dimension(gemm_layout.format),
false);
auto data_dims = onednn::convert_gemm_tensor(data_layout.get_tensor(),
cldnn::format::dimension(data_layout.format),
false);
if (gemm_dims[0] != data_dims[0])
return false;
}
#endif
// We check that broadcasting of extra input is possible and it doesn't change output shape. If it output shape is changed, then
// some dimension of dep_pshape is greater than out_pshape
if (!can_broadcast || merged_shape != out_pshape)

View File

@@ -20,10 +20,8 @@ using namespace ::tests;
namespace {
struct gemm_test_params {
std::vector<tensor> in_shapes;
tensor out_shape;
tensor kernel;
tensor pad;
std::vector<ov::PartialShape> in_shapes;
ov::PartialShape out_shape;
data_types data_type_in0;
data_types data_type_in1;
data_types data_type_in2;
@@ -40,7 +38,10 @@ struct gemm_test_params {
class GemmFusingTest : public ::BaseFusingTest<gemm_test_params> {
public:
void execute(gemm_test_params& p, bool is_caching_test = false) {
void execute(gemm_test_params& p, bool is_dynamic, bool is_caching_test = false) {
cfg_not_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
cfg_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
auto input0_prim = get_mem(get_input_layout(p, 0));
auto input1_prim = get_mem(get_input_layout(p, 1));
@@ -67,22 +68,20 @@ public:
}
layout get_input_layout(gemm_test_params& p, int in_no) {
auto pad = p.pad;
std::vector<int> pad_ = { 0, 0, pad.spatial[0], pad.spatial[1] };
if (in_no == 0)
return layout{ p.data_type_in0, p.input_format, p.in_shapes.at(0), padding{ pad_ } };
return layout{ p.in_shapes.at(0), p.data_type_in0, p.input_format };
else if (in_no == 1)
return layout{ p.data_type_in1, p.input_format, p.in_shapes.at(1), padding{ pad_ } };
return layout{ p.in_shapes.at(1), p.data_type_in1, p.input_format };
else
return layout{ p.data_type_in2, p.input_format, p.in_shapes.at(2), padding{ pad_ } };
return layout{ p.in_shapes.at(2), p.data_type_in2, p.input_format };
}
layout get_per_channel_layout(gemm_test_params& p) {
return layout{ p.default_type, p.default_format, tensor{ 1, p.in_shapes.at(0).feature[0], 1, 1 } };
return layout{ov::PartialShape{ 1, p.in_shapes[0][1], 1, 1 }, p.default_type, p.default_format };
}
layout get_output_layout(gemm_test_params& p) {
return layout{ p.default_type, p.input_format, p.out_shape };
return layout{ p.out_shape, p.default_type, p.input_format };
}
};
@@ -92,43 +91,43 @@ public:
/* ---------------------------------------- Gemm cases ------------------------------------------------- */
/* ----------------------------------------------------------------------------------------------------- */
#define CASE_GEMM_3IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP32_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 }, { 1, 2, 256, 128 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 }, { 1, 2, 256, 128 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_S8S8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_S8S8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 }, { 1, 2, 256, 128 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_S8S8_3 { { 1, 1, 8, 16 }, { 1, 1, 32, 8 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 16 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP32_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 }, { 1, 2, 128, 256 } }, { 1, 2, 128, 256 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_FP16_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 }, { 1, 2, 128, 256 } }, { 1, 2, 128, 256 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_3IN_S8S8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_S8S8_2 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 }, { 1, 2, 128, 256 } }, { 1, 2, 128, 256 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_3IN_S8S8_3 { { 1, 1, 16, 8 }, { 1, 1, 8, 32 }, { 1, 1, 16, 32 } }, { 1, 1, 16, 32 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP32_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_5D_1 { { 2, 3, 4, 6, 5 }, { 2, 3, 6, 4, 5 } }, { 2, 3, 6, 6, 5 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GEMM_2IN_FP16_6D_1 { { 2, 3, 7, 5, 3, 2 }, { 2, 3, 5, 7, 3, 2 } }, { 2, 3, 5, 5, 3, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfwzyx, data_types::f16, format::bfwzyx
#define CASE_GEMM_2IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP32_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_5D_1 { { 2, 3, 5, 6, 4 }, { 2, 3, 5, 4, 6} }, { 2, 3, 5, 6, 6 }, data_types::f16, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GEMM_2IN_FP16_6D_1 { { 2, 3, 2, 3, 5, 7 }, { 2, 3, 2, 3, 7, 5 } }, { 2, 3, 2, 3, 5, 5 }, data_types::f16, data_types::f16, data_types::f16, format::bfwzyx, data_types::f16, format::bfwzyx
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8S8_1 { { 1, 1, 4, 2 }, { 1, 1, 8, 4 } }, { 1, 1, 8, 4 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_S8U8_1 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8S8_1 { { 1, 1, 2, 4 }, { 1, 1, 4, 8 } }, { 1, 1, 2, 8 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_S8U8_1 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_FP32_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_FP16_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_FP16_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_U8S8_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_S8U8_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_U8S8_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_FP32_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_FP16_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_FP16_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_U8S8_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_S8U8_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_ELTWISE_2IN_U8S8_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
class gemm_3in_quantize_i8 : public GemmFusingTest {};
TEST_P(gemm_3in_quantize_i8, basic) {
@@ -151,7 +150,7 @@ TEST_P(gemm_3in_quantize_i8, basic) {
);
tolerance = default_tolerance(data_types::i8);
execute(p);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_3in_quantize_i8, ::testing::ValuesIn(std::vector<gemm_test_params>{
@@ -185,7 +184,7 @@ TEST_P(gemm_2in_quantize_u8, basic) {
);
tolerance = default_tolerance(data_types::u8);
execute(p);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_u8, ::testing::ValuesIn(std::vector<gemm_test_params>{
@@ -222,7 +221,7 @@ TEST_P(gemm_2in_quantize_float_in, basic) {
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemm_impl } }));
tolerance = default_tolerance(data_types::u8);
execute(p);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_float_in, ::testing::ValuesIn(std::vector<gemm_test_params>{
@@ -244,14 +243,14 @@ TEST_P(gemm_2in_scale, basic) {
create_topologies(
input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
data("scale_data", get_mem(get_per_channel_layout(p), 0.5f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
reorder("reorder_bfyx", input_info("scale"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(p.default_type);
execute(p);
execute(p, false);
}
TEST_P(gemm_2in_scale, fp16_scale_out) {
@@ -259,14 +258,14 @@ TEST_P(gemm_2in_scale, fp16_scale_out) {
create_topologies(
input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
data("scale_data", get_mem(get_per_channel_layout(p), 0.5f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, data_types::f16),
reorder("reorder_bfyx", input_info("scale"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(p.default_type);
execute(p);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_scale, ::testing::ValuesIn(std::vector<gemm_test_params>{
@@ -294,12 +293,12 @@ TEST_P(gemm_2in_add, eltwise_postop) {
}
auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_tensor();
auto add_data_size = add_data_layout.get_partial_shape();
if (p.broadcast_kind == dim_vec_kind::batch)
add_data_size.batch[0] = 1;
add_data_size[0] = 1;
else
add_data_size.feature[0] = 1;
add_data_layout.set_tensor(add_data_size);
add_data_size[1] = 1;
add_data_layout.set_partial_shape(add_data_size);
auto in_layout0 = get_input_layout(p, 0);
auto in_layout1 = get_input_layout(p, 1);
@@ -307,14 +306,50 @@ TEST_P(gemm_2in_add, eltwise_postop) {
create_topologies(
input_layout("input0", in_layout0),
input_layout("input1", in_layout1),
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
data("add_data", get_mem(add_data_layout, 0.5f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(p.default_type);
execute(p);
execute(p, false);
}
TEST_P(gemm_2in_add, eltwise_postop_dynamic) {
auto p = GetParam();
if (engine.get_device_info().supports_immad) {
ov::intel_gpu::ImplementationDesc gemmv_impl = { cldnn::format::type::any, "", impl_types::onednn };
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemmv_impl } }));
cfg_fused.set_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape(true));
}
auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_partial_shape();
if (p.broadcast_kind == dim_vec_kind::batch)
add_data_size[0] = 1;
else
add_data_size[1] = 1;
add_data_layout.set_partial_shape(add_data_size);
auto in_layout0 = get_input_layout(p, 0);
auto in_layout1 = get_input_layout(p, 1);
in_layout0.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[0].size()));
in_layout1.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[1].size()));
create_topologies(
input_layout("input0", in_layout0),
input_layout("input1", in_layout1),
data("add_data", get_mem(add_data_layout, 0.5f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(p.default_type);
execute(p, true);
}
TEST_P(gemm_2in_add, eltwise_postop_cached) {
@@ -326,12 +361,12 @@ TEST_P(gemm_2in_add, eltwise_postop_cached) {
}
auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_tensor();
auto add_data_size = add_data_layout.get_partial_shape();
if (p.broadcast_kind == dim_vec_kind::batch)
add_data_size.batch[0] = 1;
add_data_size[0] = 1;
else
add_data_size.feature[0] = 1;
add_data_layout.set_tensor(add_data_size);
add_data_size[1] = 1;
add_data_layout.set_partial_shape(add_data_size);
auto in_layout0 = get_input_layout(p, 0);
auto in_layout1 = get_input_layout(p, 1);
@@ -339,14 +374,14 @@ TEST_P(gemm_2in_add, eltwise_postop_cached) {
create_topologies(
input_layout("input0", in_layout0),
input_layout("input1", in_layout1),
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
data("add_data", get_mem(add_data_layout, 0.5f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
);
tolerance = default_tolerance(p.default_type);
execute(p, true);
execute(p, false, true);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vector<gemm_test_params>{
@@ -377,7 +412,7 @@ TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
activation("activation", input_info("gemm_prim"), activation_func::exp),
eltwise("scale", { input_info("activation"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
@@ -387,7 +422,7 @@ TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
);
tolerance = default_tolerance(data_types::i8);
execute(p);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_act_scale_quantize_i8, ::testing::ValuesIn(std::vector<gemm_test_params>{
@@ -413,7 +448,7 @@ TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) {
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
data("eltwise_data", get_mem(get_output_layout(p))),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
activation("activation", input_info("gemm_prim"), activation_func::exp),
@@ -425,7 +460,7 @@ TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) {
);
tolerance = default_tolerance(data_types::i8);
execute(p);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_act_scale_quantize_eltwise_i8, ::testing::ValuesIn(std::vector<gemm_test_params>{
@@ -441,7 +476,7 @@ TEST_P(gemm_2in_act_scale_eltwise, basic) {
create_topologies(
input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
data("eltwise_data", get_mem(get_output_layout(p))),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
@@ -451,7 +486,7 @@ TEST_P(gemm_2in_act_scale_eltwise, basic) {
);
tolerance = default_tolerance(p.default_type);
execute(p);
execute(p, false);
}
TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
@@ -459,7 +494,7 @@ TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
create_topologies(
input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
data("eltwise_data", get_mem(get_single_element_layout(p))),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
@@ -469,7 +504,7 @@ TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
);
tolerance = default_tolerance(p.default_type);
execute(p);
execute(p, false);
}
INSTANTIATE_TEST_SUITE_P(