Fix MulFakeQuantizeFusion with output low/high other than f16 (#5006)

Ticket: 51964
This commit is contained in:
Mateusz Tabaka 2021-03-30 15:53:59 +02:00 committed by GitHub
parent 170223d842
commit 0c38a9e4d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 3 deletions

View File

@ -73,7 +73,6 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
} else if (std::any_of(mul_const_value.begin(), mul_const_value.end(), [] (float f) -> bool { return f < 0.0f; })) {
const auto& output_low = fq->input_value(3);
const auto& output_high = fq->input_value(4);
auto zero = op::Constant::create(element::f32, Shape{}, {0.0f});
// get the mask of the values from mul_const that are less than zero
std::vector<float> less_than_zero;
less_than_zero.reserve(mul_const_value.size());
@ -84,8 +83,8 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
less_than_zero.push_back(mul_const_value[i] < 0);
greater_eq_zero.push_back(mul_const_value[i] >= 0);
}
auto less_const = op::Constant::create(element::f32, const_shape, less_than_zero);
auto greater_eq_const = op::Constant::create(element::f32, const_shape, greater_eq_zero);
auto less_const = op::Constant::create(output_low.get_element_type(), const_shape, less_than_zero);
auto greater_eq_const = op::Constant::create(output_low.get_element_type(), const_shape, greater_eq_zero);
// new_output_low is defined as follows:
// output_low[i], when mul_const[i] >= 0
// output_high[i], when mul_const[i] < 0

View File

@ -218,6 +218,45 @@ TEST(TransformationTests, MulFakeQuantizeFusionConstantSomeNegative) {
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MulFakeQuantizeFusionConstantSomeNegativeF16) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{1, 3, 14, 14};
{
auto data = std::make_shared<opset5::Parameter>(element::f16, data_shape);
auto mul_const = opset5::Constant::create(element::f16, Shape{3, 1, 1}, {2, 1, -2});
auto mul = std::make_shared<opset5::Multiply>(data, mul_const);
auto input_low = opset5::Constant::create(element::f16, Shape{1}, {1});
auto input_high = opset5::Constant::create(element::f16, Shape{1}, {20});
auto output_low = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {-10, -10, -10});
auto output_high = opset5::Constant::create(element::f16, Shape{1}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(mul, input_low,
input_high, output_low,
output_high, 20);
f = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::MulFakeQuantizeFusion>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f16, data_shape);
auto input_low = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {0.5f, 1.0f, -0.5f});
auto input_high = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {10.0f, 20.0f, -10.0f});
auto output_low = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {-10.0f, -10.0f, 10.0f});
auto output_high = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {10.0f, 10.0f, -10.0f});
auto fq = std::make_shared<opset5::FakeQuantize>(data, input_low,
input_high, output_low,
output_high, 20);
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, NegativeMulFakeQuantizeFusionNotAConstant) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);