[IE CLDNN] Always use FP32 as intermediate type for fused quantize (#877)

This commit is contained in:
Vladimir Paramuzov
2020-06-11 12:27:11 +03:00
committed by GitHub
parent c846c049e2
commit a3fce2d763
4 changed files with 26 additions and 70 deletions

View File

@@ -433,8 +433,13 @@ Datatype ConvolutionKernelBase::GetActivationType(const convolution_params& para
return Datatype::F32;
if (params.output.GetDType() == Datatype::UINT8 ||
params.output.GetDType() == Datatype::INT8)
return Datatype::F32;
params.output.GetDType() == Datatype::INT8) {
if (params.inputs[0].GetDType() == Datatype::F32) {
return Datatype::F32;
} else if (params.inputs[0].GetDType() == Datatype::F16) {
return Datatype::F16;
}
}
return GetUnitType(params);
}

View File

@@ -78,7 +78,7 @@ JitConstants PermuteKernelRef::GetJitConstants(const permute_params& params) con
std::swap(out_idx[3], out_idx[4]);
}
FusedOpsConfiguration conf = {"", out_idx, "input_var", GetUnitType(params), 1};
FusedOpsConfiguration conf = {"", out_idx, "input_var", params.inputs[0].GetDType(), 1};
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
}

View File

