[GPU] GEMM dynamic (#13248)

This commit is contained in:
Roman Lyamin 2022-10-21 16:19:35 +04:00 committed by GitHub
parent 1047bb7732
commit 478939ea9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 587 additions and 169 deletions

View File

@ -46,12 +46,16 @@ struct gemm : public primitive_base<gemm> {
const bool transpose_input1 = false,
const float alpha = 1.0f,
const float beta = 0.0f,
const size_t input_rank = 4,
const size_t weight_rank = 4,
const padding& output_padding = padding())
: primitive_base(id, inputs, output_padding, optional_data_type{ data_type }),
transpose_input0(transpose_input0),
transpose_input1(transpose_input1),
alpha(alpha),
beta(beta) {
beta(beta),
input_rank(input_rank),
weight_rank(weight_rank) {
if (inputs.size() != 2 && inputs.size() != 3) {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}
@ -65,6 +69,10 @@ struct gemm : public primitive_base<gemm> {
float alpha;
/// @brief Variable containing BETA parameter
float beta;
/// @brief First matrix rank
size_t input_rank;
/// @brief Second matrix rank
size_t weight_rank;
};
} // namespace cldnn

View File

@ -24,21 +24,55 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
auto input0_layout = impl_param.get_input_layout(0);
auto input1_layout = impl_param.get_input_layout(1);
auto input0_shape = input0_layout.get_shape();
auto input1_shape = input1_layout.get_shape();
bool transpose_input0 = prim->transpose_input0;
bool transpose_input1 = prim->transpose_input1;
auto M = !transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
auto N = !transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);
bool reordered = prim->input_rank > 4 || prim->weight_rank > 4;
size_t output_rank = std::max(prim->input_rank, prim->weight_rank);
size_t input_rank = reordered ? output_rank : prim->input_rank;
size_t weight_rank = reordered ? output_rank : prim->weight_rank;
auto output_size = input0_layout.get_tensor();
auto update_input_shape = [&output_rank](const ov::Shape& input_shape, size_t rank, bool transpose, bool first_input) {
auto input_shape_update = ov::Shape(input_shape.begin(), input_shape.begin() + std::min(rank, input_shape.size()));
if (input_shape_update.size() == 1) {
first_input ? input_shape_update.insert(input_shape_update.begin(), 1)
: input_shape_update.insert(input_shape_update.end(), 1);
if (transpose) {
std::swap(input_shape_update[0], input_shape_update[1]);
}
output_rank = std::max(output_rank, rank + 1);
}
input_shape_update.insert(input_shape_update.begin(), output_rank - input_shape_update.size(), 1);
return input_shape_update;
};
for (size_t i = 1; i < prim->input_size(); ++i) {
auto input_layout = impl_param.get_input_layout(i);
output_size = tensor::max(output_size, input_layout.get_tensor());
auto input0_shape_update = update_input_shape(input0_shape, input_rank, transpose_input0, true);
auto input1_shape_update = update_input_shape(input1_shape, weight_rank, transpose_input1, false);
ov::Shape bias_shape(output_rank);
if (prim->input_size() == 3) {
bias_shape = impl_param.get_input_layout(2).get_shape();
bias_shape = update_input_shape(bias_shape, weight_rank, transpose_input1, false);
}
output_size.spatial[0] = N;
output_size.spatial[1] = M;
auto output_shape = input0_shape_update;
for (size_t i = 0; i < output_rank; ++i) {
output_shape[i] = std::max(std::max(input0_shape_update[i], input1_shape_update[i]), bias_shape[i]);
}
size_t M = !transpose_input0 ? *(input0_shape_update.end() - 2) : input0_shape_update.back();
size_t N = !transpose_input1 ? input1_shape_update.back() : *(input1_shape_update.end() - 2);
output_shape[output_rank - 2] = M;
output_shape[output_rank - 1] = N;
size_t ones_to_add = 4 - std::min(output_shape.size(), static_cast<size_t>(4));
output_shape.insert(output_shape.begin(), ones_to_add, 1);
auto output_type = input0_layout.data_type;
if ((output_type == data_types::u8 || output_type == data_types::i8) && prim->output_data_type)
output_type = *prim->output_data_type;
@ -49,7 +83,7 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
auto output_format = input0_layout.format;
return layout(output_type, output_format, output_size, prim->output_padding);
return layout(output_shape, output_type, output_format, prim->output_padding);
}
template<typename ShapeType>
@ -105,51 +139,5 @@ std::string gemm_inst::to_string(gemm_node const& node) {
return primitive_description.str();
}
gemm_inst::typed_primitive_inst(network& network, gemm_node const& node) : parent(network, node) {
if (is_dynamic())
return;
auto input0_layout = node.input(0).get_output_layout();
auto input1_layout = node.input(1).get_output_layout();
bool transpose_input0 = node.get_primitive()->transpose_input0;
bool transpose_input1 = node.get_primitive()->transpose_input1;
auto transposed_x0 = input0_layout.spatial(0);
auto transposed_y0 = input0_layout.spatial(1);
if (transpose_input0) {
std::swap(transposed_x0, transposed_y0);
}
auto transposed_x1 = input1_layout.spatial(0);
auto transposed_y1 = input1_layout.spatial(1);
if (transpose_input1) {
std::swap(transposed_x1, transposed_y1);
}
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Input 0 internal dimension size",
transposed_x0,
"Input 1 internal dimension size",
transposed_y1,
"");
if (node.inputs_count() == 3) {
auto input2_layout = node.input(2).get_output_layout();
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Input 0 external dimension size",
transposed_y0,
"Input 2 rows number",
input2_layout.spatial(1),
"");
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Input 1 external dimension size",
transposed_x1,
"Input 2 columns number",
input2_layout.spatial(0),
"");
}
}
gemm_inst::typed_primitive_inst(network& network, gemm_node const& node) : parent(network, node) {}
} // namespace cldnn

