diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp index b134d0ad960..f112d03dee3 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bfyx_ref.cpp @@ -103,21 +103,6 @@ bool FullyConnected_bfyx_Ref::Validate(const Params& params, const optional_para // int8 validation const auto& fc_params = static_cast(params); - auto input_type = fc_params.inputs[0].GetDType(); - auto output_type = fc_params.outputs[0].GetDType(); - auto filter_type = fc_params.weights.GetDType(); - - // int8/uint8 inputs (quantization case) require additional checks - // require some additional checks. - if ((input_type != Datatype::UINT8 && input_type != Datatype::INT8) && - (output_type != Datatype::UINT8 && output_type != Datatype::INT8)) - return true; - - bool is_quantization = (input_type == Datatype::INT8 || input_type == Datatype::UINT8) && - (filter_type == WeightsType::INT8 || filter_type == WeightsType::UINT8); - - if (!is_quantization) - return false; // We don't support 4d output if (fc_params.outputs[0].GetLayout() == DataLayout::bfyx && fc_params.outputs[0].X().v > 1) diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp index a172ed7a0e6..9d0ffa548be 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp @@ -2408,3 +2408,300 @@ TEST(fully_connected_gpu, has_cached_weights_reorder) { ASSERT_EQ(-2.25f, output_ptr[2]); ASSERT_EQ(3.0f, output_ptr[3]); } + +template +VVF fully_connected_types_reference(VVVVF &input, VVVVF &weights, VF &bias, const quantization_t& quantization, bool relu = false, T slope = 0.0f) { + size_t input_f = input[0].size(); + size_t input_y = input[0][0].size(); + size_t input_x = input[0][0][0].size(); + size_t output_b = input.size(); // input is assumed to be bfyx + size_t output_f = weights.size(); // weights are assumed to be bfyx + VVF output(output_b, VF(output_f)); + float res; + for (size_t b = 0; b < output_b; ++b) { + for (size_t n = 0; n < output_f; ++n) { + res = bias[n]; + for (size_t f = 0; f < input_f; ++f) { + for (size_t y = 0; y < input_y; ++y) { + for (size_t x = 0; x < input_x; ++x) { + res += (float)input[b][f][y][x] * (float)weights[n][f][y][x]; + } + } + } + if (relu && res < (float)0) + res *= (float)slope; + if (res > quantization.output_high) + output[b][n] = quantization.output_high; + else { + if (res < quantization.output_low) + output[b][n] = quantization.output_low; + else + output[b][n] = (T)res; + } + } + } + return output; +} + +using fully_connected_types_test_params = std::tuple< + size_t, // batch_num + size_t, // input_f + size_t, // input_x + size_t, // input_y + size_t, // output_f + format::type // format +>; + +template +class fully_connected_types_test : public ::testing::Test { +private: + size_t batch_num() { return _input.size(); } + size_t input_f() { return _input[0].size(); } + size_t input_y() { return _input[0][0].size(); } + size_t input_x() { return _input[0][0][0].size(); } + size_t output_f() { return _weights.size(); } + + data_types input_data_type() { + return type_to_data_type::value; + } + + data_types weights_data_type() { + return type_to_data_type::value; + } + + bool has_bias() { return _bias.size() > 0; } + +public: + static std::string PrintToStringParamName(testing::TestParamInfo param_info) { + // construct a readable name + return std::to_string(param_info.index) + "_in_" + std::to_string(testing::get<0>(param_info.param)) + + "x" + std::to_string(testing::get<1>(param_info.param)) + + "x" + std::to_string(testing::get<2>(param_info.param)) + + "x" + std::to_string(testing::get<3>(param_info.param)) + + "_of_" + std::to_string(testing::get<4>(param_info.param)) + + "_" + fmt_to_str(testing::get<5>(param_info.param)); + } + + void set_input(VVVVF _data) { + _input = std::move(_data); + } + + void set_weights(VVVVF _data) { + _weights = std::move(_data); + } + + void set_bias(VF _data) { + _bias = std::move(_data); + } + + void set_input_format(format::type fmt) { + _fmt = fmt; + } + + void run_test(VVF expected) { + auto& engine = get_test_engine(); + + auto input_size = tensor(TensorValue(batch_num()), TensorValue(input_f()), TensorValue(input_x()), TensorValue(input_y())); + auto weights_size = tensor(TensorValue(output_f()), TensorValue(input_f()), TensorValue(input_x()), TensorValue(input_y())); + + auto input_prim = engine.allocate_memory({ input_data_type(), _fmt, input_size }); + auto weights_prim = engine.allocate_memory({ weights_data_type(), format::bfyx, weights_size }); + + VF input_flattened(input_prim->get_layout().get_linear_size()); + for (size_t bi = 0; bi < batch_num(); ++bi) + for (size_t fi = 0; fi < input_f(); ++fi) + for (size_t yi = 0; yi < input_y(); ++yi) + for (size_t xi = 0; xi < input_x(); ++xi) { + auto idx = tensor((int32_t)bi, (int32_t)fi, (int32_t)xi, (int32_t)yi); + auto offset = input_size.get_linear_offset(idx, _fmt); + input_flattened[offset] = _input[bi][fi][yi][xi]; + } + + set_values(input_prim, input_flattened); + set_values(weights_prim, flatten_4d(format::bfyx, _weights)); + + auto bias_prim = engine.allocate_memory({ weights_data_type(), format::bfyx, tensor(feature(output_f()))}); + set_values(bias_prim, _bias); + + topology topo; + topo.add(data("weights", weights_prim)); + topo.add(data("bias", bias_prim)); + + topo.add(input_layout("input", input_prim->get_layout())); + + auto input_sizes = input_size.sizes(); + auto last_dim = std::find_if(input_sizes.rbegin(), input_sizes.rend(), + [](tensor::value_type x) { return x != 1l; }); + size_t input_rank = std::distance(input_sizes.begin(), last_dim.base()); + auto fc_prim = fully_connected("output", input_info("input"), "weights", "bias", cldnn::padding(), input_rank); + fc_prim.output_data_types = {type_to_data_type::value}; + topo.add(fc_prim); + + ExecutionConfig config; + config.set_property(ov::intel_gpu::optimize_data(true)); + + network net(engine, topo, config); + net.set_input_data("input", input_prim); + + auto output = net.execute(); + auto out_mem = output.at("output").get_memory(); + cldnn::mem_lock out_ptr(out_mem, get_test_stream()); + + for (size_t bi = 0; bi < batch_num(); ++bi) { + for (size_t fi = 0; fi < output_f(); ++fi) { + ASSERT_NEAR(out_ptr[bi * output_f() + fi], expected[bi][fi], 1) << "at b = " << bi << ", fi = " << fi << ", output_f() = " << output_f(); + } + } + } + +private: + VVVVF _input; + VVVVF _weights; + VF _bias; + format::type _fmt; +}; + +template +class fc_random_types_test + : public fully_connected_types_test + , public ::testing::WithParamInterface< fully_connected_types_test_params> { +public: + void run_random_test() { + size_t b, in_f, in_x, in_y, out_f; + format::type in_fmt; + + std::tie(b, in_f, in_x, in_y, out_f, in_fmt) = GetParam(); + + quantization_t quant_data; + quant_data.output_low = std::numeric_limits::min(); + quant_data.output_high = std::numeric_limits::max(); + + VVVVF input_data = generate_random_4d(b, in_f, in_y, in_x, 0, 127); + VVVVF weights_data = generate_random_4d(out_f, in_f, in_y, in_x, quant_data.output_low , quant_data.output_high); + VF bias_data = generate_random_1d(out_f, quant_data.output_low , quant_data.output_high); + + this->set_input(input_data); + this->set_weights(weights_data); + this->set_bias(bias_data); + this->set_input_format(in_fmt); + + //this->run_test(ref_fully_connected(input_data, weights_data, bias_data, quant_data)); + this->run_test(fully_connected_types_reference(input_data, weights_data, bias_data, quant_data)); + } +}; + +using fully_connected_types_i8_i8_test = fc_random_types_test; +using fully_connected_types_i8_u8_test = fc_random_types_test; +using fully_connected_types_i8_f32_test = fc_random_types_test; + +using fully_connected_types_u8_i8_test = fc_random_types_test; +using fully_connected_types_u8_u8_test = fc_random_types_test; +using fully_connected_types_u8_f32_test = fc_random_types_test; + +TEST_P(fully_connected_types_i8_i8_test, random) { + run_random_test(); +} + +TEST_P(fully_connected_types_i8_u8_test, random) { + run_random_test(); +} + +TEST_P(fully_connected_types_i8_f32_test, random) { + run_random_test(); +} + +TEST_P(fully_connected_types_u8_i8_test, random) { + run_random_test(); +} + +TEST_P(fully_connected_types_u8_u8_test, random) { + run_random_test(); +} + +TEST_P(fully_connected_types_u8_f32_test, random) { + run_random_test(); +} + +INSTANTIATE_TEST_SUITE_P( + basic, + fully_connected_types_i8_i8_test, + testing::Combine( + testing::Values(1, 2), + testing::Values(3, 64), + testing::Values(1), + testing::Values(1), + testing::Values(3, 32), + testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32) + ), + fully_connected_types_i8_i8_test::PrintToStringParamName +); + +INSTANTIATE_TEST_SUITE_P( + basic, + fully_connected_types_i8_u8_test, + testing::Combine( + testing::Values(1, 2), + testing::Values(3, 64), + testing::Values(1), + testing::Values(1), + testing::Values(3, 32), + testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32) + ), + fully_connected_types_i8_u8_test::PrintToStringParamName +); + +INSTANTIATE_TEST_SUITE_P( + basic, + fully_connected_types_i8_f32_test, + testing::Combine( + testing::Values(1, 2), + testing::Values(3, 64), + testing::Values(1), + testing::Values(1), + testing::Values(3, 32), + testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32) + ), + fully_connected_types_i8_f32_test::PrintToStringParamName +); + +INSTANTIATE_TEST_SUITE_P( + basic, + fully_connected_types_u8_i8_test, + testing::Combine( + testing::Values(1, 2), + testing::Values(3, 64), + testing::Values(1), + testing::Values(1), + testing::Values(3, 32), + testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32) + ), + fully_connected_types_u8_i8_test::PrintToStringParamName +); + +INSTANTIATE_TEST_SUITE_P( + basic, + fully_connected_types_u8_u8_test, + testing::Combine( + testing::Values(1, 2), + testing::Values(3, 64), + testing::Values(1), + testing::Values(1), + testing::Values(3, 32), + testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32) + ), + fully_connected_types_u8_u8_test::PrintToStringParamName +); + +INSTANTIATE_TEST_SUITE_P( + basic, + fully_connected_types_u8_f32_test, + testing::Combine( + testing::Values(1, 2), + testing::Values(3, 64), + testing::Values(1), + testing::Values(1), + testing::Values(3, 32), + testing::Values(format::bfyx, format::b_fs_yx_fsv4, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32) + ), + fully_connected_types_u8_f32_test::PrintToStringParamName +);