[IE CLDNN] Always use FP32 as intermediate type for fused quantize (#877)
This commit is contained in:
committed by
GitHub
parent
c846c049e2
commit
a3fce2d763
@@ -433,8 +433,13 @@ Datatype ConvolutionKernelBase::GetActivationType(const convolution_params& para
|
||||
return Datatype::F32;
|
||||
|
||||
if (params.output.GetDType() == Datatype::UINT8 ||
|
||||
params.output.GetDType() == Datatype::INT8)
|
||||
return Datatype::F32;
|
||||
params.output.GetDType() == Datatype::INT8) {
|
||||
if (params.inputs[0].GetDType() == Datatype::F32) {
|
||||
return Datatype::F32;
|
||||
} else if (params.inputs[0].GetDType() == Datatype::F16) {
|
||||
return Datatype::F16;
|
||||
}
|
||||
}
|
||||
|
||||
return GetUnitType(params);
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ JitConstants PermuteKernelRef::GetJitConstants(const permute_params& params) con
|
||||
std::swap(out_idx[3], out_idx[4]);
|
||||
}
|
||||
|
||||
FusedOpsConfiguration conf = {"", out_idx, "input_var", GetUnitType(params), 1};
|
||||
FusedOpsConfiguration conf = {"", out_idx, "input_var", params.inputs[0].GetDType(), 1};
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
|
||||
}
|
||||
|
||||
|
||||
@@ -1153,29 +1153,27 @@ JitConstants FusedOpsCodeGenerator::MakeOpJitConstants(const FusedOpsConfigurati
|
||||
if (!p)
|
||||
throw std::runtime_error("[clDNN] Quantize fuse params can't be nullptr");
|
||||
|
||||
// We can't convert inputs to output data type, because it might be equal to UINT8 or INT8, so we convert the data
|
||||
// to the zero tensor's (input_lo) type
|
||||
std::string in_converted = in_var;
|
||||
Datatype tmp_type = desc.tensors.empty() ? in_type : desc.tensors[0].GetDType();
|
||||
Datatype tmp_type = Datatype::F32;
|
||||
std::string tmp_type_str = GetType(tmp_type, vec_size);
|
||||
std::string tmp_var = out_var + "_tmp";
|
||||
|
||||
if (in_type != tmp_type) {
|
||||
in_converted = ConvertToType(in_var, desc.tensors[0].GetDType(), vec_size);
|
||||
in_converted = ConvertToType(in_var, tmp_type, vec_size);
|
||||
}
|
||||
|
||||
auto post_scale = p->per_tensor_output_scale ? Broadcast(std::to_string(p->out_scale), tmp_type, vec_size)
|
||||
: GetInputVarName(p->out_scale_idx);
|
||||
: ConvertToType(GetInputVarName(p->out_scale_idx), tmp_type, vec_size);
|
||||
auto post_shift = p->per_tensor_output_shift ? Broadcast(std::to_string(p->out_shift), tmp_type, vec_size)
|
||||
: GetInputVarName(p->out_shift_idx);
|
||||
: ConvertToType(GetInputVarName(p->out_shift_idx), tmp_type, vec_size);
|
||||
auto pre_scale = p->per_tensor_input_scale ? Broadcast(std::to_string(p->in_scale), tmp_type, vec_size)
|
||||
: GetInputVarName(p->in_scale_idx);
|
||||
: ConvertToType(GetInputVarName(p->in_scale_idx), tmp_type, vec_size);
|
||||
auto pre_shift = p->per_tensor_input_shift ? Broadcast(std::to_string(p->in_shift), tmp_type, vec_size)
|
||||
: GetInputVarName(p->in_shift_idx);
|
||||
: ConvertToType(GetInputVarName(p->in_shift_idx), tmp_type, vec_size);
|
||||
auto in_lo = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_lo), tmp_type, vec_size)
|
||||
: GetInputVarName(p->in_range_lo_idx);
|
||||
: ConvertToType(GetInputVarName(p->in_range_lo_idx), tmp_type, vec_size);
|
||||
auto in_hi = p->per_tensor_input_range ? Broadcast(std::to_string(p->in_hi), tmp_type, vec_size)
|
||||
: GetInputVarName(p->in_range_hi_idx);
|
||||
: ConvertToType(GetInputVarName(p->in_range_hi_idx), tmp_type, vec_size);
|
||||
|
||||
if (p->has_clamp) {
|
||||
op_decls += "\\\n\t" + tmp_type_str + " " + tmp_var + " = min(max(" + in_lo + ", " + in_converted + "), " + in_hi + ");";
|
||||
|
||||
@@ -4526,21 +4526,17 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_activation_scale_eltwise,
|
||||
permute_params{CASE_PERMUTE_U8_3D_3, 2, 5},
|
||||
}), );
|
||||
|
||||
class permute_scale_eltwise_quant_u8: public PermuteFusingTest {};
|
||||
TEST_P(permute_scale_eltwise_quant_u8, vector_ops) {
|
||||
class permute_quant_u8: public PermuteFusingTest {};
|
||||
TEST_P(permute_quant_u8, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(
|
||||
input_layout("input", get_input_layout(p)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p))),
|
||||
data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.out_shape})),
|
||||
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("in_lo", get_mem(get_single_element_layout(p), min_random, 0)),
|
||||
data("in_hi", get_mem(get_single_element_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), 0)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 255)),
|
||||
permute("permute", "input", p.permute_order),
|
||||
scale("scale1", "permute", "scale_data"),
|
||||
eltwise("eltwise", "scale1", "eltwise_data", eltwise_mode::sum),
|
||||
quantize("quant", "eltwise", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
|
||||
quantize("quant", "permute", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::u8),
|
||||
reorder("reorder_bfyx", "quant", p.default_format, p.default_type)
|
||||
);
|
||||
|
||||
@@ -4548,56 +4544,13 @@ TEST_P(permute_scale_eltwise_quant_u8, vector_ops) {
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_scale_eltwise_quant_u8,
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, permute_quant_u8,
|
||||
::testing::ValuesIn(std::vector<permute_params> {
|
||||
permute_params{CASE_PERMUTE_F32_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_4, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_5, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_6, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_7, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_0, 2, 3},
|
||||
permute_params{CASE_PERMUTE_F32_1, 2, 3},
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_4, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_5, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_6, 2, 5},
|
||||
|
||||
permute_params{CASE_PERMUTE_S8_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3, 2, 5},
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3, 2, 5},
|
||||
|
||||
permute_params{CASE_PERMUTE_F32_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F32_3D_4, 2, 5},
|
||||
|
||||
permute_params{CASE_PERMUTE_F16_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_3D_4, 2, 5},
|
||||
|
||||
permute_params{CASE_PERMUTE_S8_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_S8_3D_3, 2, 5},
|
||||
|
||||
permute_params{CASE_PERMUTE_U8_3D_0, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3D_1, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3D_2, 2, 5},
|
||||
permute_params{CASE_PERMUTE_U8_3D_3, 2, 5},
|
||||
permute_params{CASE_PERMUTE_F16_0, 2, 3},
|
||||
permute_params{CASE_PERMUTE_F16_1, 2, 3},
|
||||
}), );
|
||||
|
||||
class permute_scale_actv_eltw_scale_actv_quant_i8: public PermuteFusingTest {};
|
||||
|
||||
Reference in New Issue
Block a user