[LPT] Extending EliminateFakeQuantize transformation (two interval boundaries) (#17140)

* [LPT] EliminateFakeQuantize extending

* tests

* folding quick fix
This commit is contained in:
Edward Shogulin 2023-04-24 11:58:00 +01:00 committed by GitHub
parent a34ef680f2
commit a3f14366d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 10 deletions

View File

@ -9,6 +9,7 @@
#include <ngraph/ngraph.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
#include "low_precision/network_helper.hpp"
namespace ngraph {
namespace pass {
@ -46,25 +47,63 @@ bool EliminateFakeQuantizeTransformation::transform(TransformationContext& conte
}
namespace {
bool check_interval(const std::shared_ptr<opset1::Constant>& constant, const float value) noexcept {
bool check_interval(const std::shared_ptr<opset1::FakeQuantize>& fq,
const std::shared_ptr<opset1::Constant>& constant,
const float value,
const float max_diff,
const bool exact_comparison) noexcept {
bool need_to_check_intervals = false;
const auto& constant_values = constant->cast_vector<float>();
for (const auto constant_value : constant_values) {
if (std::fabs(constant_value - value) > std::numeric_limits<float>::epsilon()) {
return false;
const auto diff = std::fabs(constant_value - value);
if ((exact_comparison && (std::fabs(constant_value - value) > std::numeric_limits<float>::epsilon())) ||
(diff > max_diff)) {
return false;
}
need_to_check_intervals = true;
}
}
if (need_to_check_intervals) {
auto tmp_fq = as_type_ptr<opset1::FakeQuantize>(fq->clone_with_new_inputs({
constant,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
fq->get_input_node_shared_ptr(4)}));
auto result = NetworkHelper::fold_fake_quantize(tmp_fq, false, 1);
const auto result_constant = as_type_ptr<opset1::Constant>(result);
if (result_constant == nullptr) {
return false;
}
const auto& result_values = result_constant->cast_vector<float>();
for (const auto result_value : result_values) {
if (std::fabs(result_value - value) > std::numeric_limits<float>::epsilon()) {
return false;
}
}
}
return true;
}
bool check_intervals(const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) {
const auto& element_type = fakeQuantize->get_output_element_type(0);
const auto min_value = DataPrecision::getMinValue(element_type, fakeQuantize->get_levels());
const auto max_value = DataPrecision::getMaxValue(element_type, fakeQuantize->get_levels());
const auto levels = fakeQuantize->get_levels();
const auto min_value = DataPrecision::getMinValue(element_type, levels);
const auto max_value = DataPrecision::getMaxValue(element_type, levels);
const auto max_diff = (max_value - min_value) / levels;
// input intervals can be not equal with type intervals for low precision only
const auto exact_comparison = !element_type.is_integral();
return
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(1)), min_value) &&
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(2)), max_value) &&
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(3)), min_value) &&
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(4)), max_value);
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(1)), min_value, max_diff, exact_comparison) &&
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(2)), max_value, max_diff, exact_comparison) &&
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(3)), min_value, max_diff, true) &&
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(4)), max_value, max_diff, true);
}
} // namespace

View File

@ -55,7 +55,6 @@ public:
testValues.actual.fakeQuantizeOnData1,
testValues.actual.fakeQuantizeOnData2,
{});
SimpleLowPrecisionTransformer transformer;
transformer
.add<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation, ngraph::opset1::FakeQuantize>(
@ -82,7 +81,8 @@ public:
result << testValues.inputShape << "_" << testValues.params.updatePrecisions << "_"
<< testValues.actual.precisionBefore << "_" << testValues.actual.fakeQuantizeOnData1 << "_"
<< testValues.actual.fakeQuantizeOnData2 << "_" << testValues.expected.precisionBefore << "_"
<< testValues.expected.fakeQuantizeOnData1 << "_" << testValues.expected.fakeQuantizeOnData2 << "_";
<< testValues.expected.fakeQuantizeOnData1 << "_" << testValues.expected.fakeQuantizeOnData2 << "_"
<< testValues.expected.dequantizationOperations2;
return result.str();
}
};
@ -112,6 +112,36 @@ const std::vector<TransformationTestValues> testValues = {
{ ov::element::f32, {}, {{0.01f}, ov::element::f32, {}} }
}
},
{
{1, 3, 16, 16},
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{256ul, {}, {0.f}, {2.549f}, {0.f}, {2.55f}}
},
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
{},
{ ov::element::f32, {}, {{0.01f}, ov::element::f32, {}} }
}
},
{
{1, 3, 16, 16},
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f / 2.f}}
},
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
{},
{ ov::element::f32, {}, {{0.005f}, ov::element::f32, {}} }
}
},
{
{1, 3, 16, 16},
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),