[GPU] additional checks fixed for fully_connected (#18068)

This commit is contained in:
Andrei Gorbachev 2023-06-15 06:11:38 +01:00 committed by GitHub
parent 3e63ab0dc3
commit 52834659c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 297 additions and 15 deletions

View File

@ -103,21 +103,6 @@ bool FullyConnected_bfyx_Ref::Validate(const Params& params, const optional_para
// int8 validation
const auto& fc_params = static_cast<const fully_connected_params&>(params);
auto input_type = fc_params.inputs[0].GetDType();
auto output_type = fc_params.outputs[0].GetDType();
auto filter_type = fc_params.weights.GetDType();
// int8/uint8 inputs (quantization case) require additional checks
// require some additional checks.
if ((input_type != Datatype::UINT8 && input_type != Datatype::INT8) &&
(output_type != Datatype::UINT8 && output_type != Datatype::INT8))
return true;
bool is_quantization = (input_type == Datatype::INT8 || input_type == Datatype::UINT8) &&
(filter_type == WeightsType::INT8 || filter_type == WeightsType::UINT8);
if (!is_quantization)
return false;
// We don't support 4d output
if (fc_params.outputs[0].GetLayout() == DataLayout::bfyx && fc_params.outputs[0].X().v > 1)

View File

@ -2408,3 +2408,300 @@ TEST(fully_connected_gpu, has_cached_weights_reorder) {
ASSERT_EQ(-2.25f, output_ptr[2]);
ASSERT_EQ(3.0f, output_ptr[3]);
}
template <typename InputT, typename T>
VVF<T> fully_connected_types_reference(VVVVF<InputT> &input, VVVVF<T> &weights, VF<T> &bias, const quantization_t& quantization, bool relu = false, T slope = 0.0f) {
size_t input_f = input[0].size();
size_t input_y = input[0][0].size();
size_t input_x = input[0][0][0].size();
size_t output_b = input.size(); // input is assumed to be bfyx
size_t output_f = weights.size(); // weights are assumed to be bfyx
VVF<T> output(output_b, VF<T>(output_f));
float res;
for (size_t b = 0; b < output_b; ++b) {
for (size_t n = 0; n < output_f; ++n) {
res = bias[n];
for (size_t f = 0; f < input_f; ++f) {
for (size_t y = 0; y < input_y; ++y) {
for (size_t x = 0; x < input_x; ++x) {
res += (float)input[b][f][y][x] * (float)weights[n][f][y][x];
}
}
}
if (relu && res < (float)0)
res *= (float)slope;
if (res > quantization.output_high)
output[b][n] = quantization.output_high;
else {
if (res < quantization.output_low)
output[b][n] = quantization.output_low;
else
output[b][n] = (T)res;
}
}
}
return output;
}
using fully_connected_types_test_params = std::tuple<
size_t, // batch_num
size_t, // input_f
size_t, // input_x
size_t, // input_y
size_t, // output_f
format::type // format
>;
template <typename InputT, typename WeightsT>
class fully_connected_types_test : public ::testing::Test {
private:
size_t batch_num() { return _input.size(); }
size_t input_f() { return _input[0].size(); }
size_t input_y() { return _input[0][0].size(); }
size_t input_x() { return _input[0][0][0].size(); }
size_t output_f() { return _weights.size(); }
data_types input_data_type() {
return type_to_data_type<InputT>::value;
}
data_types weights_data_type() {
return type_to_data_type<WeightsT>::value;
}
bool has_bias() { return _bias.size() > 0; }
public:
static std::string PrintToStringParamName(testing::TestParamInfo<fully_connected_types_test_params> param_info) {
// construct a readable name
return std::to_string(param_info.index) + "_in_" + std::to_string(testing::get<0>(param_info.param))
+ "x" + std::to_string(testing::get<1>(param_info.param))
+ "x" + std::to_string(testing::get<2>(param_info.param))
+ "x" + std::to_string(testing::get<3>(param_info.param))
+ "_of_" + std::to_string(testing::get<4>(param_info.param))
+ "_" + fmt_to_str(testing::get<5>(param_info.param));
}
void set_input(VVVVF<InputT> _data) {
_input = std::move(_data);
}
void set_weights(VVVVF<WeightsT> _data) {
_weights = std::move(_data);
}
void set_bias(VF<WeightsT> _data) {
_bias = std::move(_data);
}
void set_input_format(format::type fmt) {
_fmt = fmt;
}
void run_test(VVF<WeightsT> expected) {
auto& engine = get_test_engine();
auto input_size = tensor(TensorValue(batch_num()), TensorValue(input_f()), TensorValue(input_x()), TensorValue(input_y()));
auto weights_size = tensor(TensorValue(output_f()), TensorValue(input_f()), TensorValue(input_x()), TensorValue(input_y()));
auto input_prim = engine.allocate_memory({ input_data_type(), _fmt, input_size });
auto weights_prim = engine.allocate_memory({ weights_data_type(), format::bfyx, weights_size });
VF<InputT> input_flattened(input_prim->get_layout().get_linear_size());
for (size_t bi = 0; bi < batch_num(); ++bi)
for (size_t fi = 0; fi < input_f(); ++fi)
for (size_t yi = 0; yi < input_y(); ++yi)
for (size_t xi = 0; xi < input_x(); ++xi) {
auto idx = tensor((int32_t)bi, (int32_t)fi, (int32_t)xi, (int32_t)yi);
auto offset = input_size.get_linear_offset(idx, _fmt);
input_flattened[offset] = _input[bi][fi][yi][xi];
}
set_values(input_prim, input_flattened);
set_values(weights_prim, flatten_4d(format::bfyx, _weights));
auto bias_prim = engine.allocate_memory({ weights_data_type(), format::bfyx, tensor(feature(output_f()))});
set_values(bias_prim, _bias);
topology topo;
topo.add(data("weights", weights_prim));
topo.add(data("bias", bias_prim));
topo.add(input_layout("input", input_prim->get_layout()));
auto input_sizes = input_size.sizes();
auto last_dim = std::find_if(input_sizes.rbegin(), input_sizes.rend(),
[](tensor::value_type x) { return x != 1l; });
size_t input_rank = std::distance(input_sizes.begin(), last_dim.base());
auto fc_prim = fully_connected("output", input_info("input"), "weights", "bias", cldnn::padding(), input_rank);
fc_prim.output_data_types = {type_to_data_type<WeightsT>::value};
topo.add(fc_prim);
ExecutionConfig config;
config.set_property(ov::intel_gpu::optimize_data(true));
network net(engine, topo, config);
net.set_input_data("input", input_prim);
auto output = net.execute();
auto out_mem = output.at("output").get_memory();
cldnn::mem_lock<WeightsT> out_ptr(out_mem, get_test_stream());
for (size_t bi = 0; bi < batch_num(); ++bi) {
for (size_t fi = 0; fi < output_f(); ++fi) {
ASSERT_NEAR(out_ptr[bi * output_f() + fi], expected[bi][fi], 1) << "at b = " << bi << ", fi = " << fi << ", output_f() = " << output_f();
}
}
}
private:
VVVVF<InputT> _input;
VVVVF<WeightsT> _weights;
VF<WeightsT> _bias;
format::type _fmt;
};
template <typename InputT, typename WeightsT>
class fc_random_types_test
: public fully_connected_types_test<InputT, WeightsT>
, public ::testing::WithParamInterface< fully_connected_types_test_params> {
public:
void run_random_test() {
size_t b, in_f, in_x, in_y, out_f;
format::type in_fmt;
std::tie(b, in_f, in_x, in_y, out_f, in_fmt) = GetParam();
quantization_t quant_data;
quant_data.output_low = std::numeric_limits<WeightsT>::min();
quant_data.output_high = std::numeric_limits<WeightsT>::max();
VVVVF<InputT> input_data = generate_random_4d<InputT>(b, in_f, in_y, in_x, 0, 127);
VVVVF<WeightsT> weights_data = generate_random_4d<WeightsT>(out_f, in_f, in_y, in_x, quant_data.output_low , quant_data.output_high);
VF<WeightsT> bias_data = generate_random_1d<WeightsT>(out_f, quant_data.output_low , quant_data.output_high);
this->set_input(input_data);
this->set_weights(weights_data);
this->set_bias(bias_data);
this->set_input_format(in_fmt);
//this->run_test(ref_fully_connected<WeightsT, float, InputT, WeightsT>(input_data, weights_data, bias_data, quant_data));
this->run_test(fully_connected_types_reference<InputT, WeightsT>(input_data, weights_data, bias_data, quant_data));
}
};
using fully_connected_types_i8_i8_test = fc_random_types_test<int8_t, int8_t>;
using fully_connected_types_i8_u8_test = fc_random_types_test<int8_t, uint8_t>;
using fully_connected_types_i8_f32_test = fc_random_types_test<int8_t, float>;
using fully_connected_types_u8_i8_test = fc_random_types_test<uint8_t, int8_t>;
using fully_connected_types_u8_u8_test = fc_random_types_test<uint8_t, uint8_t>;
using fully_connected_types_u8_f32_test = fc_random_types_test<uint8_t, float>;
TEST_P(fully_connected_types_i8_i8_test, random) {
run_random_test();
}
TEST_P(fully_connected_types_i8_u8_test, random) {
run_random_test();
}
TEST_P(fully_connected_types_i8_f32_test, random) {
run_random_test();
}
TEST_P(fully_connected_types_u8_i8_test, random) {
run_random_test();
}
TEST_P(fully_connected_types_u8_u8_test, random) {
run_random_test();
}
TEST_P(fully_connected_types_u8_f32_test, random) {
run_random_test();
}
INSTANTIATE_TEST_SUITE_P(
basic,
fully_connected_types_i8_i8_test,
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
fully_connected_types_i8_i8_test::PrintToStringParamName
);
INSTANTIATE_TEST_SUITE_P(
basic,
fully_connected_types_i8_u8_test,
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
fully_connected_types_i8_u8_test::PrintToStringParamName
);
INSTANTIATE_TEST_SUITE_P(
basic,
fully_connected_types_i8_f32_test,
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
fully_connected_types_i8_f32_test::PrintToStringParamName
);
INSTANTIATE_TEST_SUITE_P(
basic,
fully_connected_types_u8_i8_test,
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
fully_connected_types_u8_i8_test::PrintToStringParamName
);
INSTANTIATE_TEST_SUITE_P(
basic,
fully_connected_types_u8_u8_test,
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
fully_connected_types_u8_u8_test::PrintToStringParamName
);
INSTANTIATE_TEST_SUITE_P(
basic,
fully_connected_types_u8_f32_test,
testing::Combine(
testing::Values(1, 2),
testing::Values(3, 64),
testing::Values(1),
testing::Values(1),
testing::Values(3, 32),
testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32)
),
fully_connected_types_u8_f32_test::PrintToStringParamName
);