View File

@ -10,6 +10,7 @@
#include "gemm/gemm_kernel_selector.h"
#include "gemm/gemm_kernel_base.h"
#include "intel_gpu/runtime/error_handler.hpp"
#include <algorithm>
namespace cldnn {
namespace ocl {
@ -25,13 +26,91 @@ struct gemm_impl : typed_primitive_impl_ocl<gemm> {
public:
static primitive_impl* create(const gemm_node& arg, const kernel_impl_params& impl_param) {
auto desc = arg.get_primitive();
auto get_gemm_input_layouts = [desc](const std::vector<layout>& input_layouts, const layout& output_layout) {
auto get_updated_input_shape = [&](const ov::Shape& input_shape, size_t input_rank, bool transpose, bool first_input) {
ov::Shape updated_input_shape;
if (input_rank == 1) {
updated_input_shape = { *std::max_element(input_shape.begin(), input_shape.end()) };
} else {
updated_input_shape = ov::Shape(input_shape.begin(), input_shape.begin() + input_rank);
}
if (updated_input_shape.size() == 1) {
first_input ? updated_input_shape.insert(updated_input_shape.begin(), 1)
: updated_input_shape.insert(updated_input_shape.end(), 1);
if (transpose) {
std::swap(updated_input_shape[0], updated_input_shape[1]);
}
}
size_t ones_to_add = std::max(output_layout.get_shape().size(), static_cast<size_t>(4)) - updated_input_shape.size();
updated_input_shape.insert(updated_input_shape.begin(), ones_to_add, 1ul);
return updated_input_shape;
};
auto input0_shape = input_layouts[0].get_shape();
auto input1_shape = input_layouts[1].get_shape();
auto updated_input0_shape = get_updated_input_shape(input0_shape, desc->input_rank, desc->transpose_input0, true);
auto updated_input1_shape = get_updated_input_shape(input1_shape, desc->weight_rank, desc->transpose_input1, false);
std::vector<layout> layouts = input_layouts;
layouts[0].set_partial_shape(updated_input0_shape);
layouts[1].set_partial_shape(updated_input1_shape);
if (input_layouts.size() == 3) {
auto bias_shape = input_layouts[2].get_shape();
auto updated_bias_shape = get_updated_input_shape(bias_shape, desc->weight_rank, desc->transpose_input1, false);
layouts[2].set_partial_shape(updated_bias_shape);
}
return layouts;
};
auto get_gemm_output_layout = [desc](const std::vector<layout>& input_layouts, const layout& output_layout) {
auto updated_output_layout = output_layout;
auto output_rank = output_layout.get_shape().size();
if (output_rank < 4) {
const auto& input0_layout = input_layouts[0];
const auto& input1_layout = input_layouts[1];
auto M = !desc->transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
auto N = !desc->transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);
auto output_shape = input0_layout.get_shape();
for (const auto& input_layout : input_layouts) {
auto input_shape = input_layout.get_shape();
for (size_t i = 0; i != input_shape.size(); ++i) {
output_shape[i] = std::max(output_shape[i], input_shape[i]);
}
}
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
return idx;
};
output_shape[get_spatial_idx(updated_output_layout.format, 0)] = N;
output_shape[get_spatial_idx(updated_output_layout.format, 1)] = M;
updated_output_layout.set_partial_shape(output_shape);
}
return updated_output_layout;
};
const auto input_layouts = get_gemm_input_layouts(impl_param.input_layouts, impl_param.output_layout);
const auto output_layout = get_gemm_output_layout(input_layouts, impl_param.output_layout);
auto gemm_params = get_default_params<kernel_selector::gemm_params>(impl_param, 1);
auto gemm_optional_params =
get_default_optional_params<kernel_selector::gemm_optional_params>(arg.get_program());
for (size_t i = 1; i < arg.inputs_count(); i++) {
gemm_params.inputs.push_back(convert_data_tensor(impl_param.input_layouts[i]));
gemm_params.inputs.clear();
for (size_t i = 0; i < desc->input_size(); ++i) {
gemm_params.inputs.push_back(convert_data_tensor(input_layouts[i]));
}
gemm_params.outputs[0] = convert_data_tensor(output_layout);
gemm_params.alpha = desc->alpha;
gemm_params.beta = desc->beta;

View File

@ -53,19 +53,106 @@ protected:
static std::shared_ptr<dnnl::matmul::desc> get_gemm_descriptor(const kernel_impl_params& impl_params) {
auto prim = impl_params.typed_desc<gemm>();
auto gemm_with_bias = prim->dependencies().size() == 3;
auto in0_l = impl_params.get_input_layout(0);
auto in1_l = impl_params.get_input_layout(1);
auto get_gemm_input_layouts = [prim](const std::vector<layout>& input_layouts) {
auto get_updated_input_shape = [&](const ov::Shape& input_shape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) {
ov::Shape updated_input_shape;
if (input_rank == 1) {
updated_input_shape = { *std::max_element(input_shape.begin(), input_shape.end()) };
} else {
updated_input_shape = ov::Shape(input_shape.begin(), input_shape.begin() + input_rank);
}
if (updated_input_shape.size() == 1) {
first_input ? updated_input_shape.insert(updated_input_shape.begin(), 1)
: updated_input_shape.insert(updated_input_shape.end(), 1);
if (transpose) {
std::swap(updated_input_shape[0], updated_input_shape[1]);
}
}
size_t ones_to_add = std::max(output_rank, static_cast<size_t>(4)) - updated_input_shape.size();
updated_input_shape.insert(updated_input_shape.begin(), ones_to_add, 1ul);
return updated_input_shape;
};
auto input0_shape = input_layouts[0].get_shape();
auto input1_shape = input_layouts[1].get_shape();
bool reordered = prim->input_rank > 4 || prim->weight_rank > 4;
size_t output_rank = std::max(prim->input_rank, prim->weight_rank);
size_t input_rank = reordered ? output_rank : prim->input_rank;
size_t weight_rank = reordered ? output_rank : prim->weight_rank;
auto updated_input0_shape = get_updated_input_shape(input0_shape, input_rank, output_rank, prim->transpose_input0, true);
auto updated_input1_shape = get_updated_input_shape(input1_shape, weight_rank, output_rank, prim->transpose_input1, false);
std::vector<layout> layouts = input_layouts;
layouts[0].set_partial_shape(updated_input0_shape);
layouts[1].set_partial_shape(updated_input1_shape);
if (input_layouts.size() == 3) {
auto bias_shape = input_layouts[2].get_shape();
auto updated_bias_shape = get_updated_input_shape(bias_shape, prim->weight_rank, output_rank, prim->transpose_input1, false);
layouts[2].set_partial_shape(updated_bias_shape);
}
return layouts;
};
auto get_gemm_output_layout = [prim](const std::vector<layout>& input_layouts, const layout& output_layout) {
auto updated_output_layout = output_layout;
auto output_rank = output_layout.get_shape().size();
if (output_rank < 4) {
const auto& input0_layout = input_layouts[0];
const auto& input1_layout = input_layouts[1];
auto M = !prim->transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
auto N = !prim->transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);
auto output_shape = input0_layout.get_shape();
for (const auto& input_layout : input_layouts) {
auto input_shape = input_layout.get_shape();
for (size_t i = 0; i != input_shape.size(); ++i) {
output_shape[i] = std::max(output_shape[i], input_shape[i]);
}
}
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
return idx;
};
output_shape[get_spatial_idx(updated_output_layout.format, 0)] = N;
output_shape[get_spatial_idx(updated_output_layout.format, 1)] = M;
updated_output_layout.set_partial_shape(output_shape);
}
return updated_output_layout;
};
auto gemm_with_bias = prim->dependencies().size() == 3;
auto out_l = impl_params.output_layout;
std::vector<layout> in_layouts { impl_params.get_input_layout(0), impl_params.get_input_layout(1) };
if (gemm_with_bias) {
in_layouts.emplace_back(impl_params.get_input_layout(2));
}
in_layouts = get_gemm_input_layouts(in_layouts);
out_l = get_gemm_output_layout(in_layouts, out_l);
const auto& in0_l = in_layouts[0];
const auto& in1_l = in_layouts[1];
size_t in0_batched_size = in0_l.count() / (in0_l.spatial(0) * in0_l.spatial(1));
size_t in1_batched_size = in1_l.count() / (in1_l.spatial(0) * in1_l.spatial(1));
size_t out_batched_size = out_l.count() / (out_l.spatial(0) * out_l.spatial(1));
auto batched_dims_can_be_removed = in0_batched_size == 1 && in1_batched_size == 1 && out_batched_size == 1;
if (gemm_with_bias) {
auto bias_l = impl_params.get_input_layout(2);
const auto& bias_l = in_layouts[2];
size_t bias_batched_size = bias_l.count() / (bias_l.spatial(0) * bias_l.spatial(1));
batched_dims_can_be_removed &= bias_batched_size == 1;
}

View File

@ -20,7 +20,6 @@
#include "reduce_inst.h"
#include "one_hot_inst.h"
#include "permute_inst.h"
#include "gemm_inst.h"
#include "quantize_inst.h"
#include "mvn_inst.h"
#include "depth_to_space_inst.h"
@ -1546,7 +1545,9 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
}
auto gemm_prim = node.as<gemm>().get_primitive();
if (!node.is_dynamic()) {
if (node.is_dynamic()) {
impl_candidate = impl_types::ocl;
} else {
auto in0_l = node.get_dependency(0).get_output_layout();
auto in1_l = node.get_dependency(1).get_output_layout();
auto out_l = node.get_output_layout();
@ -1564,13 +1565,13 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
auto valid_input_batch = in0_batched_size != 1 && (in1_batched_size == in0_batched_size || in1_batched_size == 1);
auto valid_output_batch = in0_batched_size > in1_batched_size ? out_batched_size == in0_batched_size :
out_batched_size == in1_batched_size;
out_batched_size == in1_batched_size;
auto valid_extra_input_batch = has_input2 ? in2_batched_size == 1 || in2_batched_size == out_batched_size : true;
auto valid_scale_factor = gemm_prim->alpha == 1.f && (has_input2 ? gemm_prim->beta == 1.f : true);
auto unsupported_onednn_gemm = !valid_input_batch ||
!valid_output_batch ||
!valid_extra_input_batch ||
!valid_scale_factor;
!valid_output_batch ||
!valid_extra_input_batch ||
!valid_scale_factor;
bool is_u8_i8 = data_type_traits::is_i8_u8(in0_l.data_type) && data_type_traits::is_i8_u8(in1_l.data_type);
bool use_ops_cldnn_kernel = is_u8_i8 || (in0_l.spatial(0) % 16 == 0 && in0_l.spatial(1) % 16 == 0 &&

View File

@ -83,11 +83,13 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
bool is_fc = IsNodeOnConstPath(op->get_input_node_shared_ptr(1));
is_fc &= std::count_if(shape_b.begin(), shape_b.end(), [](Dimension x) { return x != 1; }) <= 2;
// TODO: This conditions can be relaxed with proper handling in FC path
is_fc &= rank_a > 1 && rank_b > 1 && shape_b.is_static();
is_fc &= rank_a > 1 && rank_b > 1;
PartialShape shape_a_aligned, shape_b_aligned;
bool aligned = false;
std::tie(aligned, shape_a_aligned, shape_b_aligned) = get_aligned_shapes(shape_a, shape_b, op);
if (shape_b.is_static()) {
std::tie(aligned, shape_a_aligned, shape_b_aligned) = get_aligned_shapes(shape_a, shape_b, op);
}
is_fc &= aligned;
if (is_fc) {
@ -161,109 +163,31 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
p.add_primitive(*op, outReshapePrim);
}
} else {
auto outDims = op->get_output_shape(0);
auto outDimsN = outDims.size();
auto gemmSpecificTensor = [](const InferenceEngine::SizeVector& dims) {
switch (dims.size()) {
case 2: return cldnn::tensor(cldnn::spatial(dims[1], dims[0]));
case 3: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::spatial(dims[2], dims[1]));
case 4: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::feature(dims[1]), cldnn::spatial(dims[3], dims[2]));
case 5: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::feature(dims[1]), cldnn::spatial(dims[4], dims[3], dims[2]));
case 6: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::feature(dims[1]), cldnn::spatial(dims[5], dims[4], dims[3], dims[2]));
default: IE_THROW() << "Invalid dimensions size(" << dims.size() << ") for Gemm layer";
}
};
// Preprocess inputs
for (size_t i = 0; i < inputPrimitives.size(); ++i) {
auto inputDims = op->get_input_shape(i);
auto inputDimsN = inputDims.size();
// Add reorder if changing number of dimensions requires changing format
auto targetFormat = cldnn::format::get_default_format(outDimsN);
if (targetFormat.value != cldnn::format::get_default_format(inputDimsN).value) {
auto reorderName = layerName + "_cldnn_in" + std::to_string(i) + "_reorder";
auto targetDatatype = cldnn::element_type_to_data_type(op->get_output_element_type(0));
auto reorderPrim = cldnn::reorder(reorderName,
inputPrimitives[i],
targetFormat,
targetDatatype,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract);
p.add_primitive(*op, reorderPrim);
inputPrimitives[i] = reorderName;
}
// Reshape input if they differ or gemm specific shape matches default one
if (inputDimsN != outDimsN || inputDimsN < 4) {
auto reshapeName = layerName + "_cldnn_in" + std::to_string(i) + "_reshape";
// Extend input dimensions by prepending ones
if (inputDimsN == 1) {
// One-dimensional tensors unsqueezing is applied for each input independently.
// The axes inserted in this step are not included in the output shape.
// * If rank of the **first** input is equal to 1, it is always unsqueezed to 2D tensor **row vector** (regardless of `transpose_a`)
// by adding axes with size 1 at ROW_INDEX_DIM, to the **left** of the shape. For example `[S]` will be reshaped to `[1, S]`.
// * If rank of the **second** input is equal to 1, it is always unsqueezed to 2D tensor **column vector** (regardless of `transpose_b`)
// by adding axes with size 1 at COL_INDEX_DIM, to the **right** of the shape. For example `[S]` will be reshaped to `[S, 1]`.
bool transpose = false;
if (i == 0) {
transpose = op->get_transpose_a();
inputDims.insert(inputDims.begin(), 1);
} else {
transpose = op->get_transpose_b();
inputDims.insert(inputDims.end(), 1);
}
// Specs says that shapes must be unsqueezed regardless of tranpose flag, but primitive implementation always respects transposes
// so we have to swap dimensions correspondingly to have consistent shapes.
if (transpose) {
std::swap(inputDims[0], inputDims[1]);
}
}
if (inputDimsN < outDimsN)
inputDims.insert(inputDims.begin(), outDimsN - inputDimsN, 1ul);
auto targetShape = gemmSpecificTensor(inputDims);
auto reshapePrim = cldnn::reshape(reshapeName, inputPrimitives[i], targetShape);
p.add_primitive(*op, reshapePrim);
inputPrimitives[i] = reshapeName;
}
}
// Add actual gemm
auto alpha = 1.0f;
auto beta = 0.0f;
auto transA = op->get_transpose_a();
auto transB = op->get_transpose_b();
auto gemmPrim = cldnn::gemm(layerName,
inputPrimitives,
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
transA,
transB,
op->get_transpose_a(),
op->get_transpose_b(),
alpha,
beta);
beta,
rank_a,
rank_b);
p.add_primitive(*op, gemmPrim);
auto lastLayerName = layerName;
// Reshape output if gemm specific shape does not match default one
if (outDimsN < 4) {
auto outputShape = tensor_from_dims(outDims);
auto outReshapeName = layerName + "_cldnn_out_reshape";
auto outReshapePrim = cldnn::reshape(outReshapeName, layerName, outputShape);
p.add_primitive(*op, outReshapePrim);
lastLayerName = outReshapeName;
if (!p.use_new_shape_infer()) {
auto outDims = op->get_output_shape(0);
auto outDimsN = outDims.size();
// Reshape output if gemm specific shape does not match default one
if (outDimsN < 4) {
auto outputShape = tensor_from_dims(outDims);
auto outReshapeName = layerName + "_cldnn_out_reshape";
auto outReshapePrim = cldnn::reshape(outReshapeName, layerName, outputShape);
p.add_primitive(*op, outReshapePrim);
}
}
}
}