@@ -1153,29 +1153,27 @@ JitConstants FusedOpsCodeGenerator::MakeOpJitConstants(const FusedOpsConfigurati
if (!p)
throw std::runtime_error("[clDNN] Quantize fuse params can't be nullptr");
// We can't convert inputs to output data type, because it might be equal to UINT8 or INT8, so we convert the data
// to the zero tensor's (input_lo) type
std::string in_converted = in_var;
Datatype tmp_type = desc.tensors.empty() ? in_type : desc.tensors[0].GetDType();
Datatype tmp_type = Datatype::F32;
std::string tmp_type_str = GetType(tmp_type, vec_size);
std::string tmp_var = out_var + "_tmp";
if (in_type != tmp_type) {
in_converted = ConvertToType(in_var, desc.tensors[0].GetDType(), vec_size);
in_converted = ConvertToType(in_var, tmp_type, vec_size);
}
auto post_scale = p->per_tensor_output_scale ? Broadcast(std::to_string(p->out_scale), tmp_type, vec_size)
: GetInputVarName(p->out_scale_idx);
: ConvertToType(GetInputVarName(p->out_scale_idx), tmp_type, vec_size);
auto post_shift = p->per_tensor_output_shift ? Broadcast(std::to_string(p->out_shift), tmp_type, vec_size)
: GetInputVarName(p->out_shift_idx);
: ConvertToType(GetInputVarName(p->out_shift_idx), tmp_type, vec_size);
auto pre_scale = p->per_tensor_input_scale ? Broadcast(std::to_string(p->in_scale), tmp_type, vec_size)
: GetInputVarName(p->in_scale_idx);
: ConvertToType(GetInputVarName(p->in_scale_idx), tmp_type, vec_size);
auto pre_shift = p->per_tensor_input_shift ? Broadcast(std::to_string(p->in_shift), tmp_type, vec_size)
: GetInputVarName(p->in_shift_idx);
: ConvertToType(GetInputVarName(p->in_shift_idx), tmp_type, vec_size);
auto in_lo = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_lo), tmp_type, vec_size)
: GetInputVarName(p->in_range_lo_idx);
: ConvertToType(GetInputVarName(p->in_range_lo_idx), tmp_type, vec_size);
auto in_hi = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_hi), tmp_type, vec_size)
: GetInputVarName(p->in_range_hi_idx);
: ConvertToType(GetInputVarName(p->in_range_hi_idx), tmp_type, vec_size);
if (p->has_clamp) {
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = min(max(" + in_lo + ", " + in_converted + "), " + in_hi + ");";

View File

@@ -4526,21 +4526,17 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_activation_scale_eltwise,
permute_params{CASE_PERMUTE_U8_3D_3, 2, 5},
}), );
class permute_scale_eltwise_quant_u8: public PermuteFusingTest {};
TEST_P(permute_scale_eltwise_quant_u8, vector_ops) {
class permute_quant_u8: public PermuteFusingTest {};
TEST_P(permute_quant_u8, basic) {
auto p = GetParam();
create_topologies(
input_layout("input", get_input_layout(p)),
data("scale_data", get_mem(get_per_channel_layout(p))),
data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.out_shape})),
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("in_lo", get_mem(get_single_element_layout(p), min_random, 0)),
data("in_hi", get_mem(get_single_element_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), 0)),
data("out_hi", get_mem(get_single_element_layout(p), 255)),
permute("permute", "input", p.permute_order),
scale("scale1", "permute", "scale_data"),
eltwise("eltwise", "scale1", "eltwise_data", eltwise_mode::sum),
quantize("quant", "eltwise", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
quantize("quant", "permute", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
reorder("reorder_bfyx", "quant", p.default_format, p.default_type)
);
@@ -4548,56 +4544,13 @@ TEST_P(permute_scale_eltwise_quant_u8, vector_ops) {
execute(p);
}
INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8,
INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_quant_u8,
::testing::ValuesIn(std::vector<permute_params> {
permute_params{CASE_PERMUTE_F32_0, 2, 5},
permute_params{CASE_PERMUTE_F32_1, 2, 5},
permute_params{CASE_PERMUTE_F32_2, 2, 5},
permute_params{CASE_PERMUTE_F32_3, 2, 5},
permute_params{CASE_PERMUTE_F32_4, 2, 5},
permute_params{CASE_PERMUTE_F32_5, 2, 5},
permute_params{CASE_PERMUTE_F32_6, 2, 5},
permute_params{CASE_PERMUTE_F32_7, 2, 5},
permute_params{CASE_PERMUTE_F32_0, 2, 3},
permute_params{CASE_PERMUTE_F32_1, 2, 3},
permute_params{CASE_PERMUTE_F16_0, 2, 5},
permute_params{CASE_PERMUTE_F16_1, 2, 5},
permute_params{CASE_PERMUTE_F16_2, 2, 5},
permute_params{CASE_PERMUTE_F16_3, 2, 5},
permute_params{CASE_PERMUTE_F16_4, 2, 5},
permute_params{CASE_PERMUTE_F16_5, 2, 5},
permute_params{CASE_PERMUTE_F16_6, 2, 5},
permute_params{CASE_PERMUTE_S8_0, 2, 5},
permute_params{CASE_PERMUTE_S8_1, 2, 5},
permute_params{CASE_PERMUTE_S8_2, 2, 5},
permute_params{CASE_PERMUTE_S8_3, 2, 5},
permute_params{CASE_PERMUTE_U8_0, 2, 5},
permute_params{CASE_PERMUTE_U8_1, 2, 5},
permute_params{CASE_PERMUTE_U8_2, 2, 5},
permute_params{CASE_PERMUTE_U8_3, 2, 5},
permute_params{CASE_PERMUTE_F32_3D_0, 2, 5},
permute_params{CASE_PERMUTE_F32_3D_1, 2, 5},
permute_params{CASE_PERMUTE_F32_3D_2, 2, 5},
permute_params{CASE_PERMUTE_F32_3D_3, 2, 5},
permute_params{CASE_PERMUTE_F32_3D_4, 2, 5},
permute_params{CASE_PERMUTE_F16_3D_0, 2, 5},
permute_params{CASE_PERMUTE_F16_3D_1, 2, 5},
permute_params{CASE_PERMUTE_F16_3D_2, 2, 5},
permute_params{CASE_PERMUTE_F16_3D_3, 2, 5},
permute_params{CASE_PERMUTE_F16_3D_4, 2, 5},
permute_params{CASE_PERMUTE_S8_3D_0, 2, 5},
permute_params{CASE_PERMUTE_S8_3D_1, 2, 5},
permute_params{CASE_PERMUTE_S8_3D_2, 2, 5},
permute_params{CASE_PERMUTE_S8_3D_3, 2, 5},
permute_params{CASE_PERMUTE_U8_3D_0, 2, 5},
permute_params{CASE_PERMUTE_U8_3D_1, 2, 5},
permute_params{CASE_PERMUTE_U8_3D_2, 2, 5},
permute_params{CASE_PERMUTE_U8_3D_3, 2, 5},
permute_params{CASE_PERMUTE_F16_0, 2, 3},
permute_params{CASE_PERMUTE_F16_1, 2, 3},
}), );
class permute_scale_actv_eltw_scale_actv_quant_i8: public PermuteFusingTest {};