Fix MulFakeQuantizeFusion with output low/high other than f16 (#5006)
Ticket: 51964
This commit is contained in:
parent
170223d842
commit
0c38a9e4d3
@ -73,7 +73,6 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
|
||||
} else if (std::any_of(mul_const_value.begin(), mul_const_value.end(), [] (float f) -> bool { return f < 0.0f; })) {
|
||||
const auto& output_low = fq->input_value(3);
|
||||
const auto& output_high = fq->input_value(4);
|
||||
auto zero = op::Constant::create(element::f32, Shape{}, {0.0f});
|
||||
// get the mask of the values from mul_const that are less than zero
|
||||
std::vector<float> less_than_zero;
|
||||
less_than_zero.reserve(mul_const_value.size());
|
||||
@ -84,8 +83,8 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
|
||||
less_than_zero.push_back(mul_const_value[i] < 0);
|
||||
greater_eq_zero.push_back(mul_const_value[i] >= 0);
|
||||
}
|
||||
auto less_const = op::Constant::create(element::f32, const_shape, less_than_zero);
|
||||
auto greater_eq_const = op::Constant::create(element::f32, const_shape, greater_eq_zero);
|
||||
auto less_const = op::Constant::create(output_low.get_element_type(), const_shape, less_than_zero);
|
||||
auto greater_eq_const = op::Constant::create(output_low.get_element_type(), const_shape, greater_eq_zero);
|
||||
// new_output_low is defined as follows:
|
||||
// output_low[i], when mul_const[i] >= 0
|
||||
// output_high[i], when mul_const[i] < 0
|
||||
|
@ -218,6 +218,45 @@ TEST(TransformationTests, MulFakeQuantizeFusionConstantSomeNegative) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MulFakeQuantizeFusionConstantSomeNegativeF16) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
Shape data_shape{1, 3, 14, 14};
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f16, data_shape);
|
||||
auto mul_const = opset5::Constant::create(element::f16, Shape{3, 1, 1}, {2, 1, -2});
|
||||
auto mul = std::make_shared<opset5::Multiply>(data, mul_const);
|
||||
auto input_low = opset5::Constant::create(element::f16, Shape{1}, {1});
|
||||
auto input_high = opset5::Constant::create(element::f16, Shape{1}, {20});
|
||||
auto output_low = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {-10, -10, -10});
|
||||
auto output_high = opset5::Constant::create(element::f16, Shape{1}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(mul, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 20);
|
||||
f = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::MulFakeQuantizeFusion>();
|
||||
m.register_pass<pass::ConstantFolding>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f16, data_shape);
|
||||
auto input_low = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {0.5f, 1.0f, -0.5f});
|
||||
auto input_high = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {10.0f, 20.0f, -10.0f});
|
||||
auto output_low = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {-10.0f, -10.0f, 10.0f});
|
||||
auto output_high = opset5::Constant::create(element::f16, Shape{1, 3, 1, 1}, {10.0f, 10.0f, -10.0f});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(data, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 20);
|
||||
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, NegativeMulFakeQuantizeFusionNotAConstant) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user