[GPU] Swap XY axis for 1D conv (#14362)

This commit is contained in:
Jade Cho 2023-01-05 14:58:55 +09:00 committed by GitHub
parent e422b5acb4
commit 9427623046
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 207 additions and 0 deletions

View File

@ -166,6 +166,54 @@ public:
conv_params.quantization = kernel_selector::QuantizationType::NONE;
}
auto can_swap_xy = [&](kernel_selector::convolution_params& cp) -> bool {
if (cp.inputs[0].GetLayout() == kernel_selector::Tensor::DataLayout::bfyx
&& cp.inputs[0].X().v == 1 && cp.inputs[0].Y().v > 1
&& cp.inputs[0].X().pad.Total() == 0
&& cp.outputs[0].GetLayout() == kernel_selector::Tensor::DataLayout::bfyx
&& cp.outputs[0].X().v == 1 && cp.outputs[0].Y().v > 1
&& cp.weights.X().v == 1 && cp.weights.Y().v > 1
&& !((cp.groups == cp.inputs[0].Feature().v && cp.inputs[0].Feature().v == cp.outputs[0].Feature().v && cp.split == 1)
|| (cp.split == cp.inputs[0].Feature().v && cp.groups == 1))) { // Don't swap if it is depthwise conv
auto can_swap = [](const kernel_selector::Tensor::DataTensor& dt) -> bool {
auto x_channel_idx = kernel_selector::Tensor::DataTensor::Channelndex(dt.GetLayout(),
kernel_selector::Tensor::DataChannelName::X);
auto x_axis_dim = dt.GetDims()[x_channel_idx];
return (x_axis_dim.pad.Total() == 0 && x_axis_dim.v == 1);
};
for (auto& desc : cp.fused_ops) {
if (!can_swap(desc.output_tensor)) {
return false;
}
for (size_t i = 0; i < desc.tensors.size(); i++) {
if (!can_swap(desc.tensors[i])) {
return false;
}
}
}
return true;
}
return false;
};
// Swap XY axes
if (can_swap_xy(conv_params) && primitive->deformable_mode == false) {
conv_params.inputs[0].SwapXY();
conv_params.outputs[0].SwapXY();
conv_params.weights.SwapXY();
for (auto& desc : conv_params.fused_ops) {
desc.output_tensor.SwapXY();
for (size_t i = 0; i < desc.tensors.size(); i++) {
desc.tensors[i].SwapXY();
}
}
conv_params.filterSize = { ky, kx, kz };
conv_params.padding = {pad_y, pad_x, pad_z};
conv_params.stride = {stride_y, stride_x, stride_z};
conv_params.dilation = {dilation_y, dilation_x, dilation_z};
}
auto format = impl_param.get_output_layout().format;
if (format == format::b_fs_zyx_fsv16 ||
format == format::bs_fs_zyx_bsv16_fsv16 ||

View File

@ -892,6 +892,7 @@ kernel_selector::weights_tensor convert_weights_tensor(const layout& l, bool is_
const auto d = t[tensor_index];
vec[i] = static_cast<size_t>(d);
}
return kernel_selector::weights_tensor(vec, ks_type, ks_layout);
}

View File

@ -550,6 +550,28 @@ DataTensor DataTensor::FlattenEverything() const {
return res;
}
void DataTensor::SwapXY() {
DataLayout l = Tensor::bfyx;
auto x = X();
auto y = Y();
if (GetLayout() != DataLayout::bfyx)
throw std::runtime_error("Unsupported - unsupported layout.");
if (x.pad.Total() != 0 || x.v != 1)
throw std::runtime_error("Unsupported - unsupported shape.");
// Swap XY axes.
y.pitch = 1;
x.pitch = y.v + y.pad.Total();
std::vector<Dim> vec(ChannelsCount(l));
vec[Channelndex(l, DataChannelName::X)] = y;
vec[Channelndex(l, DataChannelName::Y)] = x;
vec[Channelndex(l, DataChannelName::FEATURE)] = Feature();
vec[Channelndex(l, DataChannelName::BATCH)] = Batch();
*this = {vec, dtype, l};
}
NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l) {
std::vector<size_t> newDims = d;
@ -1051,5 +1073,22 @@ WeightsTensor WeightsTensor::TransformIgnorePadding(WeightsLayout l, WeightsType
return {vec, t, l};
}
void WeightsTensor::SwapXY() {
auto x = X();
if (x.pad.Total() != 0 || x.v != 1)
throw std::runtime_error("Unsupported - unsupported weight shape.");
std::vector<size_t> vec;
for (auto& d : dims) {
vec.push_back(d.v);
}
auto x_index = Channelndex(layout, WeightsChannelName::X);
auto y_index = Channelndex(layout, WeightsChannelName::Y);
std::swap(vec[x_index], vec[y_index]);
*this = {vec, dtype, layout};
}
} // namespace Tensor
} // namespace kernel_selector

