[GPU] Handle unsupported eltwise fusion for onednn gemm in dynamic cases (#16875)
* [GPU] Handle unsupported eltwise fusion for onednn gemm in dynamic cases * Update src/plugins/intel_gpu/tests/fusions/gemm_fusion_test.cpp Co-authored-by: Sergey Shlyapnikov <Sergeishlyapnikov@gmail.com> --------- Co-authored-by: Sergey Shlyapnikov <Sergeishlyapnikov@gmail.com>
This commit is contained in:
committed by
GitHub
parent
656428bc4f
commit
5299f26168
@@ -995,11 +995,15 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
|
||||
auto& data = node->get_dependency(fused_prim.dep_start_idx);
|
||||
|
||||
auto gemm_layout = node->get_output_layout();
|
||||
auto data_layout = data.get_output_layout();
|
||||
|
||||
if (gemm_layout.is_dynamic() || data_layout.is_dynamic())
|
||||
continue;
|
||||
|
||||
auto gemm_dims = onednn::convert_gemm_tensor(gemm_layout.get_tensor(),
|
||||
cldnn::format::dimension(gemm_layout.format),
|
||||
false);
|
||||
|
||||
auto data_layout = data.get_output_layout();
|
||||
auto data_dims = onednn::convert_gemm_tensor(data_layout.get_tensor(),
|
||||
cldnn::format::dimension(data_layout.format),
|
||||
false);
|
||||
|
||||
@@ -32,6 +32,10 @@
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
#include <impls/onednn/utils.hpp>
|
||||
#endif
|
||||
|
||||
namespace cldnn {
|
||||
namespace {
|
||||
|
||||
@@ -1065,6 +1069,29 @@ bool primitive_inst::is_valid_fusion() const {
|
||||
auto merged_shape = out_pshape;
|
||||
auto can_broadcast = ov::PartialShape::broadcast_merge_into(merged_shape, dep_pshape, fd.typed_desc<eltwise>()->broadcast_spec);
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
// WA for OneDNN binary add fusions: we need to broadcast batch dimension to avoid situation with
|
||||
// batch dimension mismatch in OneDNN tensor descriptors as follow:
|
||||
// * Gemm output shape: (b,f,y,x) -> OneDNN shape: (b*f,y,x)
|
||||
// * Gemm fused op shape: (1,f,y,x) -> OneDNN shape: (1*f,y,x)
|
||||
// If batch dimension of gemm output is not equal to 1, then OneDNN will not be able to broadcast fused op data
|
||||
// correctly and we need to do it manually
|
||||
if (_node->is_type<gemm>() && _node->get_preferred_impl_type() == impl_types::onednn) {
|
||||
auto gemm_layout = _impl_params->get_output_layout();
|
||||
auto data_layout = dep.first->_impl_params->get_output_layout();
|
||||
auto gemm_dims = onednn::convert_gemm_tensor(gemm_layout.get_tensor(),
|
||||
cldnn::format::dimension(gemm_layout.format),
|
||||
false);
|
||||
|
||||
auto data_dims = onednn::convert_gemm_tensor(data_layout.get_tensor(),
|
||||
cldnn::format::dimension(data_layout.format),
|
||||
false);
|
||||
|
||||
if (gemm_dims[0] != data_dims[0])
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
// We check that broadcasting of extra input is possible and it doesn't change output shape. If it output shape is changed, then
|
||||
// some dimension of dep_pshape is greater than out_pshape
|
||||
if (!can_broadcast || merged_shape != out_pshape)
|
||||
|
||||
@@ -20,10 +20,8 @@ using namespace ::tests;
|
||||
|
||||
namespace {
|
||||
struct gemm_test_params {
|
||||
std::vector<tensor> in_shapes;
|
||||
tensor out_shape;
|
||||
tensor kernel;
|
||||
tensor pad;
|
||||
std::vector<ov::PartialShape> in_shapes;
|
||||
ov::PartialShape out_shape;
|
||||
data_types data_type_in0;
|
||||
data_types data_type_in1;
|
||||
data_types data_type_in2;
|
||||
@@ -40,7 +38,10 @@ struct gemm_test_params {
|
||||
class GemmFusingTest : public ::BaseFusingTest<gemm_test_params> {
|
||||
public:
|
||||
|
||||
void execute(gemm_test_params& p, bool is_caching_test = false) {
|
||||
void execute(gemm_test_params& p, bool is_dynamic, bool is_caching_test = false) {
|
||||
cfg_not_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
|
||||
cfg_fused.set_property(ov::intel_gpu::allow_new_shape_infer(is_dynamic));
|
||||
|
||||
auto input0_prim = get_mem(get_input_layout(p, 0));
|
||||
auto input1_prim = get_mem(get_input_layout(p, 1));
|
||||
|
||||
@@ -67,22 +68,20 @@ public:
|
||||
}
|
||||
|
||||
layout get_input_layout(gemm_test_params& p, int in_no) {
|
||||
auto pad = p.pad;
|
||||
std::vector<int> pad_ = { 0, 0, pad.spatial[0], pad.spatial[1] };
|
||||
if (in_no == 0)
|
||||
return layout{ p.data_type_in0, p.input_format, p.in_shapes.at(0), padding{ pad_ } };
|
||||
return layout{ p.in_shapes.at(0), p.data_type_in0, p.input_format };
|
||||
else if (in_no == 1)
|
||||
return layout{ p.data_type_in1, p.input_format, p.in_shapes.at(1), padding{ pad_ } };
|
||||
return layout{ p.in_shapes.at(1), p.data_type_in1, p.input_format };
|
||||
else
|
||||
return layout{ p.data_type_in2, p.input_format, p.in_shapes.at(2), padding{ pad_ } };
|
||||
return layout{ p.in_shapes.at(2), p.data_type_in2, p.input_format };
|
||||
}
|
||||
|
||||
layout get_per_channel_layout(gemm_test_params& p) {
|
||||
return layout{ p.default_type, p.default_format, tensor{ 1, p.in_shapes.at(0).feature[0], 1, 1 } };
|
||||
return layout{ov::PartialShape{ 1, p.in_shapes[0][1], 1, 1 }, p.default_type, p.default_format };
|
||||
}
|
||||
|
||||
layout get_output_layout(gemm_test_params& p) {
|
||||
return layout{ p.default_type, p.input_format, p.out_shape };
|
||||
return layout{ p.out_shape, p.default_type, p.input_format };
|
||||
}
|
||||
};
|
||||
|
||||
@@ -92,43 +91,43 @@ public:
|
||||
/* ---------------------------------------- Gemm cases ------------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
|
||||
#define CASE_GEMM_3IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP32_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 }, { 1, 2, 256, 128 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 }, { 1, 2, 256, 128 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_S8S8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_S8S8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 }, { 1, 2, 256, 128 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_S8S8_3 { { 1, 1, 8, 16 }, { 1, 1, 32, 8 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 16 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP32_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 }, { 1, 2, 128, 256 } }, { 1, 2, 128, 256 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_FP16_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 }, { 1, 2, 128, 256 } }, { 1, 2, 128, 256 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_3IN_S8S8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_S8S8_2 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 }, { 1, 2, 128, 256 } }, { 1, 2, 128, 256 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_3IN_S8S8_3 { { 1, 1, 16, 8 }, { 1, 1, 8, 32 }, { 1, 1, 16, 32 } }, { 1, 1, 16, 32 }, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_GEMM_2IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP32_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_5D_1 { { 2, 3, 4, 6, 5 }, { 2, 3, 6, 4, 5 } }, { 2, 3, 6, 6, 5 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||
#define CASE_GEMM_2IN_FP16_6D_1 { { 2, 3, 7, 5, 3, 2 }, { 2, 3, 5, 7, 3, 2 } }, { 2, 3, 5, 5, 3, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfwzyx, data_types::f16, format::bfwzyx
|
||||
#define CASE_GEMM_2IN_FP32_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP32_2 { { 1, 1, 63, 63 }, { 1, 1, 63, 63 } }, { 1, 1, 63, 63 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP32_3 { { 1, 1, 128, 128 }, { 1, 1, 128, 128 } }, { 1, 1, 128, 128 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP32_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_2 { { 1, 1, 31, 31 }, { 1, 1, 31, 31 } }, { 1, 1, 31, 31 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_2IN_FP16_5D_1 { { 2, 3, 5, 6, 4 }, { 2, 3, 5, 4, 6} }, { 2, 3, 5, 6, 6 }, data_types::f16, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||
#define CASE_GEMM_2IN_FP16_6D_1 { { 2, 3, 2, 3, 5, 7 }, { 2, 3, 2, 3, 7, 5 } }, { 2, 3, 2, 3, 5, 5 }, data_types::f16, data_types::f16, data_types::f16, format::bfwzyx, data_types::f16, format::bfwzyx
|
||||
|
||||
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_GEMM_2IN_U8S8_1 { { 1, 1, 4, 2 }, { 1, 1, 8, 4 } }, { 1, 1, 8, 4 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_S8U8_1 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_U8S8_1 { { 1, 1, 2, 4 }, { 1, 1, 4, 8 } }, { 1, 1, 2, 8 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_2IN_S8U8_1 { { 1, 2, 128, 64 }, { 1, 2, 64, 256 } }, { 1, 2, 128, 256 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_GEMM_ELTWISE_2IN_FP32_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, tensor{ 1 }, tensor{ 0 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_FP16_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_FP16_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_U8S8_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_S8U8_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_U8S8_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_FP32_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_FP16_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_FP16_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_U8S8_1 { { 1, 1, 4, 4 }, { 1, 1, 4, 4 } }, { 1, 1, 4, 4 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_S8U8_1 { { 1, 1, 32, 32 }, { 1, 1, 32, 32 } }, { 1, 1, 32, 32 }, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
#define CASE_GEMM_ELTWISE_2IN_U8S8_2 { { 1, 1, 1024, 1024 }, { 1, 1, 1024, 1024 } }, { 1, 1, 1024, 1024 }, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||
|
||||
class gemm_3in_quantize_i8 : public GemmFusingTest {};
|
||||
TEST_P(gemm_3in_quantize_i8, basic) {
|
||||
@@ -151,7 +150,7 @@ TEST_P(gemm_3in_quantize_i8, basic) {
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(data_types::i8);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_3in_quantize_i8, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
@@ -185,7 +184,7 @@ TEST_P(gemm_2in_quantize_u8, basic) {
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(data_types::u8);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_u8, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
@@ -222,7 +221,7 @@ TEST_P(gemm_2in_quantize_float_in, basic) {
|
||||
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemm_impl } }));
|
||||
|
||||
tolerance = default_tolerance(data_types::u8);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_quantize_float_in, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
@@ -244,14 +243,14 @@ TEST_P(gemm_2in_scale, basic) {
|
||||
create_topologies(
|
||||
input_layout("input0", get_input_layout(p, 0)),
|
||||
input_layout("input1", get_input_layout(p, 1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 0.5f)),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
|
||||
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
|
||||
reorder("reorder_bfyx", input_info("scale"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
TEST_P(gemm_2in_scale, fp16_scale_out) {
|
||||
@@ -259,14 +258,14 @@ TEST_P(gemm_2in_scale, fp16_scale_out) {
|
||||
create_topologies(
|
||||
input_layout("input0", get_input_layout(p, 0)),
|
||||
input_layout("input1", get_input_layout(p, 1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 0.5f)),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
|
||||
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, data_types::f16),
|
||||
reorder("reorder_bfyx", input_info("scale"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_scale, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
@@ -294,12 +293,12 @@ TEST_P(gemm_2in_add, eltwise_postop) {
|
||||
}
|
||||
|
||||
auto add_data_layout = get_output_layout(p);
|
||||
auto add_data_size = add_data_layout.get_tensor();
|
||||
auto add_data_size = add_data_layout.get_partial_shape();
|
||||
if (p.broadcast_kind == dim_vec_kind::batch)
|
||||
add_data_size.batch[0] = 1;
|
||||
add_data_size[0] = 1;
|
||||
else
|
||||
add_data_size.feature[0] = 1;
|
||||
add_data_layout.set_tensor(add_data_size);
|
||||
add_data_size[1] = 1;
|
||||
add_data_layout.set_partial_shape(add_data_size);
|
||||
|
||||
auto in_layout0 = get_input_layout(p, 0);
|
||||
auto in_layout1 = get_input_layout(p, 1);
|
||||
@@ -307,14 +306,50 @@ TEST_P(gemm_2in_add, eltwise_postop) {
|
||||
create_topologies(
|
||||
input_layout("input0", in_layout0),
|
||||
input_layout("input1", in_layout1),
|
||||
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
|
||||
data("add_data", get_mem(add_data_layout, 0.5f)),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
|
||||
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
|
||||
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
TEST_P(gemm_2in_add, eltwise_postop_dynamic) {
|
||||
auto p = GetParam();
|
||||
|
||||
if (engine.get_device_info().supports_immad) {
|
||||
ov::intel_gpu::ImplementationDesc gemmv_impl = { cldnn::format::type::any, "", impl_types::onednn };
|
||||
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemmv_impl } }));
|
||||
cfg_fused.set_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape(true));
|
||||
}
|
||||
|
||||
auto add_data_layout = get_output_layout(p);
|
||||
auto add_data_size = add_data_layout.get_partial_shape();
|
||||
if (p.broadcast_kind == dim_vec_kind::batch)
|
||||
add_data_size[0] = 1;
|
||||
else
|
||||
add_data_size[1] = 1;
|
||||
add_data_layout.set_partial_shape(add_data_size);
|
||||
|
||||
auto in_layout0 = get_input_layout(p, 0);
|
||||
auto in_layout1 = get_input_layout(p, 1);
|
||||
|
||||
in_layout0.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[0].size()));
|
||||
in_layout1.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[1].size()));
|
||||
|
||||
create_topologies(
|
||||
input_layout("input0", in_layout0),
|
||||
input_layout("input1", in_layout1),
|
||||
data("add_data", get_mem(add_data_layout, 0.5f)),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
|
||||
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
|
||||
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p, true);
|
||||
}
|
||||
|
||||
TEST_P(gemm_2in_add, eltwise_postop_cached) {
|
||||
@@ -326,12 +361,12 @@ TEST_P(gemm_2in_add, eltwise_postop_cached) {
|
||||
}
|
||||
|
||||
auto add_data_layout = get_output_layout(p);
|
||||
auto add_data_size = add_data_layout.get_tensor();
|
||||
auto add_data_size = add_data_layout.get_partial_shape();
|
||||
if (p.broadcast_kind == dim_vec_kind::batch)
|
||||
add_data_size.batch[0] = 1;
|
||||
add_data_size[0] = 1;
|
||||
else
|
||||
add_data_size.feature[0] = 1;
|
||||
add_data_layout.set_tensor(add_data_size);
|
||||
add_data_size[1] = 1;
|
||||
add_data_layout.set_partial_shape(add_data_size);
|
||||
|
||||
auto in_layout0 = get_input_layout(p, 0);
|
||||
auto in_layout1 = get_input_layout(p, 1);
|
||||
@@ -339,14 +374,14 @@ TEST_P(gemm_2in_add, eltwise_postop_cached) {
|
||||
create_topologies(
|
||||
input_layout("input0", in_layout0),
|
||||
input_layout("input1", in_layout1),
|
||||
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
|
||||
data("add_data", get_mem(add_data_layout, 0.5f)),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
|
||||
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
|
||||
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p, true);
|
||||
execute(p, false, true);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
@@ -377,7 +412,7 @@ TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
|
||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
|
||||
activation("activation", input_info("gemm_prim"), activation_func::exp),
|
||||
eltwise("scale", { input_info("activation"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
|
||||
@@ -387,7 +422,7 @@ TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(data_types::i8);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_act_scale_quantize_i8, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
@@ -413,7 +448,7 @@ TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) {
|
||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
|
||||
data("eltwise_data", get_mem(get_output_layout(p))),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
|
||||
activation("activation", input_info("gemm_prim"), activation_func::exp),
|
||||
@@ -425,7 +460,7 @@ TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) {
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(data_types::i8);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_act_scale_quantize_eltwise_i8, ::testing::ValuesIn(std::vector<gemm_test_params>{
|
||||
@@ -441,7 +476,7 @@ TEST_P(gemm_2in_act_scale_eltwise, basic) {
|
||||
create_topologies(
|
||||
input_layout("input0", get_input_layout(p, 0)),
|
||||
input_layout("input1", get_input_layout(p, 1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
|
||||
data("eltwise_data", get_mem(get_output_layout(p))),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
|
||||
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
|
||||
@@ -451,7 +486,7 @@ TEST_P(gemm_2in_act_scale_eltwise, basic) {
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
|
||||
@@ -459,7 +494,7 @@ TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
|
||||
create_topologies(
|
||||
input_layout("input0", get_input_layout(p, 0)),
|
||||
input_layout("input1", get_input_layout(p, 1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255.f)),
|
||||
data("eltwise_data", get_mem(get_single_element_layout(p))),
|
||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
|
||||
eltwise("scale", { input_info("gemm_prim"), input_info("scale_data") }, eltwise_mode::prod, p.default_type),
|
||||
@@ -469,7 +504,7 @@ TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
|
||||
);
|
||||
|
||||
tolerance = default_tolerance(p.default_type);
|
||||
execute(p);
|
||||
execute(p, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
||||
Reference in New Issue
Block a user