[LPT] Extending EliminateFakeQuantize transformation (two interval boundaries) (#17140)
* [LPT] EliminateFakeQuantize extending * tests * folding quick fix
This commit is contained in:
parent
a34ef680f2
commit
a3f14366d9
@ -9,6 +9,7 @@
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include "itt.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
@ -46,25 +47,63 @@ bool EliminateFakeQuantizeTransformation::transform(TransformationContext& conte
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool check_interval(const std::shared_ptr<opset1::Constant>& constant, const float value) noexcept {
|
||||
bool check_interval(const std::shared_ptr<opset1::FakeQuantize>& fq,
|
||||
const std::shared_ptr<opset1::Constant>& constant,
|
||||
const float value,
|
||||
const float max_diff,
|
||||
const bool exact_comparison) noexcept {
|
||||
bool need_to_check_intervals = false;
|
||||
const auto& constant_values = constant->cast_vector<float>();
|
||||
for (const auto constant_value : constant_values) {
|
||||
if (std::fabs(constant_value - value) > std::numeric_limits<float>::epsilon()) {
|
||||
return false;
|
||||
const auto diff = std::fabs(constant_value - value);
|
||||
if ((exact_comparison && (std::fabs(constant_value - value) > std::numeric_limits<float>::epsilon())) ||
|
||||
(diff > max_diff)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
need_to_check_intervals = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_to_check_intervals) {
|
||||
auto tmp_fq = as_type_ptr<opset1::FakeQuantize>(fq->clone_with_new_inputs({
|
||||
constant,
|
||||
fq->get_input_node_shared_ptr(1),
|
||||
fq->get_input_node_shared_ptr(2),
|
||||
fq->get_input_node_shared_ptr(3),
|
||||
fq->get_input_node_shared_ptr(4)}));
|
||||
auto result = NetworkHelper::fold_fake_quantize(tmp_fq, false, 1);
|
||||
const auto result_constant = as_type_ptr<opset1::Constant>(result);
|
||||
if (result_constant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& result_values = result_constant->cast_vector<float>();
|
||||
for (const auto result_value : result_values) {
|
||||
if (std::fabs(result_value - value) > std::numeric_limits<float>::epsilon()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_intervals(const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) {
|
||||
const auto& element_type = fakeQuantize->get_output_element_type(0);
|
||||
const auto min_value = DataPrecision::getMinValue(element_type, fakeQuantize->get_levels());
|
||||
const auto max_value = DataPrecision::getMaxValue(element_type, fakeQuantize->get_levels());
|
||||
const auto levels = fakeQuantize->get_levels();
|
||||
const auto min_value = DataPrecision::getMinValue(element_type, levels);
|
||||
const auto max_value = DataPrecision::getMaxValue(element_type, levels);
|
||||
const auto max_diff = (max_value - min_value) / levels;
|
||||
// input intervals can be not equal with type intervals for low precision only
|
||||
const auto exact_comparison = !element_type.is_integral();
|
||||
|
||||
return
|
||||
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(1)), min_value) &&
|
||||
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(2)), max_value) &&
|
||||
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(3)), min_value) &&
|
||||
check_interval(ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(4)), max_value);
|
||||
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(1)), min_value, max_diff, exact_comparison) &&
|
||||
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(2)), max_value, max_diff, exact_comparison) &&
|
||||
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(3)), min_value, max_diff, true) &&
|
||||
check_interval(fakeQuantize, ov::as_type_ptr<opset1::Constant>(fakeQuantize->get_input_node_shared_ptr(4)), max_value, max_diff, true);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -55,7 +55,6 @@ public:
|
||||
testValues.actual.fakeQuantizeOnData1,
|
||||
testValues.actual.fakeQuantizeOnData2,
|
||||
{});
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer
|
||||
.add<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation, ngraph::opset1::FakeQuantize>(
|
||||
@ -82,7 +81,8 @@ public:
|
||||
result << testValues.inputShape << "_" << testValues.params.updatePrecisions << "_"
|
||||
<< testValues.actual.precisionBefore << "_" << testValues.actual.fakeQuantizeOnData1 << "_"
|
||||
<< testValues.actual.fakeQuantizeOnData2 << "_" << testValues.expected.precisionBefore << "_"
|
||||
<< testValues.expected.fakeQuantizeOnData1 << "_" << testValues.expected.fakeQuantizeOnData2 << "_";
|
||||
<< testValues.expected.fakeQuantizeOnData1 << "_" << testValues.expected.fakeQuantizeOnData2 << "_"
|
||||
<< testValues.expected.dequantizationOperations2;
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
@ -112,6 +112,36 @@ const std::vector<TransformationTestValues> testValues = {
|
||||
{ ov::element::f32, {}, {{0.01f}, ov::element::f32, {}} }
|
||||
}
|
||||
},
|
||||
{
|
||||
{1, 3, 16, 16},
|
||||
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),
|
||||
{
|
||||
element::f32,
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
|
||||
{256ul, {}, {0.f}, {2.549f}, {0.f}, {2.55f}}
|
||||
},
|
||||
{
|
||||
element::f32,
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
|
||||
{},
|
||||
{ ov::element::f32, {}, {{0.01f}, ov::element::f32, {}} }
|
||||
}
|
||||
},
|
||||
{
|
||||
{1, 3, 16, 16},
|
||||
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),
|
||||
{
|
||||
element::f32,
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f / 2.f}}
|
||||
},
|
||||
{
|
||||
element::f32,
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
|
||||
{},
|
||||
{ ov::element::f32, {}, {{0.005f}, ov::element::f32, {}} }
|
||||
}
|
||||
},
|
||||
{
|
||||
{1, 3, 16, 16},
|
||||
TestTransformationParams(true, {ngraph::element::u8}, {ngraph::element::i8}),
|
||||
|
Loading…
Reference in New Issue
Block a user