[GPU] GEMM dynamic (#13248)
This commit is contained in:
parent
1047bb7732
commit
478939ea9e
@ -46,12 +46,16 @@ struct gemm : public primitive_base<gemm> {
|
|||||||
const bool transpose_input1 = false,
|
const bool transpose_input1 = false,
|
||||||
const float alpha = 1.0f,
|
const float alpha = 1.0f,
|
||||||
const float beta = 0.0f,
|
const float beta = 0.0f,
|
||||||
|
const size_t input_rank = 4,
|
||||||
|
const size_t weight_rank = 4,
|
||||||
const padding& output_padding = padding())
|
const padding& output_padding = padding())
|
||||||
: primitive_base(id, inputs, output_padding, optional_data_type{ data_type }),
|
: primitive_base(id, inputs, output_padding, optional_data_type{ data_type }),
|
||||||
transpose_input0(transpose_input0),
|
transpose_input0(transpose_input0),
|
||||||
transpose_input1(transpose_input1),
|
transpose_input1(transpose_input1),
|
||||||
alpha(alpha),
|
alpha(alpha),
|
||||||
beta(beta) {
|
beta(beta),
|
||||||
|
input_rank(input_rank),
|
||||||
|
weight_rank(weight_rank) {
|
||||||
if (inputs.size() != 2 && inputs.size() != 3) {
|
if (inputs.size() != 2 && inputs.size() != 3) {
|
||||||
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
|
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
|
||||||
}
|
}
|
||||||
@ -65,6 +69,10 @@ struct gemm : public primitive_base<gemm> {
|
|||||||
float alpha;
|
float alpha;
|
||||||
/// @brief Variable containing BETA parameter
|
/// @brief Variable containing BETA parameter
|
||||||
float beta;
|
float beta;
|
||||||
|
/// @brief First matrix rank
|
||||||
|
size_t input_rank;
|
||||||
|
/// @brief Second matrix rank
|
||||||
|
size_t weight_rank;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cldnn
|
} // namespace cldnn
|
||||||
|
@ -24,21 +24,55 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
|
|||||||
|
|
||||||
auto input0_layout = impl_param.get_input_layout(0);
|
auto input0_layout = impl_param.get_input_layout(0);
|
||||||
auto input1_layout = impl_param.get_input_layout(1);
|
auto input1_layout = impl_param.get_input_layout(1);
|
||||||
|
|
||||||
|
auto input0_shape = input0_layout.get_shape();
|
||||||
|
auto input1_shape = input1_layout.get_shape();
|
||||||
|
|
||||||
bool transpose_input0 = prim->transpose_input0;
|
bool transpose_input0 = prim->transpose_input0;
|
||||||
bool transpose_input1 = prim->transpose_input1;
|
bool transpose_input1 = prim->transpose_input1;
|
||||||
|
|
||||||
auto M = !transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
|
bool reordered = prim->input_rank > 4 || prim->weight_rank > 4;
|
||||||
auto N = !transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);
|
size_t output_rank = std::max(prim->input_rank, prim->weight_rank);
|
||||||
|
size_t input_rank = reordered ? output_rank : prim->input_rank;
|
||||||
|
size_t weight_rank = reordered ? output_rank : prim->weight_rank;
|
||||||
|
|
||||||
auto output_size = input0_layout.get_tensor();
|
auto update_input_shape = [&output_rank](const ov::Shape& input_shape, size_t rank, bool transpose, bool first_input) {
|
||||||
|
auto input_shape_update = ov::Shape(input_shape.begin(), input_shape.begin() + std::min(rank, input_shape.size()));
|
||||||
|
if (input_shape_update.size() == 1) {
|
||||||
|
first_input ? input_shape_update.insert(input_shape_update.begin(), 1)
|
||||||
|
: input_shape_update.insert(input_shape_update.end(), 1);
|
||||||
|
if (transpose) {
|
||||||
|
std::swap(input_shape_update[0], input_shape_update[1]);
|
||||||
|
}
|
||||||
|
output_rank = std::max(output_rank, rank + 1);
|
||||||
|
}
|
||||||
|
input_shape_update.insert(input_shape_update.begin(), output_rank - input_shape_update.size(), 1);
|
||||||
|
return input_shape_update;
|
||||||
|
};
|
||||||
|
|
||||||
for (size_t i = 1; i < prim->input_size(); ++i) {
|
auto input0_shape_update = update_input_shape(input0_shape, input_rank, transpose_input0, true);
|
||||||
auto input_layout = impl_param.get_input_layout(i);
|
auto input1_shape_update = update_input_shape(input1_shape, weight_rank, transpose_input1, false);
|
||||||
output_size = tensor::max(output_size, input_layout.get_tensor());
|
|
||||||
|
ov::Shape bias_shape(output_rank);
|
||||||
|
if (prim->input_size() == 3) {
|
||||||
|
bias_shape = impl_param.get_input_layout(2).get_shape();
|
||||||
|
bias_shape = update_input_shape(bias_shape, weight_rank, transpose_input1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
output_size.spatial[0] = N;
|
auto output_shape = input0_shape_update;
|
||||||
output_size.spatial[1] = M;
|
for (size_t i = 0; i < output_rank; ++i) {
|
||||||
|
output_shape[i] = std::max(std::max(input0_shape_update[i], input1_shape_update[i]), bias_shape[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t M = !transpose_input0 ? *(input0_shape_update.end() - 2) : input0_shape_update.back();
|
||||||
|
size_t N = !transpose_input1 ? input1_shape_update.back() : *(input1_shape_update.end() - 2);
|
||||||
|
|
||||||
|
output_shape[output_rank - 2] = M;
|
||||||
|
output_shape[output_rank - 1] = N;
|
||||||
|
|
||||||
|
size_t ones_to_add = 4 - std::min(output_shape.size(), static_cast<size_t>(4));
|
||||||
|
output_shape.insert(output_shape.begin(), ones_to_add, 1);
|
||||||
|
|
||||||
auto output_type = input0_layout.data_type;
|
auto output_type = input0_layout.data_type;
|
||||||
if ((output_type == data_types::u8 || output_type == data_types::i8) && prim->output_data_type)
|
if ((output_type == data_types::u8 || output_type == data_types::i8) && prim->output_data_type)
|
||||||
output_type = *prim->output_data_type;
|
output_type = *prim->output_data_type;
|
||||||
@ -49,7 +83,7 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
|
|||||||
|
|
||||||
auto output_format = input0_layout.format;
|
auto output_format = input0_layout.format;
|
||||||
|
|
||||||
return layout(output_type, output_format, output_size, prim->output_padding);
|
return layout(output_shape, output_type, output_format, prim->output_padding);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename ShapeType>
|
template<typename ShapeType>
|
||||||
@ -105,51 +139,5 @@ std::string gemm_inst::to_string(gemm_node const& node) {
|
|||||||
return primitive_description.str();
|
return primitive_description.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
gemm_inst::typed_primitive_inst(network& network, gemm_node const& node) : parent(network, node) {
|
gemm_inst::typed_primitive_inst(network& network, gemm_node const& node) : parent(network, node) {}
|
||||||
if (is_dynamic())
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto input0_layout = node.input(0).get_output_layout();
|
|
||||||
auto input1_layout = node.input(1).get_output_layout();
|
|
||||||
bool transpose_input0 = node.get_primitive()->transpose_input0;
|
|
||||||
bool transpose_input1 = node.get_primitive()->transpose_input1;
|
|
||||||
|
|
||||||
auto transposed_x0 = input0_layout.spatial(0);
|
|
||||||
auto transposed_y0 = input0_layout.spatial(1);
|
|
||||||
|
|
||||||
if (transpose_input0) {
|
|
||||||
std::swap(transposed_x0, transposed_y0);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto transposed_x1 = input1_layout.spatial(0);
|
|
||||||
auto transposed_y1 = input1_layout.spatial(1);
|
|
||||||
|
|
||||||
if (transpose_input1) {
|
|
||||||
std::swap(transposed_x1, transposed_y1);
|
|
||||||
}
|
|
||||||
|
|
||||||
CLDNN_ERROR_NOT_EQUAL(node.id(),
|
|
||||||
"Input 0 internal dimension size",
|
|
||||||
transposed_x0,
|
|
||||||
"Input 1 internal dimension size",
|
|
||||||
transposed_y1,
|
|
||||||
"");
|
|
||||||
|
|
||||||
if (node.inputs_count() == 3) {
|
|
||||||
auto input2_layout = node.input(2).get_output_layout();
|
|
||||||
|
|
||||||
CLDNN_ERROR_NOT_EQUAL(node.id(),
|
|
||||||
"Input 0 external dimension size",
|
|
||||||
transposed_y0,
|
|
||||||
"Input 2 rows number",
|
|
||||||
input2_layout.spatial(1),
|
|
||||||
"");
|
|
||||||
CLDNN_ERROR_NOT_EQUAL(node.id(),
|
|
||||||
"Input 1 external dimension size",
|
|
||||||
transposed_x1,
|
|
||||||
"Input 2 columns number",
|
|
||||||
input2_layout.spatial(0),
|
|
||||||
"");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace cldnn
|
} // namespace cldnn
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include "gemm/gemm_kernel_selector.h"
|
#include "gemm/gemm_kernel_selector.h"
|
||||||
#include "gemm/gemm_kernel_base.h"
|
#include "gemm/gemm_kernel_base.h"
|
||||||
#include "intel_gpu/runtime/error_handler.hpp"
|
#include "intel_gpu/runtime/error_handler.hpp"
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace cldnn {
|
namespace cldnn {
|
||||||
namespace ocl {
|
namespace ocl {
|
||||||
@ -25,13 +26,91 @@ struct gemm_impl : typed_primitive_impl_ocl<gemm> {
|
|||||||
public:
|
public:
|
||||||
static primitive_impl* create(const gemm_node& arg, const kernel_impl_params& impl_param) {
|
static primitive_impl* create(const gemm_node& arg, const kernel_impl_params& impl_param) {
|
||||||
auto desc = arg.get_primitive();
|
auto desc = arg.get_primitive();
|
||||||
|
auto get_gemm_input_layouts = [desc](const std::vector<layout>& input_layouts, const layout& output_layout) {
|
||||||
|
auto get_updated_input_shape = [&](const ov::Shape& input_shape, size_t input_rank, bool transpose, bool first_input) {
|
||||||
|
ov::Shape updated_input_shape;
|
||||||
|
|
||||||
|
if (input_rank == 1) {
|
||||||
|
updated_input_shape = { *std::max_element(input_shape.begin(), input_shape.end()) };
|
||||||
|
} else {
|
||||||
|
updated_input_shape = ov::Shape(input_shape.begin(), input_shape.begin() + input_rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (updated_input_shape.size() == 1) {
|
||||||
|
first_input ? updated_input_shape.insert(updated_input_shape.begin(), 1)
|
||||||
|
: updated_input_shape.insert(updated_input_shape.end(), 1);
|
||||||
|
|
||||||
|
if (transpose) {
|
||||||
|
std::swap(updated_input_shape[0], updated_input_shape[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t ones_to_add = std::max(output_layout.get_shape().size(), static_cast<size_t>(4)) - updated_input_shape.size();
|
||||||
|
updated_input_shape.insert(updated_input_shape.begin(), ones_to_add, 1ul);
|
||||||
|
|
||||||
|
return updated_input_shape;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto input0_shape = input_layouts[0].get_shape();
|
||||||
|
auto input1_shape = input_layouts[1].get_shape();
|
||||||
|
|
||||||
|
auto updated_input0_shape = get_updated_input_shape(input0_shape, desc->input_rank, desc->transpose_input0, true);
|
||||||
|
auto updated_input1_shape = get_updated_input_shape(input1_shape, desc->weight_rank, desc->transpose_input1, false);
|
||||||
|
|
||||||
|
std::vector<layout> layouts = input_layouts;
|
||||||
|
layouts[0].set_partial_shape(updated_input0_shape);
|
||||||
|
layouts[1].set_partial_shape(updated_input1_shape);
|
||||||
|
|
||||||
|
if (input_layouts.size() == 3) {
|
||||||
|
auto bias_shape = input_layouts[2].get_shape();
|
||||||
|
auto updated_bias_shape = get_updated_input_shape(bias_shape, desc->weight_rank, desc->transpose_input1, false);
|
||||||
|
layouts[2].set_partial_shape(updated_bias_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return layouts;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_gemm_output_layout = [desc](const std::vector<layout>& input_layouts, const layout& output_layout) {
|
||||||
|
auto updated_output_layout = output_layout;
|
||||||
|
auto output_rank = output_layout.get_shape().size();
|
||||||
|
if (output_rank < 4) {
|
||||||
|
const auto& input0_layout = input_layouts[0];
|
||||||
|
const auto& input1_layout = input_layouts[1];
|
||||||
|
|
||||||
|
auto M = !desc->transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
|
||||||
|
auto N = !desc->transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);
|
||||||
|
|
||||||
|
auto output_shape = input0_layout.get_shape();
|
||||||
|
for (const auto& input_layout : input_layouts) {
|
||||||
|
auto input_shape = input_layout.get_shape();
|
||||||
|
for (size_t i = 0; i != input_shape.size(); ++i) {
|
||||||
|
output_shape[i] = std::max(output_shape[i], input_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
|
||||||
|
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
|
||||||
|
return idx;
|
||||||
|
};
|
||||||
|
|
||||||
|
output_shape[get_spatial_idx(updated_output_layout.format, 0)] = N;
|
||||||
|
output_shape[get_spatial_idx(updated_output_layout.format, 1)] = M;
|
||||||
|
updated_output_layout.set_partial_shape(output_shape);
|
||||||
|
}
|
||||||
|
return updated_output_layout;
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto input_layouts = get_gemm_input_layouts(impl_param.input_layouts, impl_param.output_layout);
|
||||||
|
const auto output_layout = get_gemm_output_layout(input_layouts, impl_param.output_layout);
|
||||||
|
|
||||||
auto gemm_params = get_default_params<kernel_selector::gemm_params>(impl_param, 1);
|
auto gemm_params = get_default_params<kernel_selector::gemm_params>(impl_param, 1);
|
||||||
auto gemm_optional_params =
|
auto gemm_optional_params =
|
||||||
get_default_optional_params<kernel_selector::gemm_optional_params>(arg.get_program());
|
get_default_optional_params<kernel_selector::gemm_optional_params>(arg.get_program());
|
||||||
|
|
||||||
for (size_t i = 1; i < arg.inputs_count(); i++) {
|
gemm_params.inputs.clear();
|
||||||
gemm_params.inputs.push_back(convert_data_tensor(impl_param.input_layouts[i]));
|
for (size_t i = 0; i < desc->input_size(); ++i) {
|
||||||
|
gemm_params.inputs.push_back(convert_data_tensor(input_layouts[i]));
|
||||||
}
|
}
|
||||||
|
gemm_params.outputs[0] = convert_data_tensor(output_layout);
|
||||||
|
|
||||||
gemm_params.alpha = desc->alpha;
|
gemm_params.alpha = desc->alpha;
|
||||||
gemm_params.beta = desc->beta;
|
gemm_params.beta = desc->beta;
|
||||||
|
@ -53,19 +53,106 @@ protected:
|
|||||||
|
|
||||||
static std::shared_ptr<dnnl::matmul::desc> get_gemm_descriptor(const kernel_impl_params& impl_params) {
|
static std::shared_ptr<dnnl::matmul::desc> get_gemm_descriptor(const kernel_impl_params& impl_params) {
|
||||||
auto prim = impl_params.typed_desc<gemm>();
|
auto prim = impl_params.typed_desc<gemm>();
|
||||||
auto gemm_with_bias = prim->dependencies().size() == 3;
|
|
||||||
|
|
||||||
auto in0_l = impl_params.get_input_layout(0);
|
auto get_gemm_input_layouts = [prim](const std::vector<layout>& input_layouts) {
|
||||||
auto in1_l = impl_params.get_input_layout(1);
|
auto get_updated_input_shape = [&](const ov::Shape& input_shape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) {
|
||||||
|
ov::Shape updated_input_shape;
|
||||||
|
|
||||||
|
if (input_rank == 1) {
|
||||||
|
updated_input_shape = { *std::max_element(input_shape.begin(), input_shape.end()) };
|
||||||
|
} else {
|
||||||
|
updated_input_shape = ov::Shape(input_shape.begin(), input_shape.begin() + input_rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (updated_input_shape.size() == 1) {
|
||||||
|
first_input ? updated_input_shape.insert(updated_input_shape.begin(), 1)
|
||||||
|
: updated_input_shape.insert(updated_input_shape.end(), 1);
|
||||||
|
|
||||||
|
if (transpose) {
|
||||||
|
std::swap(updated_input_shape[0], updated_input_shape[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t ones_to_add = std::max(output_rank, static_cast<size_t>(4)) - updated_input_shape.size();
|
||||||
|
updated_input_shape.insert(updated_input_shape.begin(), ones_to_add, 1ul);
|
||||||
|
|
||||||
|
return updated_input_shape;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto input0_shape = input_layouts[0].get_shape();
|
||||||
|
auto input1_shape = input_layouts[1].get_shape();
|
||||||
|
|
||||||
|
bool reordered = prim->input_rank > 4 || prim->weight_rank > 4;
|
||||||
|
size_t output_rank = std::max(prim->input_rank, prim->weight_rank);
|
||||||
|
size_t input_rank = reordered ? output_rank : prim->input_rank;
|
||||||
|
size_t weight_rank = reordered ? output_rank : prim->weight_rank;
|
||||||
|
|
||||||
|
auto updated_input0_shape = get_updated_input_shape(input0_shape, input_rank, output_rank, prim->transpose_input0, true);
|
||||||
|
auto updated_input1_shape = get_updated_input_shape(input1_shape, weight_rank, output_rank, prim->transpose_input1, false);
|
||||||
|
|
||||||
|
std::vector<layout> layouts = input_layouts;
|
||||||
|
layouts[0].set_partial_shape(updated_input0_shape);
|
||||||
|
layouts[1].set_partial_shape(updated_input1_shape);
|
||||||
|
|
||||||
|
if (input_layouts.size() == 3) {
|
||||||
|
auto bias_shape = input_layouts[2].get_shape();
|
||||||
|
auto updated_bias_shape = get_updated_input_shape(bias_shape, prim->weight_rank, output_rank, prim->transpose_input1, false);
|
||||||
|
layouts[2].set_partial_shape(updated_bias_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return layouts;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_gemm_output_layout = [prim](const std::vector<layout>& input_layouts, const layout& output_layout) {
|
||||||
|
auto updated_output_layout = output_layout;
|
||||||
|
auto output_rank = output_layout.get_shape().size();
|
||||||
|
if (output_rank < 4) {
|
||||||
|
const auto& input0_layout = input_layouts[0];
|
||||||
|
const auto& input1_layout = input_layouts[1];
|
||||||
|
|
||||||
|
auto M = !prim->transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
|
||||||
|
auto N = !prim->transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);
|
||||||
|
|
||||||
|
auto output_shape = input0_layout.get_shape();
|
||||||
|
for (const auto& input_layout : input_layouts) {
|
||||||
|
auto input_shape = input_layout.get_shape();
|
||||||
|
for (size_t i = 0; i != input_shape.size(); ++i) {
|
||||||
|
output_shape[i] = std::max(output_shape[i], input_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
|
||||||
|
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
|
||||||
|
return idx;
|
||||||
|
};
|
||||||
|
|
||||||
|
output_shape[get_spatial_idx(updated_output_layout.format, 0)] = N;
|
||||||
|
output_shape[get_spatial_idx(updated_output_layout.format, 1)] = M;
|
||||||
|
updated_output_layout.set_partial_shape(output_shape);
|
||||||
|
}
|
||||||
|
return updated_output_layout;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto gemm_with_bias = prim->dependencies().size() == 3;
|
||||||
auto out_l = impl_params.output_layout;
|
auto out_l = impl_params.output_layout;
|
||||||
|
|
||||||
|
std::vector<layout> in_layouts { impl_params.get_input_layout(0), impl_params.get_input_layout(1) };
|
||||||
|
if (gemm_with_bias) {
|
||||||
|
in_layouts.emplace_back(impl_params.get_input_layout(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
in_layouts = get_gemm_input_layouts(in_layouts);
|
||||||
|
out_l = get_gemm_output_layout(in_layouts, out_l);
|
||||||
|
|
||||||
|
const auto& in0_l = in_layouts[0];
|
||||||
|
const auto& in1_l = in_layouts[1];
|
||||||
|
|
||||||
size_t in0_batched_size = in0_l.count() / (in0_l.spatial(0) * in0_l.spatial(1));
|
size_t in0_batched_size = in0_l.count() / (in0_l.spatial(0) * in0_l.spatial(1));
|
||||||
size_t in1_batched_size = in1_l.count() / (in1_l.spatial(0) * in1_l.spatial(1));
|
size_t in1_batched_size = in1_l.count() / (in1_l.spatial(0) * in1_l.spatial(1));
|
||||||
size_t out_batched_size = out_l.count() / (out_l.spatial(0) * out_l.spatial(1));
|
size_t out_batched_size = out_l.count() / (out_l.spatial(0) * out_l.spatial(1));
|
||||||
|
|
||||||
auto batched_dims_can_be_removed = in0_batched_size == 1 && in1_batched_size == 1 && out_batched_size == 1;
|
auto batched_dims_can_be_removed = in0_batched_size == 1 && in1_batched_size == 1 && out_batched_size == 1;
|
||||||
if (gemm_with_bias) {
|
if (gemm_with_bias) {
|
||||||
auto bias_l = impl_params.get_input_layout(2);
|
const auto& bias_l = in_layouts[2];
|
||||||
size_t bias_batched_size = bias_l.count() / (bias_l.spatial(0) * bias_l.spatial(1));
|
size_t bias_batched_size = bias_l.count() / (bias_l.spatial(0) * bias_l.spatial(1));
|
||||||
batched_dims_can_be_removed &= bias_batched_size == 1;
|
batched_dims_can_be_removed &= bias_batched_size == 1;
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,6 @@
|
|||||||
#include "reduce_inst.h"
|
#include "reduce_inst.h"
|
||||||
#include "one_hot_inst.h"
|
#include "one_hot_inst.h"
|
||||||
#include "permute_inst.h"
|
#include "permute_inst.h"
|
||||||
#include "gemm_inst.h"
|
|
||||||
#include "quantize_inst.h"
|
#include "quantize_inst.h"
|
||||||
#include "mvn_inst.h"
|
#include "mvn_inst.h"
|
||||||
#include "depth_to_space_inst.h"
|
#include "depth_to_space_inst.h"
|
||||||
@ -1546,7 +1545,9 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto gemm_prim = node.as<gemm>().get_primitive();
|
auto gemm_prim = node.as<gemm>().get_primitive();
|
||||||
if (!node.is_dynamic()) {
|
if (node.is_dynamic()) {
|
||||||
|
impl_candidate = impl_types::ocl;
|
||||||
|
} else {
|
||||||
auto in0_l = node.get_dependency(0).get_output_layout();
|
auto in0_l = node.get_dependency(0).get_output_layout();
|
||||||
auto in1_l = node.get_dependency(1).get_output_layout();
|
auto in1_l = node.get_dependency(1).get_output_layout();
|
||||||
auto out_l = node.get_output_layout();
|
auto out_l = node.get_output_layout();
|
||||||
|
@ -83,11 +83,13 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
|
|||||||
bool is_fc = IsNodeOnConstPath(op->get_input_node_shared_ptr(1));
|
bool is_fc = IsNodeOnConstPath(op->get_input_node_shared_ptr(1));
|
||||||
is_fc &= std::count_if(shape_b.begin(), shape_b.end(), [](Dimension x) { return x != 1; }) <= 2;
|
is_fc &= std::count_if(shape_b.begin(), shape_b.end(), [](Dimension x) { return x != 1; }) <= 2;
|
||||||
// TODO: This conditions can be relaxed with proper handling in FC path
|
// TODO: This conditions can be relaxed with proper handling in FC path
|
||||||
is_fc &= rank_a > 1 && rank_b > 1 && shape_b.is_static();
|
is_fc &= rank_a > 1 && rank_b > 1;
|
||||||
|
|
||||||
PartialShape shape_a_aligned, shape_b_aligned;
|
PartialShape shape_a_aligned, shape_b_aligned;
|
||||||
bool aligned = false;
|
bool aligned = false;
|
||||||
|
if (shape_b.is_static()) {
|
||||||
std::tie(aligned, shape_a_aligned, shape_b_aligned) = get_aligned_shapes(shape_a, shape_b, op);
|
std::tie(aligned, shape_a_aligned, shape_b_aligned) = get_aligned_shapes(shape_a, shape_b, op);
|
||||||
|
}
|
||||||
is_fc &= aligned;
|
is_fc &= aligned;
|
||||||
|
|
||||||
if (is_fc) {
|
if (is_fc) {
|
||||||
@ -161,109 +163,31 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
|
|||||||
p.add_primitive(*op, outReshapePrim);
|
p.add_primitive(*op, outReshapePrim);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto outDims = op->get_output_shape(0);
|
|
||||||
auto outDimsN = outDims.size();
|
|
||||||
|
|
||||||
auto gemmSpecificTensor = [](const InferenceEngine::SizeVector& dims) {
|
|
||||||
switch (dims.size()) {
|
|
||||||
case 2: return cldnn::tensor(cldnn::spatial(dims[1], dims[0]));
|
|
||||||
case 3: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::spatial(dims[2], dims[1]));
|
|
||||||
case 4: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::feature(dims[1]), cldnn::spatial(dims[3], dims[2]));
|
|
||||||
case 5: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::feature(dims[1]), cldnn::spatial(dims[4], dims[3], dims[2]));
|
|
||||||
case 6: return cldnn::tensor(cldnn::batch(dims[0]), cldnn::feature(dims[1]), cldnn::spatial(dims[5], dims[4], dims[3], dims[2]));
|
|
||||||
default: IE_THROW() << "Invalid dimensions size(" << dims.size() << ") for Gemm layer";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Preprocess inputs
|
|
||||||
for (size_t i = 0; i < inputPrimitives.size(); ++i) {
|
|
||||||
auto inputDims = op->get_input_shape(i);
|
|
||||||
auto inputDimsN = inputDims.size();
|
|
||||||
|
|
||||||
// Add reorder if changing number of dimensions requires changing format
|
|
||||||
auto targetFormat = cldnn::format::get_default_format(outDimsN);
|
|
||||||
|
|
||||||
if (targetFormat.value != cldnn::format::get_default_format(inputDimsN).value) {
|
|
||||||
auto reorderName = layerName + "_cldnn_in" + std::to_string(i) + "_reorder";
|
|
||||||
auto targetDatatype = cldnn::element_type_to_data_type(op->get_output_element_type(0));
|
|
||||||
auto reorderPrim = cldnn::reorder(reorderName,
|
|
||||||
inputPrimitives[i],
|
|
||||||
targetFormat,
|
|
||||||
targetDatatype,
|
|
||||||
std::vector<float>(),
|
|
||||||
cldnn::reorder_mean_mode::subtract);
|
|
||||||
|
|
||||||
p.add_primitive(*op, reorderPrim);
|
|
||||||
|
|
||||||
inputPrimitives[i] = reorderName;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reshape input if they differ or gemm specific shape matches default one
|
|
||||||
if (inputDimsN != outDimsN || inputDimsN < 4) {
|
|
||||||
auto reshapeName = layerName + "_cldnn_in" + std::to_string(i) + "_reshape";
|
|
||||||
|
|
||||||
// Extend input dimensions by prepending ones
|
|
||||||
if (inputDimsN == 1) {
|
|
||||||
// One-dimensional tensors unsqueezing is applied for each input independently.
|
|
||||||
// The axes inserted in this step are not included in the output shape.
|
|
||||||
// * If rank of the **first** input is equal to 1, it is always unsqueezed to 2D tensor **row vector** (regardless of `transpose_a`)
|
|
||||||
// by adding axes with size 1 at ROW_INDEX_DIM, to the **left** of the shape. For example `[S]` will be reshaped to `[1, S]`.
|
|
||||||
// * If rank of the **second** input is equal to 1, it is always unsqueezed to 2D tensor **column vector** (regardless of `transpose_b`)
|
|
||||||
// by adding axes with size 1 at COL_INDEX_DIM, to the **right** of the shape. For example `[S]` will be reshaped to `[S, 1]`.
|
|
||||||
bool transpose = false;
|
|
||||||
if (i == 0) {
|
|
||||||
transpose = op->get_transpose_a();
|
|
||||||
inputDims.insert(inputDims.begin(), 1);
|
|
||||||
} else {
|
|
||||||
transpose = op->get_transpose_b();
|
|
||||||
inputDims.insert(inputDims.end(), 1);
|
|
||||||
}
|
|
||||||
// Specs says that shapes must be unsqueezed regardless of tranpose flag, but primitive implementation always respects transposes
|
|
||||||
// so we have to swap dimensions correspondingly to have consistent shapes.
|
|
||||||
if (transpose) {
|
|
||||||
std::swap(inputDims[0], inputDims[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (inputDimsN < outDimsN)
|
|
||||||
inputDims.insert(inputDims.begin(), outDimsN - inputDimsN, 1ul);
|
|
||||||
|
|
||||||
auto targetShape = gemmSpecificTensor(inputDims);
|
|
||||||
|
|
||||||
auto reshapePrim = cldnn::reshape(reshapeName, inputPrimitives[i], targetShape);
|
|
||||||
|
|
||||||
p.add_primitive(*op, reshapePrim);
|
|
||||||
|
|
||||||
inputPrimitives[i] = reshapeName;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add actual gemm
|
// Add actual gemm
|
||||||
auto alpha = 1.0f;
|
auto alpha = 1.0f;
|
||||||
auto beta = 0.0f;
|
auto beta = 0.0f;
|
||||||
auto transA = op->get_transpose_a();
|
|
||||||
auto transB = op->get_transpose_b();
|
|
||||||
|
|
||||||
auto gemmPrim = cldnn::gemm(layerName,
|
auto gemmPrim = cldnn::gemm(layerName,
|
||||||
inputPrimitives,
|
inputPrimitives,
|
||||||
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
||||||
transA,
|
op->get_transpose_a(),
|
||||||
transB,
|
op->get_transpose_b(),
|
||||||
alpha,
|
alpha,
|
||||||
beta);
|
beta,
|
||||||
|
rank_a,
|
||||||
|
rank_b);
|
||||||
|
|
||||||
p.add_primitive(*op, gemmPrim);
|
p.add_primitive(*op, gemmPrim);
|
||||||
|
|
||||||
auto lastLayerName = layerName;
|
if (!p.use_new_shape_infer()) {
|
||||||
|
auto outDims = op->get_output_shape(0);
|
||||||
|
auto outDimsN = outDims.size();
|
||||||
// Reshape output if gemm specific shape does not match default one
|
// Reshape output if gemm specific shape does not match default one
|
||||||
if (outDimsN < 4) {
|
if (outDimsN < 4) {
|
||||||
auto outputShape = tensor_from_dims(outDims);
|
auto outputShape = tensor_from_dims(outDims);
|
||||||
auto outReshapeName = layerName + "_cldnn_out_reshape";
|
auto outReshapeName = layerName + "_cldnn_out_reshape";
|
||||||
auto outReshapePrim = cldnn::reshape(outReshapeName, layerName, outputShape);
|
auto outReshapePrim = cldnn::reshape(outReshapeName, layerName, outputShape);
|
||||||
|
|
||||||
p.add_primitive(*op, outReshapePrim);
|
p.add_primitive(*op, outReshapePrim);
|
||||||
|
}
|
||||||
lastLayerName = outReshapeName;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,11 +14,6 @@ using namespace ov::test;
|
|||||||
|
|
||||||
namespace GPULayerTestsDefinitions {
|
namespace GPULayerTestsDefinitions {
|
||||||
|
|
||||||
enum class MatMulNodeType {
|
|
||||||
MatMul,
|
|
||||||
FullyConnected
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ShapeRelatedParams {
|
struct ShapeRelatedParams {
|
||||||
std::vector<InputShape> inputShapes;
|
std::vector<InputShape> inputShapes;
|
||||||
std::pair<bool, bool> transpose;
|
std::pair<bool, bool> transpose;
|
||||||
@ -374,5 +369,341 @@ const auto fullyConnectedParams4D_smoke = ::testing::Combine(::testing::ValuesIn
|
|||||||
INSTANTIATE_TEST_SUITE_P(smoke_FC_4D, MatMulLayerGPUTest, fullyConnectedParams4D_smoke, MatMulLayerGPUTest::getTestCaseName);
|
INSTANTIATE_TEST_SUITE_P(smoke_FC_4D, MatMulLayerGPUTest, fullyConnectedParams4D_smoke, MatMulLayerGPUTest::getTestCaseName);
|
||||||
|
|
||||||
} // namespace fullyConnected
|
} // namespace fullyConnected
|
||||||
|
|
||||||
|
/* ============= MatMul ============= */
|
||||||
|
namespace matmul {
|
||||||
|
|
||||||
|
const std::vector<ShapeRelatedParams> IS = {
|
||||||
|
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, false}},
|
||||||
|
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, false}},
|
||||||
|
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, true}},
|
||||||
|
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, true}},
|
||||||
|
|
||||||
|
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {false, false}},
|
||||||
|
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {true, false}},
|
||||||
|
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {false, true}},
|
||||||
|
{static_shapes_to_test_representation({{7, 32, 120}, {3, 7, 120, 50}}), {true, true}},
|
||||||
|
|
||||||
|
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {false, false}},
|
||||||
|
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {true, false}},
|
||||||
|
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {false, true}},
|
||||||
|
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {true, true}},
|
||||||
|
|
||||||
|
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {false, false}},
|
||||||
|
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {true, false}},
|
||||||
|
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {false, true}},
|
||||||
|
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {true, true}}
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<ShapeRelatedParams> IS_OneDNN = {
|
||||||
|
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {false, false}},
|
||||||
|
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {true, false}},
|
||||||
|
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {false, true}},
|
||||||
|
{static_shapes_to_test_representation({{2, 4, 32, 120}, {2, 4, 120, 5}}), {true, true}},
|
||||||
|
|
||||||
|
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {false, false}},
|
||||||
|
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {true, false}},
|
||||||
|
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {false, true}},
|
||||||
|
{static_shapes_to_test_representation({{2, 2, 32, 120}, {1, 1, 120, 5}}), {true, true}},
|
||||||
|
|
||||||
|
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {false, false}},
|
||||||
|
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {true, false}},
|
||||||
|
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {false, true}},
|
||||||
|
{static_shapes_to_test_representation({{12, 12}, {12, 12}}), {true, true}}
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<ShapeRelatedParams> IS_Dynamic = {
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
|
||||||
|
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
|
||||||
|
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
|
||||||
|
},
|
||||||
|
{true, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
|
||||||
|
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
|
||||||
|
},
|
||||||
|
{false, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1}, {{55, 12}, {33, 7}}}, // input 0
|
||||||
|
{{-1, -1}, {{12, 55}, {7, 33}}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{true, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{false, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1}, {{60, 5}, {30, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
|
||||||
|
},
|
||||||
|
{true, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
|
||||||
|
},
|
||||||
|
{false, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{7, 32, 60}, {7, 32, 30}}}, // input 0
|
||||||
|
{{-1, -1, -1, -1}, {{3, 7, 60, 25}, {3, 7, 30, 25}}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{true, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{false, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{-1, -1, -1}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{true, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{{1, 15}, {1, 15}, {1, 15}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{false, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, 16 }, {{ 4, 16 }, { 2, 16 }}}, // input 0
|
||||||
|
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
|
||||||
|
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<ShapeRelatedParams> IS_Dynamic_nightly = {
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{{5, 15}, {1, 12}, {4, 15}}, {{10, 10, 10}, {5, 5, 5}}}, // input 0
|
||||||
|
{{{1, 13}, {3, 15}, {1, 10}}, {{10, 10, 10}, {5, 5, 5}}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ {2, 10}, {3, 15}, -1, 16 }, {{ 2, 12, 4, 16 }, { 3, 12, 2, 16 }}}, // input 0
|
||||||
|
{{ 1, 1, -1, 4 }, {{ 1, 1, 16, 4 }, { 1, 1, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ 1, 1, -1, 16 }, {{ 1, 1, 4, 16 }, { 1, 1, 2, 16 }}}, // input 0
|
||||||
|
{{ {2, 5}, {3, 15}, -1, 4 }, {{ 2, 12, 16, 4 }, { 2, 12, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, 16 }, {{ 4, 16 }, { 2, 16 }}}, // input 0
|
||||||
|
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, {2, 15}, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
|
||||||
|
{{ -1, 4 }, {{ 16, 4 }, { 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, {1, 15}, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
|
||||||
|
{{ -1, 4 }, {{ 16, 4 }, { 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ {1, 3}, {1, 9}, {1, 5}, {1, 10} }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
|
||||||
|
{{ {1, 5}, {1, 7}, {1, 8}, {1, 5} }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ {1, 3}, {1, 9}, {1, 5}, {1, 10} }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
|
||||||
|
{{ {1, 5}, {1, 7}, {1, 8}, {1, 5} }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ 1, 7, 4, -1 }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
|
||||||
|
{{ 1, 7, -1, 4 }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ 1, 7, 4, -1 }, {{ 1, 7, 4, 5 }, { 1, 7, 4, 4 }}}, // input 0
|
||||||
|
{{ 1, 7, -1, 4 }, {{ 1, 7, 5, 4 }, { 1, 7, 4, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{false, false}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
|
||||||
|
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{true, true}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
|
||||||
|
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{true, false}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
|
||||||
|
{{ -1, 12, -1, 16 }, {{ 1, 12, 4, 16 }, { 2, 12, 2, 16 }}}, // input 0
|
||||||
|
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
|
||||||
|
},
|
||||||
|
{false, true}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto testParams = ::testing::Combine(::testing::ValuesIn(IS),
|
||||||
|
::testing::ValuesIn(netPRCs),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(helpers::InputLayerType::PARAMETER),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
::testing::ValuesIn(additionalConfig));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MM_Static, MatMulLayerGPUTest, testParams, MatMulLayerGPUTest::getTestCaseName);
|
||||||
|
|
||||||
|
const auto testParamsOneDNN = ::testing::Combine(::testing::ValuesIn(IS_OneDNN),
|
||||||
|
::testing::Values(ElementType::f16),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(helpers::InputLayerType::PARAMETER),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
::testing::ValuesIn(additionalConfig));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MM_Static_OneDNN, MatMulLayerGPUTest, testParamsOneDNN, MatMulLayerGPUTest::getTestCaseName);
|
||||||
|
|
||||||
|
const auto testParamsDynamic = ::testing::Combine(::testing::ValuesIn(IS_Dynamic),
|
||||||
|
::testing::ValuesIn(netPRCs),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(helpers::InputLayerType::PARAMETER),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
::testing::ValuesIn(additionalConfig));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MM_Dynamic, MatMulLayerGPUTest, testParamsDynamic, MatMulLayerGPUTest::getTestCaseName);
|
||||||
|
|
||||||
|
const auto testParamsDynamic_nightly = ::testing::Combine(::testing::ValuesIn(IS_Dynamic_nightly),
|
||||||
|
::testing::ValuesIn(netPRCs),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(ElementType::undefined),
|
||||||
|
::testing::Values(helpers::InputLayerType::PARAMETER),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
::testing::ValuesIn(additionalConfig));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(nightly_MM_Dynamic, MatMulLayerGPUTest, testParamsDynamic_nightly, MatMulLayerGPUTest::getTestCaseName);
|
||||||
|
|
||||||
|
} // namespace matmul
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace GPULayerTestsDefinitions
|
} // namespace GPULayerTestsDefinitions
|
||||||
|
Loading…
Reference in New Issue
Block a user