View File

@ -616,6 +616,7 @@ struct DataTensor : public TensorBaseT<Datatype, DataLayout> {
DataTensor TransformIgnorePadding(DataLayout l) const;
DataTensor FlattenFeatureAndSpatials() const;
DataTensor FlattenEverything() const;
void SwapXY();
static inline Dim Extract(DataLayout l, DataChannelName channel, const NDims& d) {
return TensorBaseT::Extract(dataChannelArray, l, channel, d);
@ -658,6 +659,8 @@ struct WeightsTensor : TensorBaseT<WeightsType, WeightsLayout> {
Dim OFM() const { return Extract(layout, WeightsChannelName::OFM, dims); }
Dim G() const { return Extract(layout, WeightsChannelName::G, dims); }
void SwapXY();
static inline Dim Extract(WeightsLayout l, WeightsChannelName channel, const NDims& d) {
return TensorBaseT::Extract(weightsChannelArray, l, channel, d);
}

View File

@ -1634,6 +1634,37 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_group_conv_eltwise_sum, ::testin
conv_eltw_test_params{ CASE_GROUP_CONV_ELTW_FP32_1, 3, 2, 3 },
}));
class conv_swap_xy_with_eltwise_diff_sizes : public ConvEltwTest {};
TEST_P(conv_swap_xy_with_eltwise_diff_sizes, basic) {
auto p = GetParam();
create_topologies(
input_layout("input", get_input_layout(p)),
data("weights", get_mem(get_weights_layout(p))),
data("bias", get_mem(get_bias_layout(p))),
data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.eltw_shape })),
convolution("conv_prim", input_info("input"), { "weights" }, { "bias" }, p.groups, p.stride, p.pad, p.dilation),
activation("activation", input_info("conv_prim"), activation_func::relu_negative_slope),
eltwise("sum", { input_info("activation"), input_info("eltwise_data") }, eltwise_mode::sum, data_types::f16),
reorder("reorder_bfyx", input_info("sum"), p.default_format, data_types::f16)
);
tolerance = default_tolerance(p.default_type);
execute(p);
}
// in_shape; out_shape; eltw_shape; kernel; stride; pad; dilation; groups; data_type; input_format; weights_type; weights_format; default_type; default_format;
#define CASE_CONV_ELTW_FP16_SWAP_XY_1 { 1, 16, 1, 5 }, { 1, 32, 1, 7 }, { 1, 32, 1, 1 }, { 1, 1, 1, 3 }, { 1, 1 }, { 2, 0 }, { 1, 1 }, 1, data_types::f16, format::bfyx, data_types::f16, format::os_iyx_osv16, data_types::f16, format::bfyx
#define CASE_CONV_ELTW_FP16_SWAP_XY_2 { 1, 16, 1, 5 }, { 1, 32, 1, 7 }, { 1, 32, 1, 7 }, { 1, 1, 1, 3 }, { 1, 1 }, { 2, 0 }, { 1, 1 }, 1, data_types::f16, format::bfyx, data_types::f16, format::os_iyx_osv16, data_types::f16, format::bfyx
#define CASE_CONV_ELTW_FP32_SWAP_XY_1 { 3, 16, 1, 5 }, { 3, 32, 1, 7 }, { 1, 32, 1, 1 }, { 1, 1, 1, 3 }, { 1, 1 }, { 2, 0 }, { 1, 1 }, 1, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define CASE_CONV_ELTW_FP32_SWAP_XY_2 { 3, 16, 1, 5 }, { 3, 32, 1, 7 }, { 3, 32, 1, 7 }, { 1, 1, 1, 3 }, { 1, 1 }, { 2, 0 }, { 1, 1 }, 1, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_swap_xy_with_eltwise_diff_sizes, ::testing::ValuesIn(std::vector<conv_eltw_test_params>{
conv_eltw_test_params{ CASE_CONV_ELTW_FP16_SWAP_XY_1, 3, 3, 4 },
conv_eltw_test_params{ CASE_CONV_ELTW_FP16_SWAP_XY_2, 3, 3, 4 },
conv_eltw_test_params{ CASE_CONV_ELTW_FP32_SWAP_XY_1, 3, 3, 4 },
conv_eltw_test_params{ CASE_CONV_ELTW_FP32_SWAP_XY_2, 3, 3, 4 },
}));
class conv_scale_activation_eltwise_fp32_quantize_i8 : public ConvEltwTest {};
TEST_P(conv_scale_activation_eltwise_fp32_quantize_i8, basic) {
auto p = GetParam();

View File

@ -23,6 +23,8 @@
#include <fstream>
#include <tuple>
#include "convolution_inst.h"
using namespace cldnn;
using namespace ::tests;
@ -9319,3 +9321,86 @@ TEST(convolution_f32_gpu, convolution_gpu_bfyx_f16_depthwise_x_bloxk_size_1) {
TEST(export_import_convolution_f32_gpu, convolution_gpu_bfyx_f16_depthwise_x_bloxk_size_1) {
test_convolution_f32_gpu_convolution_gpu_bfyx_f16_depthwise_x_bloxk_size_1<FLOAT16>(true);
}
TEST(convolution_f32_fw_gpu, basic_convolution_no_bias_swap_xy) {
// Filter : 2x2x1x3
// Stride : 1x1
// Input : 1x2x1x5
// Output : 1x2x1x3
//
// Input:
// 1 1
// 2 2
// 3 3
// 4 4
// 5 5
//
// Filter:
// 1 1 1 1
// 2 2 2 2
// 1 1 1 1
//
// Output:
// 16 16
// 24 24
// 32 32
auto& engine = get_test_engine();
auto input = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 2, 1, 5 } });
auto weights = engine.allocate_memory({ data_types::f32, format::bfyx, { 2, 2, 1, 3 } });
set_values(input, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f });
set_values(weights, { 1.0f, 2.0f, 1.0f, 1.0f, 2.0f, 1.0f, 1.0f, 2.0f, 1.0f, 1.0f, 2.0f, 1.0f});
VVVF<float> output_vec = {{{ 16.0f }, { 24.0f }, { 32.0f }}, {{ 16.0f }, { 24.0f }, { 32.0f }}};
topology topology(
input_layout("input", input->get_layout()),
data("weights", weights),
convolution("conv", input_info("input"), { "weights" }, { 1, 1 }));
network network(engine, topology);
network.set_input_data("input", input);
auto outputs = network.execute();
ASSERT_EQ(outputs.size(), size_t(1));
ASSERT_EQ(outputs.begin()->first, "conv");
const auto& const_net = network;
const auto conv_inst = const_net.get_primitive("conv");
ASSERT_TRUE(conv_inst != nullptr);
auto output_memory = outputs.at("conv").get_memory();
auto output_layout = output_memory->get_layout();
cldnn::mem_lock<float> output_ptr(output_memory, get_test_stream());
int y_size = output_layout.spatial(1);
int x_size = output_layout.spatial(0);
int f_size = output_layout.feature();
int b_size = output_layout.batch();
ASSERT_EQ(output_layout.format, format::bfyx);
ASSERT_EQ(y_size, 3);
ASSERT_EQ(x_size, 1);
ASSERT_EQ(f_size, 2);
ASSERT_EQ(b_size, 1);
for (int f = 0; f < f_size; ++f) {
for (int y = 0; y < y_size; ++y) {
for (int x = 0; x < x_size; ++x) {
ASSERT_EQ(output_vec[f][y][x], output_ptr[f * y_size + y * x_size + x]);
}
}
}
auto inst = network.get_primitive("conv");
const auto& node = inst->get_node();
auto selected_impl = node.type()->choose_impl(node);
bool found_define = false;
for (auto& s : selected_impl->get_kernels_source()) {
if (s != nullptr && !s->get_str().empty()
&& s->get_str().find("#define INPUT0_SIZE_X 5") != std::string::npos)
found_define = true;
}
EXPECT_TRUE(found_define);
}