View File

@ -14,11 +14,6 @@ using namespace ov::test;
namespace GPULayerTestsDefinitions {
enum class MatMulNodeType {
MatMul,
FullyConnected
};
struct ShapeRelatedParams {
std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose;
@ -374,5 +369,341 @@ const auto fullyConnectedParams4D_smoke = ::testing::Combine(::testing::ValuesIn
INSTANTIATE_TEST_SUITE_P(smoke_FC_4D, MatMulLayerGPUTest, fullyConnectedParams4D_smoke, MatMulLayerGPUTest::getTestCaseName);
} // namespace fullyConnected
/* ============= MatMul ============= */
namespace matmul {
const std::vector<ShapeRelatedParams> IS = {
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, false}},
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, false}},
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, true}},
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, true}},
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {false, false}},
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {true, false}},
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {false, true}},
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {true, true}},
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {false, false}},
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {true, false}},
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {false, true}},
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {true, true}},
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {false, false}},
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {true, false}},
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {false, true}},
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {true, true}}
};
const std::vector<ShapeRelatedParams> IS_OneDNN = {
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {false, false}},
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {true, false}},
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {false, true}},
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {true, true}},
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {false, false}},
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {true, false}},
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {false, true}},
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {true, true}},
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {false, false}},
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {true, false}},
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {false, true}},
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {true, true}}
};
const std::vector<ShapeRelatedParams> IS_Dynamic = {
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
},
{true, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
},
{false, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
},
{true, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
},
{false, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
},
{true, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
},
{false, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{true, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{false, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{true, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{false, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, 16 }, {{ 4, 16 }, { 2, 16 }}}, // input 0
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
},
{false, false}
}
};
const std::vector<ShapeRelatedParams> IS_Dynamic_nightly = {
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{{5, 15}, {1, 12}, {4, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
{{{1, 13}, {3, 15}, {1, 10}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ {2, 10}, {3, 15}, -1, 16 }, {{ 2, 12, 4, 16 }, { 3, 12, 2, 16 }}}, // input 0
{{ 1, 1, -1, 4 }, {{ 1, 1, 16, 4 }, { 1, 1, 16, 4 }}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ 1, 1, -1, 16 }, {{ 1, 1, 4, 16 }, { 1, 1, 2, 16 }}}, // input 0
{{ {2, 5}, {3, 15}, -1, 4 }, {{ 2, 12, 16, 4 }, { 2, 12, 16, 4 }}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, 16 }, {{ 4, 16 }, { 2, 16 }}}, // input 0
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, {2, 15}, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
{{ -1, 4 }, {{ 16, 4 }, { 16, 4 }}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, {1, 15}, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
{{ -1, 4 }, {{ 16, 4 }, { 16, 4 }}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ {1, 3}, {1, 9}, {1, 5}, {1, 10} }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
{{ {1, 5}, {1, 7}, {1, 8}, {1, 5} }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ {1, 3}, {1, 9}, {1, 5}, {1, 10} }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
{{ {1, 5}, {1, 7}, {1, 8}, {1, 5} }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ 1, 7, 4, -1 }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
{{ 1, 7, -1, 4 }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ 1, 7, 4, -1 }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
{{ 1, 7, -1, 4 }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
},
{false, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
},
{true, true}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
},
{true, false}
},
{
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
},
{false, true}
}
};
const auto testParams = ::testing::Combine(::testing::ValuesIn(IS),
::testing::ValuesIn(netPRCs),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::PARAMETER),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::ValuesIn(additionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_MM_Static, MatMulLayerGPUTest, testParams, MatMulLayerGPUTest::getTestCaseName);
const auto testParamsOneDNN = ::testing::Combine(::testing::ValuesIn(IS_OneDNN),
::testing::Values(ElementType::f16),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::PARAMETER),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::ValuesIn(additionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_MM_Static_OneDNN, MatMulLayerGPUTest, testParamsOneDNN, MatMulLayerGPUTest::getTestCaseName);
const auto testParamsDynamic = ::testing::Combine(::testing::ValuesIn(IS_Dynamic),
::testing::ValuesIn(netPRCs),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::PARAMETER),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::ValuesIn(additionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_MM_Dynamic, MatMulLayerGPUTest, testParamsDynamic, MatMulLayerGPUTest::getTestCaseName);
const auto testParamsDynamic_nightly = ::testing::Combine(::testing::ValuesIn(IS_Dynamic_nightly),
::testing::ValuesIn(netPRCs),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::PARAMETER),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::ValuesIn(additionalConfig));
INSTANTIATE_TEST_SUITE_P(nightly_MM_Dynamic, MatMulLayerGPUTest, testParamsDynamic_nightly, MatMulLayerGPUTest::getTestCaseName);
} // namespace matmul
} // namespace
} // namespace GPULayerTestsDefinitions