[GPU] FullyConnected dynamic (#13015)

* [GPU] FullyConnected dynamic

* [GPU] Fix FC OneDNN usage
This commit is contained in:
Roman Lyamin 2022-10-11 16:51:18 +04:00 committed by GitHub
parent 2d4d80a444
commit 9c6ad77852
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 713 additions and 301 deletions

View File

@ -8,6 +8,7 @@
#include "intel_gpu/runtime/error_handler.hpp"
#include "json_object.h"
#include <string>
#include <algorithm>
#include "matmul_shape_inference.hpp"
@ -94,7 +95,9 @@ layout fully_connected_inst::calc_output_layout(fully_connected_node const& node
auto desc = impl_param.typed_desc<fully_connected>();
auto input_layout = impl_param.get_input_layout();
auto input_pshape = input_layout.get_partial_shape();
auto weights_layout = *impl_param.weights_layout;
auto weights_pshape = weights_layout.get_partial_shape();
auto output_type = input_layout.data_type;
if ((output_type == data_types::u8 || output_type == data_types::i8) && desc->output_data_type)
output_type = *desc->output_data_type;
@ -103,10 +106,23 @@ layout fully_connected_inst::calc_output_layout(fully_connected_node const& node
output_type = impl_param.get_fused_output_layout().data_type;
}
if (input_layout.is_dynamic()) {
auto rank = input_layout.get_rank();
format output_format = format::get_default_format(rank);
return layout(ov::PartialShape::dynamic(rank), output_type, output_format);
auto reshape_to_2d = [](const ov::PartialShape& shape, int64_t feature) {
auto staticShape = shape.to_shape();
size_t total = std::accumulate(staticShape.begin(), staticShape.end(), 1, std::multiplies<size_t>());
std::vector<int64_t> reshapeSize = { static_cast<int64_t>(total) / feature, feature };
return reshapeSize;
};
int64_t feature = input_pshape[std::min(desc->input_size, static_cast<size_t>(4)) - 1].get_length();
if (desc->input_size == 3) {
feature = std::max({input_layout.spatial(0), input_layout.spatial(1), input_layout.spatial(2)});
}
if (desc->input_size > 3) {
input_layout.set_partial_shape(reshape_to_2d(input_pshape, feature));
}
if (weights_pshape.size() != 2) {
weights_layout.set_partial_shape(reshape_to_2d(weights_pshape, feature));
}
auto output_size = tensor(input_layout.batch(), weights_layout.batch(), 1, 1);

View File

@ -745,17 +745,6 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
}
}
// Change input data of fully-connected node from bx to bf
if (input_layout.is_static() && format::is_simple_data_format(input_layout.format) && weights.is_constant() && input_layout.format.dimension() == 4 &&
input_layout.feature() == 1 && input_layout.spatial(0) != 1 && input_layout.spatial(1) == 1) {
auto new_tensor = input_layout.get_tensor();
new_tensor.feature[0] = input_layout.spatial(0);
new_tensor.spatial[0] = 1;
auto new_reshape = std::make_shared<reshape>("reorder:Reshape_bf_" + fc_node.id() + "_for_input", input.id(), new_tensor);
auto& new_reorder_node = p.get_or_create(new_reshape);
p.add_intermediate(new_reorder_node, fc_node, 0);
}
// Change weights type i32 to f32
auto weights_layout = weights.get_output_layout();
if (weights_layout.data_type == data_types::i32) {

View File

@ -17,6 +17,7 @@
#include "intel_gpu/primitives/reorder.hpp"
#include "intel_gpu/primitives/input_layout.hpp"
#include <memory>
#include <algorithm>
namespace cldnn {
namespace ocl {
@ -42,7 +43,59 @@ protected:
public:
static primitive_impl* create(const fully_connected_node& arg, const kernel_impl_params& impl_param) {
const auto primitive = arg.get_primitive();
auto fc_params = get_weights_bias_default_params<kernel_selector::fully_connected_params>(impl_param);
auto get_fc_input_layouts = [primitive](const std::vector<layout>& input_layouts) {
auto reshape_to_2d = [](const ov::PartialShape& shape, int64_t feature) {
auto staticShape = shape.to_shape();
size_t total = std::accumulate(staticShape.begin(), staticShape.end(), 1, std::multiplies<size_t>());
std::vector<int64_t> reshapeSize = { static_cast<int64_t>(total) / feature, feature };
return reshapeSize;
};
auto input0_layout = input_layouts[0];
auto input1_layout = input_layouts[1];
auto input0_pshape = input0_layout.get_partial_shape();
auto input1_pshape = input1_layout.get_partial_shape();
int64_t feature = input0_pshape[std::min(primitive->input_size, static_cast<size_t>(4)) - 1ul].get_length();
if (primitive->input_size > 3) {
input0_layout.set_partial_shape(reshape_to_2d(input0_pshape, feature));
}
if (input1_pshape.size() != 2) {
input1_layout.set_partial_shape(reshape_to_2d(input1_pshape, feature));
}
std::vector<layout> layouts{input0_layout, input1_layout};
return layouts;
};
auto get_fc_output_layout = [primitive](const std::vector<layout>& input_layouts, const layout& output_layout) {
auto updated_out_layout = output_layout;
ov::PartialShape updated_out_pshape { input_layouts[0].get_partial_shape().begin()->get_length(),
input_layouts[1].get_partial_shape().begin()->get_length() };
if (primitive->input_size == 3) {
updated_out_pshape = { input_layouts[0].get_partial_shape().begin()->get_length(),
(input_layouts[0].get_partial_shape().begin() + 1)->get_length(),
input_layouts[1].get_partial_shape().begin()->get_length() };
}
updated_out_layout.set_partial_shape(updated_out_pshape);
return updated_out_layout;
};
auto updated_impl_param = impl_param;
const auto input_layouts = get_fc_input_layouts(impl_param.input_layouts);
updated_impl_param.input_layouts[0] = input_layouts[0];
updated_impl_param.input_layouts[1] = input_layouts[1];
updated_impl_param.weights_layout = input_layouts[1];
updated_impl_param.output_layout = get_fc_output_layout(input_layouts, impl_param.output_layout);
auto fc_params = get_weights_bias_default_params<kernel_selector::fully_connected_params>(updated_impl_param);
auto fc_optional_params =
get_default_weights_bias_optional_params<kernel_selector::fully_connected_optional_params>(
arg.get_program());

View File

@ -19,6 +19,14 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected, dnn
using parent = typed_primitive_onednn_impl<fully_connected, dnnl::inner_product_forward::desc>;
using parent::parent;
private:
static std::vector<int64_t> reshape_to_2d(const ov::PartialShape& shape, int64_t feature) {
auto staticShape = shape.to_shape();
size_t total = std::accumulate(staticShape.begin(), staticShape.end(), 1, std::multiplies<size_t>());
std::vector<int64_t> reshapeSize = { static_cast<int64_t>(total) / feature, feature };
return reshapeSize;
}
protected:
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<fully_connected_onednn>(*this);
@ -58,9 +66,20 @@ protected:
}
static kernel_selector::WeightsReorderParams get_weights_reorder(const kernel_impl_params& impl_params, const dnnl::primitive_desc& pd) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto cldnn_prim = impl_params.typed_desc<fully_connected>();
auto input_pshape = input_layout.get_partial_shape();
auto weights_pshape = weights_layout.get_partial_shape();
int64_t feature = input_pshape[std::min(cldnn_prim->input_size, static_cast<size_t>(4)) - 1].get_length();
if (cldnn_prim->input_size == 3) {
feature = std::max({input_layout.spatial(0), input_layout.spatial(1), input_layout.spatial(2)});
}
if (weights_pshape.size() != 2) {
weights_layout.set_partial_shape(reshape_to_2d(weights_pshape, feature));
}
kernel_selector::WeightsReorderParams weights_reorder_params;
auto& reorderKS = kernel_selector::ReorderWeightsKernelSelctor::Instance();
kernel_selector::reorder_weights_params r_params;
@ -98,6 +117,26 @@ protected:
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.output_layout;
auto input_pshape = input_layout.get_partial_shape();
auto weights_pshape = weights_layout.get_partial_shape();
int64_t feature = input_pshape[std::min(prim->input_size, static_cast<size_t>(4)) - 1].get_length();
if (prim->input_size == 3) {
feature = std::max({input_layout.spatial(0), input_layout.spatial(1), input_layout.spatial(2)});
}
if (prim->input_size > 3) {
input_layout.set_partial_shape(reshape_to_2d(input_pshape, feature));
}
if (weights_pshape.size() != 2) {
weights_layout.set_partial_shape(reshape_to_2d(weights_pshape, feature));
}
if (prim->input_size == 3) {
output_layout.set_partial_shape({ input_layout.batch(), input_layout.feature(), weights_layout.batch(), 1 });
} else {
output_layout.set_partial_shape({ input_layout.batch(), weights_layout.batch() });
}
if (prim->input_size == 3) {
combine_bf_with_first_spatial_dim(input_layout);
combine_bf_with_first_spatial_dim(output_layout);

View File

@ -162,15 +162,16 @@ std::string convert_data_format_string(cldnn::format fmt) {
}
void combine_bf_with_first_spatial_dim(cldnn::layout& l) {
auto rank = cldnn::format::dimension(l.format);
auto last_spatial_dim_idx = rank - 2 - 1;
auto t = l.get_tensor();
t.batch[0] *= l.feature();
t.feature[0] = t.spatial[last_spatial_dim_idx];
t.spatial[last_spatial_dim_idx] = 1;
l.set_tensor(t);
auto pshape = l.get_shape();
ov::Shape new_shape{1, 1};
for (size_t i = 0; i < pshape.size(); ++i) {
if (i < 2) {
new_shape[0] *= pshape[i];
} else {
new_shape[1] *= pshape[i];
}
}
l.set_partial_shape(new_shape);
}
int64_t get_f_offset(cldnn::layout&& l, dnnl::memory::desc&& desc) {
@ -426,6 +427,7 @@ static cldnn::format convert_format(dnnl::memory::format_tag fmt, bool is_groupe
case dnnl::memory::format_tag::ab: return cldnn::format::oiyx;
case dnnl::memory::format_tag::abcd: return cldnn::format::oiyx;
case dnnl::memory::format_tag::bacd: return cldnn::format::ioyx;
case dnnl::memory::format_tag::bcda: return cldnn::format::iyxo;
case dnnl::memory::format_tag::BAcd16b16a: return cldnn::format::is_os_yx_isv16_osv16;
case dnnl::memory::format_tag::ABcd16b16a: return cldnn::format::os_is_yx_isv16_osv16;
case dnnl::memory::format_tag::abcde: return cldnn::format::oizyx;

View File

@ -55,6 +55,8 @@ class typed_primitive_inst<reorder> : public typed_primitive_inst_base<reorder>
using parent = typed_primitive_inst_base<reorder>;
public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(reorder_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(reorder_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(reorder_node const& node);

View File

@ -259,6 +259,7 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::oiyx;
case format::ioyx:
return kernel_selector::weights_layout::ioyx;
case format::iyxo:
case format::fyxb:
return kernel_selector::weights_layout::iyxo;
case format::byxf:

View File

@ -163,13 +163,21 @@ layout reorder_inst::calc_output_layout(reorder_node const& node, kernel_impl_pa
// TODO Shouldn't transform be called every time ifmt != ofmt?
return layout(odt, ofmt, input_layout.get_tensor().transform(ofmt, 1), op);
} else {
if (input_layout.is_static())
return layout(odt, ofmt, input_layout.get_tensor(), op);
else
return layout(input_layout.get_partial_shape(), odt, ofmt, op);
return layout(odt, ofmt, input_layout.get_tensor(), op);
}
}
template<typename ShapeType>
std::vector<layout> reorder_inst::calc_output_layouts(reorder_node const& /*node*/, const kernel_impl_params& impl_param) {
auto desc = impl_param.typed_desc<reorder>();
auto input_layout = impl_param.get_input_layout();
auto ifmt = input_layout.format;
auto ofmt = desc->output_format == format::any ? ifmt : desc->output_format;
return { layout(input_layout.get<ShapeType>(), desc->output_data_type.value(), ofmt, desc->output_padding) };
}
std::string reorder_inst::to_string(reorder_node const& node) {
auto desc = node.get_primitive();
auto mean = desc->mean;

View File

@ -185,9 +185,9 @@ void createClDnnConstant(Program& p, const ngraph::Shape& constDims, const std::
constTensor = getConstTensor(newDims);
}
cldnn::layout constLayout = cldnn::layout(cldnn::element_type_to_data_type(op->get_output_element_type(0)),
constFormat,
constTensor);
cldnn::data_types out_dtype = cldnn::element_type_to_data_type(op->get_output_element_type(0));
cldnn::layout constLayout = p.use_new_shape_infer() ? cldnn::layout(newDims, out_dtype, constFormat) :
cldnn::layout(out_dtype, constFormat, constTensor);
cldnn::primitive_id initialconstPrimID = layer_type_name_ID(op);
cldnn::primitive_id constPrimID;

View File

@ -25,15 +25,23 @@ namespace intel_gpu {
* for example: [2, 32, 64] [3, 64, 64] it will raise an exception.
*/
static std::pair<ngraph::Shape, ngraph::Shape> get_aligned_shapes(const ngraph::Shape& shape_a,
const ngraph::Shape& shape_b,
const std::shared_ptr<ngraph::op::v0::MatMul>& matmul) {
ngraph::Shape shape_a_aligned(shape_a), shape_b_aligned(shape_b);
size_t max_size = std::max(shape_a_aligned.size(), shape_b_aligned.size());
for (size_t i = 0, cnt = max_size - shape_a_aligned.size(); i < cnt; ++i)
static std::tuple<bool, PartialShape, PartialShape> get_aligned_shapes(const PartialShape& shape_a,
const PartialShape& shape_b,
const std::shared_ptr<ngraph::op::v0::MatMul>& matmul) {
PartialShape shape_a_aligned(shape_a), shape_b_aligned(shape_b);
auto rank_a = shape_a_aligned.rank().get_length();
auto rank_b = shape_b_aligned.rank().get_length();
size_t max_size = std::max(rank_a, rank_b);
if (max_size == 1) {
return std::make_tuple(false, shape_a_aligned, shape_b_aligned);
}
for (size_t i = 0, cnt = max_size - rank_a; i < cnt; ++i) {
shape_a_aligned.insert(shape_a_aligned.begin(), 1);
for (size_t i = 0, cnt = max_size - shape_b_aligned.size(); i < cnt; ++i)
}
for (size_t i = 0, cnt = max_size - rank_b; i < cnt; ++i) {
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
}
if (matmul->get_transpose_a()) {
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
@ -43,14 +51,22 @@ static std::pair<ngraph::Shape, ngraph::Shape> get_aligned_shapes(const ngraph::
}
for (size_t i = 0; i < max_size - 2; ++i) {
if (shape_a_aligned[i] != shape_b_aligned[i] && shape_a_aligned[i] > 1 && shape_b_aligned[i] > 1) {
IE_THROW() << "Shapes can't be aligned: " << shape_a_aligned << " " << shape_b_aligned;
auto a_dim = shape_a_aligned[i], b_dim = shape_b_aligned[i];
if (a_dim.is_dynamic()) {
if (b_dim == 1) {
shape_a_aligned[i] = shape_b_aligned[i] = a_dim;
} else {
return std::make_tuple(false, shape_a_aligned, shape_b_aligned);
}
} else {
if (a_dim != b_dim && a_dim.get_length() > 1 && b_dim.get_length() > 1) {
IE_THROW() << "Shapes can't be aligned: " << shape_a_aligned << " " << shape_b_aligned;
}
auto max_value = std::max(a_dim.get_length(), b_dim.get_length());
shape_a_aligned[i] = shape_b_aligned[i] = max_value;
}
size_t max_value = std::max(shape_a_aligned[i], shape_b_aligned[i]);
shape_a_aligned[i] = shape_b_aligned[i] = max_value;
}
return {shape_a_aligned, shape_b_aligned};
return {true, shape_a_aligned, shape_b_aligned};
}
static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::MatMul>& op) {
@ -58,97 +74,67 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);
auto shape_a = op->get_input_shape(0);
auto shape_b = op->get_input_shape(1);
auto shape_a = op->get_input_partial_shape(0);
auto shape_b = op->get_input_partial_shape(1);
auto rank_a = shape_a.rank().get_length();
auto rank_b = shape_b.rank().get_length();
bool is_fc = IsNodeOnConstPath(op->get_input_node_shared_ptr(1));
is_fc &= std::count_if(shape_b.begin(), shape_b.end(), [](size_t x) { return x != 1; }) <= 2;
is_fc &= std::count_if(shape_b.begin(), shape_b.end(), [](Dimension x) { return x != 1; }) <= 2;
// TODO: This conditions can be relaxed with proper handling in FC path
is_fc &= shape_b.size() > 1 && shape_a.size() > 1;
is_fc &= rank_a > 1 && rank_b > 1 && shape_b.is_static();
PartialShape shape_a_aligned, shape_b_aligned;
bool aligned = false;
std::tie(aligned, shape_a_aligned, shape_b_aligned) = get_aligned_shapes(shape_a, shape_b, op);
is_fc &= aligned;
if (is_fc) {
ngraph::Shape shape_a_aligned, shape_b_aligned;
std::tie(shape_a_aligned, shape_b_aligned) = get_aligned_shapes(shape_a, shape_b, op);
if (shape_a_aligned.size() < 2 || shape_b_aligned.size() < 2) {
IE_THROW() << "MatMul " << op->get_friendly_name() << " shapes are inconsistent.";
}
size_t K = *(shape_a_aligned.end() - 1);
auto inputName = inputPrimitives[0];
auto weightsName = inputPrimitives[1];
// Weights normalization
if (!op->get_transpose_b()) {
std::vector<uint16_t> transpose_order(shape_b.size());
auto create_transpose = [&](const std::string& transposeName, const std::string& transposeInputName, size_t rank) {
std::vector<uint16_t> transpose_order(rank);
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
for (auto o = transpose_order.size(); o < 4; o++)
transpose_order.push_back((uint16_t)o);
auto permuteName = op->get_friendly_name() + "/transpose_b";
auto permutePrim = cldnn::permute(permuteName,
weightsName,
auto permutePrim = cldnn::permute(transposeName,
transposeInputName,
transpose_order);
p.add_primitive(*op, permutePrim);
weightsName = permuteName;
};
// Weights normalization
if (!op->get_transpose_b()) {
auto transposeName = op->get_friendly_name() + "/transpose_b";
create_transpose(transposeName, weightsName, rank_b);
weightsName = transposeName;
}
// Input normalization
if (op->get_transpose_a()) {
std::vector<uint16_t> transpose_order(shape_a.size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
for (auto o = transpose_order.size(); o < 4; o++)
transpose_order.push_back((uint16_t)o);
auto permuteName = op->get_friendly_name() + "/transpose_a";
auto permutePrim = cldnn::permute(permuteName,
inputName,
transpose_order);
p.add_primitive(*op, permutePrim);
inputName = permuteName;
auto transposeName = op->get_friendly_name() + "/transpose_a";
create_transpose(transposeName, inputName, rank_a);
inputName = transposeName;
}
bool reshape_fc = shape_a_aligned.size() > 3;
auto reshape_to_2d = [&](const ngraph::Shape& shape, std::string inputName, size_t features, std::string suffix) -> std::string {
size_t total = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
std::vector<size_t> reshapeSize = { total / features, features };
if (total != reshapeSize[0] * reshapeSize[1])
IE_THROW() << "Inconsistent reshape in Matmul op: " << op->get_friendly_name();
auto reshapeInName = op->get_friendly_name() + suffix;
auto reshapeInPrim = cldnn::reshape(reshapeInName,
inputName,
tensor_from_dims(reshapeSize));
p.add_primitive(*op, reshapeInPrim);
return reshapeInName;
};
if (reshape_fc) {
inputName = reshape_to_2d(shape_a, inputName, K, "_cldnn_reshape_in");
}
if (shape_b.size() != 2) {
weightsName = reshape_to_2d(shape_b, weightsName, K, "_cldnn_reshape_weights");
}
auto input_rank = reshape_fc ? 2 : shape_a.size();
auto fcPrim = cldnn::fully_connected(layerName,
inputName,
weightsName,
"",
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
cldnn::padding(),
input_rank);
shape_a.size());
p.add_primitive(*op, fcPrim);
auto lastLayerName = layerName;
if (reshape_fc) {
if (shape_a_aligned.size() > 3 && !p.use_new_shape_infer()) {
auto lastLayerName = layerName;
auto outReshapeName = layerName + "_cldnn_out_reshape";
// add reorder
@ -158,28 +144,21 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
if (outDims.size() > 4) {
cldnn::format outputFormat = cldnn::format::bfyx;
switch (outDims.size()) {
case 5: outputFormat = cldnn::format::bfzyx; break;
case 6: outputFormat = cldnn::format::bfwzyx; break;
default: break;
case 5: outputFormat = cldnn::format::bfzyx; break;
case 6: outputFormat = cldnn::format::bfwzyx; break;
default: break;
}
cldnn::primitive_id reorderId = "reorder:" + outReshapeName + "_reorder";
cldnn::layout outputLayout(cldnn::element_type_to_data_type(op->get_output_element_type(0)), outputFormat, outTensor);
auto reorder_prim = cldnn::reorder(reorderId,
layerName,
outputLayout,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract);
auto reorder_prim = cldnn::reorder(reorderId, layerName, outputLayout);
p.add_primitive(*op, reorder_prim);
lastLayerName = reorderId;
}
// add reshape
auto outReshapePrim = cldnn::reshape(outReshapeName, lastLayerName, outTensor);
p.add_primitive(*op, outReshapePrim);
lastLayerName = outReshapeName;
}
} else {
auto outDims = op->get_output_shape(0);

View File

@ -109,11 +109,7 @@ static void CreateLSTMCellOp(Program& p, const std::shared_ptr<ngraph::op::v4::L
cldnn::layout inputLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, inputShape);
cldnn::layout hiddenLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, inStateShape);
p.add_primitive(*op, cldnn::reshape(inReshapeID, inputPrimitives[0], inputShape));
p.add_primitive(*op, cldnn::reorder(permuteID,
inReshapeID,
inputLayout,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract));
p.add_primitive(*op, cldnn::reorder(permuteID, inReshapeID, inputLayout));
std::string hiddenInResh = inHiddenReshapeID + "_1";
@ -121,22 +117,13 @@ static void CreateLSTMCellOp(Program& p, const std::shared_ptr<ngraph::op::v4::L
std::string cellInResh = inHiddenReshapeID + "_2";
std::string cellInStr = inHiddenReorderID + "_2";
p.add_primitive(*op, cldnn::reshape(hiddenInResh, inputPrimitives[1], inStateShape));
p.add_primitive(*op, cldnn::reorder(hiddenInStr,
hiddenInResh,
hiddenLayout,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract));
p.add_primitive(*op, cldnn::reorder(hiddenInStr, hiddenInResh, hiddenLayout));
p.add_primitive(*op, cldnn::reshape(cellInResh, inputPrimitives[2], inStateShape));
p.add_primitive(*op, cldnn::reorder(cellInStr,
cellInResh,
hiddenLayout,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract));
p.add_primitive(*op, cldnn::reorder(cellInStr, cellInResh, hiddenLayout));
p.add_primitive(*op, cldnn::concatenation(input_concatID,
{ permuteID, hiddenInStr },
3));
cldnn::tensor gemmSz = cldnn::tensor{ lstm_batch_size, 1, 4 * lstm_hidden_size, 1 };
cldnn::layout gemmLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, gemmSz);
cldnn::tensor hiddenSz = cldnn::tensor{ lstm_batch_size, 1, lstm_hidden_size, 1 };
@ -149,13 +136,13 @@ static void CreateLSTMCellOp(Program& p, const std::shared_ptr<ngraph::op::v4::L
cldnn::primitive_id WRconcatID = layerName + "_WRconcat";
p.add_primitive(*op, cldnn::concatenation(WRconcatID, { weightID, recurrentID }, 1));
p.add_primitive(*op, cldnn::fully_connected(lstm_fc_id, input_concatID, WRconcatID, hasBias ? biasID : ""));
cldnn::primitive_id FCInputReshapeID = "Reshape_bf_" + lstm_fc_id + "_for_input";
cldnn::tensor FCInputReshapeSz = { lstm_batch_size, inputShape.spatial[0] + inStateShape.spatial[0], 1, 1 };
p.add_primitive(*op, cldnn::reshape(FCInputReshapeID, input_concatID, FCInputReshapeSz));
p.add_primitive(*op, cldnn::fully_connected(lstm_fc_id, FCInputReshapeID, WRconcatID, hasBias ? biasID : ""));
p.add_primitive(*op, cldnn::reshape(gemmReshapeID, lstm_fc_id, gemmSz));
p.add_primitive(*op, cldnn::reorder(gemmReorderID,
gemmReshapeID,
gemmLayout,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract));
p.add_primitive(*op, cldnn::reorder(gemmReorderID, gemmReshapeID, gemmLayout));
p.add_primitive(*op, cldnn::lstm_elt(lstm_elt_id, gemmReorderID, cellInStr, clip, 0, activations,
activation_params, cldnn::lstm_weights_order::fizo, 0));
@ -220,11 +207,7 @@ static void CreateLSTMSequenceOp(Program& p, const std::shared_ptr<ngraph::op::v
cldnn::tensor inStateShape = { lstm_batch_size, 1, lstm_hidden_size, 1 };
cldnn::layout inputLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, inputShape);
p.add_primitive(*op, cldnn::reshape(inReshapeID, inputPrimitives[0], inputShape));
p.add_primitive(*op, cldnn::reorder(permuteID,
inReshapeID,
inputLayout,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract));
p.add_primitive(*op, cldnn::reorder(permuteID, inReshapeID, inputLayout));
p.add_primitive(*op, cldnn::reshape(inHiddenStateID, inputPrimitives[1], inStateShape));
p.add_primitive(*op, cldnn::reshape(inCellStateID, inputPrimitives[2], inStateShape));
@ -249,6 +232,7 @@ static void CreateLSTMSequenceOp(Program& p, const std::shared_ptr<ngraph::op::v
const std::string id_str = std::to_string(i);
cldnn::primitive_id concatID = layerName + "_inputConcat" + id_str;
cldnn::primitive_id lstm_fc_id = layerName + "_fully_connected" + id_str;
cldnn::primitive_id fc_input_resh_id = "Reshape_bf_" + lstm_fc_id + "_for_input" + id_str;
cldnn::primitive_id lstm_fc_resh_id = layerName + "_gemmReshape" + id_str;
cldnn::primitive_id lstm_fc_reor_id = layerName + "_gemmReorder" + id_str;
cldnn::primitive_id lstm_elt_id = layerName + "_lstm_elt" + id_str;
@ -263,14 +247,15 @@ static void CreateLSTMSequenceOp(Program& p, const std::shared_ptr<ngraph::op::v
p.add_primitive(*op, cldnn::crop(inputCrop_id, permuteID, crop_tensor, offset_tensor));
p.add_primitive(*op, cldnn::concatenation(concatID, { inputCrop_id, hiddenStr }, 3));
p.add_primitive(*op, cldnn::fully_connected(lstm_fc_id, concatID, WRreshapeID, biasID));
cldnn::tensor fc_input_resh_tensor = { crop_tensor.batch[0], crop_tensor.spatial[0] + inStateShape.spatial[0],
crop_tensor.feature[0], crop_tensor.spatial[1]};
p.add_primitive(*op, cldnn::reshape(fc_input_resh_id, concatID, fc_input_resh_tensor));
p.add_primitive(*op, cldnn::fully_connected(lstm_fc_id, fc_input_resh_id, WRreshapeID, biasID));
p.add_primitive(*op, cldnn::reshape(lstm_fc_resh_id, lstm_fc_id, gemmSz));
p.add_primitive(*op, cldnn::reorder(lstm_fc_reor_id,
lstm_fc_resh_id,
gemmLayout,
std::vector<float>(),
cldnn::reorder_mean_mode::subtract));
p.add_primitive(*op, cldnn::reorder(lstm_fc_reor_id, lstm_fc_resh_id, gemmLayout));
p.add_primitive(*op, cldnn::lstm_elt(lstm_elt_id, lstm_fc_reor_id, cellStr, clip, 0, activations,
activation_params, cldnn::lstm_weights_order::fizo, 0));

View File

@ -248,7 +248,7 @@ std::map<std::string, std::string> Plugin::ConvertPerfHintsToConfig(
}
IExecutableNetworkInternal::Ptr Plugin::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network,
const std::map<std::string, std::string> &orig_config) {
const std::map<std::string, std::string> &orig_config) {
OV_ITT_SCOPED_TASK(itt::domains::intel_gpu_plugin, "Plugin::LoadExeNetworkImpl");
// verification of supported input
InferenceEngine::InputsDataMap _networkInputs = network.getInputsInfo();

View File

@ -188,13 +188,13 @@ TEST(fully_connected_gpu, no_biases) {
// Output:
// 2.5 2.75 0.75 7
const int32_t input_x = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
const int32_t input_f = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, 1, input_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, input_f, 1, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_f, 1, 1 } });
set_values(input_prim, { -0.5f, 2.0f, 0.5f });
set_values(weights_prim, { 1.5f, 1.0f, 0.5f, -1.0f, 0.0f, 0.5f, 0.5f, -0.5f, -2.0f, -0.5f, 1.0f, 1.5f });
@ -245,22 +245,22 @@ TEST(fully_connected_gpu, no_biases_int8) {
// Output:
// 18 -32 12 -52
const int32_t input_x = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
const int32_t input_f = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { input_b, 1, input_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { input_b, input_f, 1, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, weight_f, 1, 1 } });
set_values(input_prim, { 8.4f, 2.3f, -4.49f });
set_values<char>(weights_prim, { 2, 1, 0, -3, -2, 1, 0, -2, -4, -5, 10, 8 });
auto input = input_layout("input", input_prim->get_layout());
auto w_data = data("weights", weights_prim);
auto ri = reorder("reorder_to_int", "input", { data_types::i8, format::bfyx, { input_b, 1, input_x, 1 } });
auto ri = reorder("reorder_to_int", "input", { data_types::i8, format::bfyx, { input_b, input_f, 1, 1 } });
auto fc = fully_connected("fc_prim", "reorder_to_int", "weights");
auto rf = reorder("reorder_to_float", "fc_prim", { data_types::f32, format::bfyx, { input_b, 1, 4, 1 } });
auto rf = reorder("reorder_to_float", "fc_prim", { data_types::f32, format::bfyx, { input_b, weight_b, 1, 1 } });
topology topology;
topology.add(input);
topology.add(w_data);
@ -306,13 +306,13 @@ TEST(fully_connected_gpu, xb_f32_batch_1) {
// 2.5 2.75 0.75 7
const int32_t output_f = 4, // size of the whole output buffer
input_x = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
input_f = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, 1, input_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, input_f, 1, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_f, 1, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, output_f, 1} });
set_values(input_prim, { -0.5f, 2.0f, 0.5f });
@ -366,13 +366,13 @@ TEST(fully_connected_gpu, xb_f32_batch_2) {
// 4 1 2.75 5
const int32_t output_f = 4, // size of the whole output buffer
input_x = 3, input_b = 2, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
input_f = 3, input_b = 2, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b,1, input_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, input_f, 1, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_f, 1, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, output_f, 1 } });
set_values(input_prim, { -0.5f, 1.0f, 2.0f, 1.5f, 0.5f, 0.0f });
@ -427,14 +427,14 @@ TEST(fully_connected_gpu, x_f32) {
// 2.5 2.75 0.75 7
const int32_t output_f = 4, // size of the whole output buffer
input_x = 3, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
input_f = 3, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, input_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, input_f, 1, 1 } });
//auto output_prim = memory::allocate({ memory::format::xb_f32, { output_b, { { output_f } }, { 1 } } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_f, 1, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, output_f, 1 } });
set_values(input_prim, { -0.5f, 2.0f, 0.5f });
@ -465,60 +465,6 @@ TEST(fully_connected_gpu, x_f32) {
EXPECT_EQ(7.00f, output_ptr[3]);
}
TEST(fully_connected_gpu, yxfn_f32) {
// Input : 1x2x1x2 - 1 batch 2 feature maps of size 2x1
// Output : 2x1 - 2 batches 1 neuron each
// Weights: 2x2x1x2 - 2 neurons with weights of 2 feature maps of size 2x1
//
// Input:
// 1 -2 f0: b0
// 3 -4 f1: b0
// Weights:
// 1 -1 n0: fm0
// 2 0 n0: fm1
// 3 4 n1: fm0
// 0.5 5 n1: fm1
//
// Biases:
// 1.0 -5
//
// Output:
// 10 -28.5
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { 1, 2, 2, 1 } });
//auto output_prim = memory::allocate({ memory::format::xb_f32, { 2 , { { 1 } }, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 2, 2, 2, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 2, 1 } });
set_values(input_prim, { 1.f, 3.f, -2.f, -4.f });
set_values(weights_prim, { 1.f, -1.f, 2.0f, 0.f, 3.0f, 4.0f, 0.5f, 5.0f });
set_values(bias_prim, { 1.0f, -5.0f });
topology topology(
input_layout("input", input_prim->get_layout()),
data("weights", weights_prim),
data("bias", bias_prim),
fully_connected("fc_prim", "input", "weights", "bias")
);
network network(engine, topology);
network.set_input_data("input", input_prim);
auto outputs = network.execute();
EXPECT_EQ(outputs.size(), size_t(1));
EXPECT_EQ(outputs.begin()->first, "fc_prim");
auto output_prim = outputs.begin()->second.get_memory();
cldnn::mem_lock<float> output_ptr (output_prim, get_test_stream());
EXPECT_EQ(10, output_ptr[0]);
EXPECT_EQ(-28.5, output_ptr[1]);
}
TEST(fully_connected_gpu, xb_f32_batch_1_relu) {
// Input : 3x1
// Output : 4x1
@ -541,14 +487,14 @@ TEST(fully_connected_gpu, xb_f32_batch_1_relu) {
// 2.5 0 0.75 0
const int32_t output_f = 4, // size of the whole output buffer
input_x = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
input_f = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, 1, input_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, input_f, 1, 1 } });
//auto output_prim = memory::allocate({ memory::format::xb_f32, { output_b, { { output_f } }, { 1 } } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_f, 1, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, output_f, 1 } });
set_values(input_prim, { -0.5f, 2.0f, 0.5f });
@ -603,14 +549,14 @@ TEST(fully_connected_gpu, xb_f32_batch_2_relu) {
// 4 0 2.75 0
const int32_t output_f = 4, // size of the whole output buffer
input_x = 3, input_b = 2, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
input_f = 3, input_b = 2, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, 1, input_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::yxfb, { input_b, input_f, 1, 1 } });
//auto output_prim = memory::allocate({ memory::format::xb_f32, { output_b, { { output_f } }, { 1 } } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_f, 1, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, output_f, 1 } });
set_values(input_prim, { -0.5f, 1.0f, 2.0f, 1.5f, 0.5f, 0.0f });
@ -666,14 +612,14 @@ TEST(fully_connected_gpu, x_f32_relu) {
// 2.5 0 0.75 0
const int32_t output_f = 4, // size of the whole output buffer
input_x = 3, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
input_f = 3, // size of the whole input buffer
weight_b = 4, weight_y = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, input_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, input_f, 1, 1 } });
//auto output_prim = memory::allocate({ memory::format::x_f32, { 1 , { { output_f } }, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_y, 1, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, output_f, 1 } });
set_values(input_prim, { -0.5f, 2.0f, 0.5f });
@ -726,14 +672,14 @@ TEST(fully_connected_gpu, x_f32_relu_with_negative_slope) {
// 2.5 -0.125 0.75 -0.1
const int32_t output_f = 4, // size of the whole output buffer
input_x = 3, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
input_f = 3, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_test_engine();
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, input_x, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, input_f, 1, 1 } });
//auto output_prim = memory::allocate({ memory::format::x_f32, { 1 , { { output_f } }, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, 1, weight_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { weight_b, weight_f, 1, 1 } });
auto bias_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, output_f, 1 } });
set_values(input_prim, { -0.5f, 2.0f, 0.5f });
@ -1024,11 +970,10 @@ TEST(fully_connected_gpu, DISABLED_fs_byx_fsv32_b34)
}
}
using shared_dims = std::tuple<size_t, size_t, size_t>;
using fully_connected_test_params = std::tuple<
size_t, // batch_num
size_t, // input_f
size_t, // input_x
size_t, // input_y
shared_dims, // input_f input_x input_y
size_t, // output_f
format::type, // input format
format::type, // output format
@ -1038,11 +983,13 @@ using fully_connected_test_params = std::tuple<
template <typename InputT, typename WeightsT, typename BiasT, typename OutputT>
struct fully_connected_random_test : ::testing::TestWithParam<fully_connected_test_params> {
void run_test() {
shared_dims dims;
size_t batch, input_f, input_x, input_y, output_f;
format::type input_format, output_format;
std::string kernel;
std::tie(batch, input_f, input_x, input_y, output_f, input_format, output_format, kernel) = GetParam();
std::tie(batch, dims, output_f, input_format, output_format, kernel) = GetParam();
std::tie(input_f, input_x, input_y) = dims;
auto input_data = generate_smart_random_4d<InputT>(batch, input_f, input_y, input_x);
auto weights_data = generate_smart_random_4d<WeightsT>(output_f, input_f, input_y, input_x);
@ -1071,9 +1018,8 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f32,
::testing::Combine(
::testing::Values(1, 2),
::testing::Values(3, 32),
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(shared_dims{3, 1, 1},
shared_dims{32, 1, 1}),
::testing::Values(3, 32),
::testing::Values(format::bfyx, format::yxfb),
::testing::Values(format::any),
@ -1085,9 +1031,8 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f32,
::testing::Combine(
::testing::Values(2, 8),
::testing::Values(3, 32),
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(shared_dims{3, 1, 1},
shared_dims{32, 1, 1}),
::testing::Values(3, 32),
::testing::Values(format::bfyx),
::testing::Values(format::bfyx),
@ -1105,9 +1050,8 @@ INSTANTIATE_TEST_SUITE_P(
// Batch 1 is disabled due to sporadic failures in `fully_connected_gpu_bs_f_bsv16_b1`
// - there are nans in output.
::testing::Values(2),
::testing::Values(3, 32),
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(shared_dims{3, 1, 1},
shared_dims{32, 1, 1}),
::testing::Values(3, 32),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1119,9 +1063,8 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f16,
::testing::Combine(
::testing::Values(1, 2),
::testing::Values(3, 32),
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(shared_dims{3, 1, 1},
shared_dims{32, 1, 1}),
::testing::Values(3, 32),
::testing::Values(format::yxfb),
::testing::Values(format::any),
@ -1133,9 +1076,8 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f16,
::testing::Combine(
::testing::Values(2, 8),
::testing::Values(3, 32),
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(shared_dims{3, 1, 1},
shared_dims{32, 1, 1}),
::testing::Values(3, 32),
::testing::Values(format::bfyx),
::testing::Values(format::bfyx),
@ -1146,11 +1088,13 @@ INSTANTIATE_TEST_SUITE_P(
template <typename InputT, typename WeightsT, typename BiasT, typename OutputT>
struct fully_connected_random_test_3d : ::testing::TestWithParam<fully_connected_test_params> {
void run_test() {
shared_dims dims;
size_t batch, input_f, input_x, input_y, output_y;
format::type input_format, output_format;
std::string kernel;
std::tie(batch, input_f, input_x, input_y, output_y, input_format, output_format, kernel) = GetParam();
std::tie(batch, dims, output_y, input_format, output_format, kernel) = GetParam();
std::tie(input_f, input_x, input_y) = dims;
auto input_data = generate_smart_random_4d<InputT>(batch, input_f, input_y, input_x);
auto weights_data = generate_smart_random_4d<WeightsT>(output_y, input_y, 1, 1);
@ -1181,9 +1125,10 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f32_3d,
::testing::Combine(
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(1),
::testing::Values(1, 3, 16),
::testing::Values(shared_dims{1, 1, 1},
shared_dims{1, 1, 3},
shared_dims{3, 1, 1},
shared_dims{3, 1, 3}),
::testing::Values(1, 3, 16),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1195,9 +1140,10 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f32_3d,
::testing::Combine(
::testing::Values(1, 2),
::testing::Values(64, 65),
::testing::Values(1),
::testing::Values(64, 65, 128),
::testing::Values(shared_dims{64, 1, 65},
shared_dims{64, 1, 128},
shared_dims{65, 1, 65},
shared_dims{65, 1, 128}),
::testing::Values(1, 32, 64),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1209,9 +1155,10 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f32_3d,
::testing::Combine(
::testing::Values(3),
::testing::Values(16, 17, 32),
::testing::Values(1),
::testing::Values(17, 32),
::testing::Values(shared_dims{16, 1, 17},
shared_dims{16, 1, 32},
shared_dims{32, 1, 17},
shared_dims{32, 1, 32}),
::testing::Values(17, 32),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1227,9 +1174,10 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_f16_3d,
::testing::Combine(
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(1),
::testing::Values(1, 3, 16),
::testing::Values(shared_dims{1, 1, 1},
shared_dims{1, 1, 16},
shared_dims{3, 1, 1},
shared_dims{3, 1, 16}),
::testing::Values(1, 3, 16),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1245,9 +1193,10 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_i8_3d,
::testing::Combine(
::testing::Values(1, 3),
::testing::Values(1, 3),
::testing::Values(1),
::testing::Values(1, 3, 16),
::testing::Values(shared_dims{1, 1, 1},
shared_dims{1, 1, 16},
shared_dims{3, 1, 1},
shared_dims{3, 1, 16}),
::testing::Values(1, 3, 16),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1259,9 +1208,10 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_i8_3d,
::testing::Combine(
::testing::Values(1, 2),
::testing::Values(64, 65),
::testing::Values(1),
::testing::Values(64, 65, 128),
::testing::Values(shared_dims{64, 1, 65},
shared_dims{64, 1, 128},
shared_dims{65, 1, 65},
shared_dims{65, 1, 128}),
::testing::Values(1, 32, 64),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1273,9 +1223,10 @@ INSTANTIATE_TEST_SUITE_P(
fully_connected_random_test_i8_3d,
::testing::Combine(
::testing::Values(1, 3),
::testing::Values(16, 17),
::testing::Values(1),
::testing::Values(17, 32),
::testing::Values(shared_dims{16, 1, 17},
shared_dims{16, 1, 32},
shared_dims{32, 1, 17},
shared_dims{32, 1, 32}),
::testing::Values(17, 32),
::testing::Values(format::bfyx),
::testing::Values(format::any),
@ -1388,7 +1339,12 @@ public:
topo.add(data("bias", bias_prim));
topo.add(input_layout("input", input_prim->get_layout()));
auto fc_prim = fully_connected("fc_prim", "input", "weights", "bias");
auto input_sizes = input_size.sizes();
auto last_dim = std::find_if(input_sizes.rbegin(), input_sizes.rend(),
[](tensor::value_type x) { return x != 1l; });
size_t input_rank = std::distance(input_sizes.begin(), last_dim.base());
auto fc_prim = fully_connected("fc_prim", "input", "weights", "bias", cldnn::padding(), input_rank);
fc_prim.output_data_type = type_to_data_type<OutputT>::value;
topo.add(fc_prim);
@ -1575,8 +1531,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1, 3),
testing::Values(1, 3),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
@ -1603,8 +1559,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1, 3),
testing::Values(1, 3),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
@ -1617,8 +1573,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 32),
testing::Values(1, 3),
testing::Values(1, 3),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv32)
),
@ -1631,8 +1587,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 32),
testing::Values(1, 3),
testing::Values(1, 3),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv32)
),
@ -1645,8 +1601,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 32),
testing::Values(1, 3),
testing::Values(1, 3),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv32)
),
@ -1659,8 +1615,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 32),
testing::Values(1, 3),
testing::Values(1, 3),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv32)
),
@ -1673,25 +1629,25 @@ TEST(fully_connected_onednn_gpu, no_biases_int8) {
// Output : 4x1
// Weights: 4x3
const int32_t input_x = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_x = 3; // size of the whole weights buffer
const int32_t input_f = 3, input_b = 1, // size of the whole input buffer
weight_b = 4, weight_f = 3; // size of the whole weights buffer
auto& engine = get_onednn_test_engine();
if (!engine.get_device_info().supports_immad)
return;
// Change input data of fully-connected node from bx to bf
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { input_b, 1, input_x, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, weight_x, 1, 1 } });
auto input_prim = engine.allocate_memory({ data_types::f32, format::bfyx, { input_b, input_f, 1, 1 } });
auto weights_prim = engine.allocate_memory({ data_types::i8, format::bfyx, { weight_b, weight_f, 1, 1 } });
set_values(input_prim, { 8.4f, 2.3f, -4.49f });
set_values<char>(weights_prim, { 2, 1, 0, -3, -2, 1, 0, -2, -4, -5, 10, 8 });
auto input = input_layout("input", input_prim->get_layout());
auto w_data = data("weights", weights_prim);
auto ri = reorder("reorder_to_int", "input", { data_types::i8, format::bfyx, { input_b, 1, input_x, 1 } });
auto ri = reorder("reorder_to_int", "input", { data_types::i8, format::bfyx, { input_b, input_f, 1, 1 } });
auto fc = fully_connected("fc_prim", "reorder_to_int", "weights");
auto rf = reorder("reorder_to_float", "fc_prim", { data_types::f32, format::bfyx, { input_b, 1, 4, 1 } });
auto rf = reorder("reorder_to_float", "fc_prim", { data_types::f32, format::bfyx, { input_b, 4, 1, 1 } });
topology topology;
topology.add(input);
topology.add(w_data);
@ -1793,6 +1749,7 @@ TEST(fully_connected_gpu, dynamic) {
build_options options;
options.set_option(build_option::optimize_data(true));
options.set_option(cldnn::build_option::allow_new_shape_infer(true));
network network(engine, topology, options);
network.set_input_data("input", input_data);
@ -1841,6 +1798,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_same_shape) {
build_options options;
options.set_option(build_option::optimize_data(true));
options.set_option(cldnn::build_option::allow_new_shape_infer(true));
network network(engine, topology, options);
{
@ -1918,6 +1876,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_different_shape) {
build_options options;
options.set_option(build_option::optimize_data(true));
options.set_option(cldnn::build_option::allow_new_shape_infer(true));
network network(engine, topology, options);
{
@ -2000,6 +1959,7 @@ TEST(fully_connected_gpu, dynamic_multi_inference_multiple_shapes) {
build_options options;
options.set_option(build_option::optimize_data(true));
options.set_option(cldnn::build_option::allow_new_shape_infer(true));
network network(engine, topology, options);
// Call different shape multiple times to ensure caching works fine

View File

@ -0,0 +1,378 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/single_layer/mat_mul.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ie_precision.hpp"
#include "ngraph_functions/builders.hpp"
#include <string>
using namespace ngraph;
using namespace InferenceEngine;
using namespace ov::test;
namespace GPULayerTestsDefinitions {
enum class MatMulNodeType {
MatMul,
FullyConnected
};
struct ShapeRelatedParams {
std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose;
};
typedef std::tuple<
ShapeRelatedParams,
ElementType, // Network precision
ElementType, // Input precision
ElementType, // Output precision
ngraph::helpers::InputLayerType, // Secondary input type
TargetDevice, // Device name
std::map<std::string, std::string> // Additional network configuration
> MatMulLayerTestParamsSet;
class MatMulLayerGPUTest : public testing::WithParamInterface<MatMulLayerTestParamsSet>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<MatMulLayerTestParamsSet>& obj) {
MatMulLayerTestParamsSet basicParamsSet = obj.param;
ElementType netType;
ElementType inType, outType;
ShapeRelatedParams shapeRelatedParams;
ngraph::helpers::InputLayerType secondaryInputType;
TargetDevice targetDevice;
std::map<std::string, std::string> additionalConfig;
std::tie(shapeRelatedParams, netType, inType, outType, secondaryInputType, targetDevice, additionalConfig) =
basicParamsSet;
std::ostringstream result;
result << "IS=";
for (const auto& shape : shapeRelatedParams.inputShapes) {
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : shapeRelatedParams.inputShapes) {
result << "(";
if (!shape.second.empty()) {
auto itr = shape.second.begin();
do {
result << CommonTestUtils::vec2str(*itr);
} while (++itr != shape.second.end() && result << "_");
}
result << ")_";
}
result << "transpose_a=" << shapeRelatedParams.transpose.first << "_";
result << "transpose_b=" << shapeRelatedParams.transpose.second << "_";
result << "secondaryInputType=" << secondaryInputType << "_";
result << "netPRC=" << netType << "_";
result << "inPRC=" << inType << "_";
result << "outPRC=" << outType << "_";
result << "trgDev=" << targetDevice;
result << "config=(";
for (const auto configEntry : additionalConfig) {
result << configEntry.first << ", " << configEntry.second << ":";
}
result << ")";
return result.str();
}
protected:
template<typename T>
void transpose(T& shape) {
IE_ASSERT(shape.size() > 1);
std::swap(*(shape.end() - 1), *(shape.end() - 2));
}
void SetUp() override {
MatMulLayerTestParamsSet basicParamsSet = this->GetParam();
ShapeRelatedParams shapeRelatedParams;
ElementType netType;
helpers::InputLayerType secondaryInputType;
std::map<std::string, std::string> additionalConfig;
std::tie(shapeRelatedParams, netType, inType, outType, secondaryInputType, targetDevice, additionalConfig) = basicParamsSet;
init_input_shapes(shapeRelatedParams.inputShapes);
bool transpA = shapeRelatedParams.transpose.first;
bool transpB = shapeRelatedParams.transpose.second;
if (transpA) {
transpose(inputDynamicShapes[0]);
for (auto& shapes : targetStaticShapes) {
transpose(shapes[0]);
}
}
if (transpB) {
transpose(inputDynamicShapes[1]);
for (auto& shapes : targetStaticShapes) {
transpose(shapes[1]);
}
}
const auto& inShapeA = inputDynamicShapes[0];
const auto& inShapeB = inputDynamicShapes[1];
configuration.insert(additionalConfig.begin(), additionalConfig.end());
auto params = builder::makeDynamicParams(netType, {inShapeA});
auto matrixB = builder::makeDynamicInputLayer(netType, secondaryInputType, inShapeB);
if (secondaryInputType == helpers::InputLayerType::PARAMETER) {
params.push_back(std::dynamic_pointer_cast<opset1::Parameter>(matrixB));
}
auto paramOuts = helpers::convert2OutputVector(helpers::castOps2Nodes<opset1::Parameter>(params));
auto matMul = builder::makeMatMul(paramOuts[0], matrixB, transpA, transpB);
auto makeFunction = [](const ngraph::element::Type &ngPrc, ngraph::ParameterVector &params, const std::shared_ptr<ngraph::Node> &lastNode) {
ngraph::ResultVector results;
for (int i = 0; i < lastNode->get_output_size(); i++)
results.push_back(std::make_shared<ngraph::opset1::Result>(lastNode->output(i)));
return std::make_shared<ngraph::Function>(results, params, "MatMul");
};
function = makeFunction(netType, params, matMul);
}
};
TEST_P(MatMulLayerGPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
namespace {
/* ============= Common params ============= */
std::map<std::string, std::string> emptyAdditionalConfig;
std::vector<std::map<std::string, std::string>> additionalConfig {
std::map<std::string, std::string>{/* empty config */},
};
const std::vector<ElementType> netPRCs {
ElementType::f32,
};
/* ============= FullyConnected ============= */
namespace fullyConnected {
const std::vector<ShapeRelatedParams> IS2D_smoke = {
{static_shapes_to_test_representation({{59, 1}, {1, 120}}), {false, true}},
{static_shapes_to_test_representation({{59, 1}, {1, 120}}), {true, true}},
{static_shapes_to_test_representation({{59, 120}, {120, 1}}), {false, false}},
{static_shapes_to_test_representation({{59, 120}, {120, 1}}), {true, true}},
{static_shapes_to_test_representation({{1, 120}, {120, 59}}), {false, false}},
{static_shapes_to_test_representation({{1, 120}, {120, 59}}), {true, false}},
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {true, false}},
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, true}},
{
{
{{-1, -1}, {{20, 60}, {20, 60}}},
{{60, 120}, {{60, 120}, {60, 120}}}
},
{false, false}
},
{
{
{{{0, 100}, {0, 12}}, {{20, 1}, {14, 1}, {20, 1}, {14, 1}}},
{{1, 120}, {{1, 120}, {1, 120}, {1, 120}, {1, 120}}}
},
{true, true}
}
};
const std::vector<ShapeRelatedParams> IS2D_nightly = {
{static_shapes_to_test_representation({{59, 1}, {1, 120}}), {false, false}},
{static_shapes_to_test_representation({{59, 1}, {1, 120}}), {true, false}},
{static_shapes_to_test_representation({{59, 120}, {120, 1}}), {true, false}},
{static_shapes_to_test_representation({{59, 120}, {120, 1}}), {false, true}},
{static_shapes_to_test_representation({{1, 120}, {120, 59}}), {true, true}},
{static_shapes_to_test_representation({{1, 120}, {120, 59}}), {false, true}},
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {true, true}},
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, false}},
{
{
{{-1, -1}, {{71, 128}, {50, 128}}},
{{128, 20}, {{128, 20}, {128, 20}}}
},
{false, false}
},
{
{
{{-1, 59}, {{10, 59}, {15, 59}, {15, 59}}},
{{59, 1}, {{59, 1}, {59, 1}, {59, 1}}}
},
{true, false}
},
{
{
{{{0, 120}, 59}, {{5, 59}, {11, 59}, {5, 59}, {10, 59}}},
{{59, 120}, {{59, 120}, {59, 120}, {59, 120}, {59, 120}}}
},
{false, true}
}
};
const auto testParams2D_smoke = ::testing::Combine(::testing::ValuesIn(IS2D_smoke),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::CONSTANT),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D, MatMulLayerGPUTest, testParams2D_smoke, MatMulLayerGPUTest::getTestCaseName);
const auto testParams2D_nightly = ::testing::Combine(::testing::ValuesIn(IS2D_nightly),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::CONSTANT),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(nightly_FC_2D, MatMulLayerGPUTest, testParams2D_nightly, MatMulLayerGPUTest::getTestCaseName);
const std::vector<ShapeRelatedParams> IS3D_smoke = {
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, true}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {true, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {false, true}},
{
{
{{1, 5, 32}, {{1, 5, 32}, {1, 5, 32}}},
{{32, 3}, {{32, 3}, {32, 3}}}
},
{false, true}
},
{static_shapes_to_test_representation({{1, 429}, {1, 429, 1}}), {true, true}},
{
{
{{-1, -1}, {{1, 129}, {2, 129}, {1, 129}, {2, 129}}},
{{1, 129, 1}, {{1, 129, 1}, {1, 129, 1}, {1, 129, 1}, {1, 129, 1}}}
},
{true, true}
},
{
{
{{{0, 60}, {0, 60}, {0, 60}}, {{1, 3, 14}, {1, 7, 14}}},
{{14, 10}, {{14, 10}, {14, 10}}}
},
{true, true}
}
};
const std::vector<ShapeRelatedParams> IS3D_nightly = {
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {true, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {true, true}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {false, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 50}}), {true, true}},
{
{
{{-1, -1, -1}, {{1, 32, 120}, {1, 12, 120}}},
{{120, 3}, {{120, 3}, {120, 3}}}
},
{false, false}
},
{
{
{{-1, -1, 50}, {{1, 2, 50}, {1, 10, 50}, {1, 2, 50}, {2, 2, 50}}},
{{50, 7}, {{50, 7}, {50, 7}, {50, 7}, {50, 7}}}
},
{true, false}
},
{
{
{{-1, -1, 32}, {{1, 5, 32}, {1, 5, 32}}},
{{32, 3}, {{32, 3}, {32, 3}}}
},
{false, true}
}
};
const auto fullyConnectedParams3D_smoke = ::testing::Combine(::testing::ValuesIn(IS3D_smoke),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::CONSTANT),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D, MatMulLayerGPUTest, fullyConnectedParams3D_smoke, MatMulLayerGPUTest::getTestCaseName);
const auto fullyConnectedParams3D_nightly = ::testing::Combine(::testing::ValuesIn(IS3D_nightly),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::CONSTANT),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(nightly_FC_3D, MatMulLayerGPUTest, fullyConnectedParams3D_nightly, MatMulLayerGPUTest::getTestCaseName);
const std::vector<ShapeRelatedParams> IS4D_smoke = {
{
{
{{-1, -1, -1, -1}, {{1, 32, 20, 120}, {1, 12, 20, 120}}},
{{120, 3}, {{120, 3}, {120, 3}}}
},
{false, false}
},
{
{
{{-1, -1, -1, 50}, {{1, 1, 4, 50}, {1, 5, 10, 50}, {1, 2, 5, 50}, {2, 2, 2, 50}}},
{{50, 7}, {{50, 7}, {50, 7}, {50, 7}, {50, 7}}}
},
{true, false}
},
{
{
{{-1, -1, -1, 32}, {{1, 1, 5, 32}, {1, 2, 5, 32}}},
{{32, 3}, {{32, 3}, {32, 3}}}
},
{false, true}
},
{
{
{{{0, 60}, {0, 60}, {0, 60}, {0, 60}}, {{1, 3, 6, 14}, {1, 7, 10, 14}}},
{{14, 10}, {{14, 10}, {14, 10}}}
},
{true, true}
}
};
const auto fullyConnectedParams4D_smoke = ::testing::Combine(::testing::ValuesIn(IS4D_smoke),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::CONSTANT),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_FC_4D, MatMulLayerGPUTest, fullyConnectedParams4D_smoke, MatMulLayerGPUTest::getTestCaseName);
} // namespace fullyConnected
} // namespace
} // namespace GPULayerTestsDefinitions