Fix FakeQuantizeMulFusion for cases with NUMPY broadcasting (#4570)

* Fix FQMul fusion

* Added transformation test

* Removed wrong test
This commit is contained in:
Gleb Kazantaev 2021-03-03 11:34:29 +03:00 committed by GitHub
parent c1925cc220
commit ffade0d1d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 17 deletions

View File

@ -67,13 +67,12 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
const auto fq_output_low_p = ngraph::pattern::any_input();
const auto fq_output_high_p = ngraph::pattern::any_input();
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
{ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
fq_output_low_p,
fq_output_high_p},
pattern::consumers_count(1));
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>({ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
fq_output_low_p,
fq_output_high_p},
pattern::consumers_count(1));
const auto mul_constant_p = ngraph::pattern::wrap_type<opset4::Constant>();
const auto mul_node_p = ngraph::pattern::wrap_type<opset4::Multiply>(
@ -84,9 +83,9 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
const auto original_output_low = pattern_map.at(fq_output_low_p);
const auto original_output_high = pattern_map.at(fq_output_high_p);
const auto mul_constant = pattern_map.at(mul_constant_p);
const auto & original_output_low = pattern_map.at(fq_output_low_p);
const auto & original_output_high = pattern_map.at(fq_output_high_p);
const auto & mul_constant = pattern_map.at(mul_constant_p);
const auto new_output_limits = get_adjusted_output_range(
original_output_low, original_output_high, mul_constant);
@ -98,6 +97,26 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
new_output_limits.second});
const auto mul_node = pattern_map.at(mul_node_p).get_node_shared_ptr();
// WA: this check is intended to prevent replacement when new FQ has shape
// which is different to Multiply output shape. Otherwise such replacement
// will lead to shape inconsistency in remaining graph. This check must be
// removed in future when FQ will have correct validate_and_infer function
// for cases with NUMPY broadcast.
auto fq_casted = std::dynamic_pointer_cast<opset4::FakeQuantize>(new_fq_node);
if (!fq_casted) {
return false;
}
if (fq_casted->get_auto_broadcast() == op::AutoBroadcastType::NUMPY) {
if (fq_casted->get_output_partial_shape(0).is_dynamic() ||
mul_node->get_output_partial_shape(0).is_dynamic()) {
return false;
}
if (fq_casted->get_shape() != mul_node->get_shape()) {
return false;
}
}
replace_node(mul_node, new_fq_node);
new_fq_node->set_friendly_name(fq_node->get_friendly_name());

View File

@ -169,13 +169,6 @@ INSTANTIATE_TEST_CASE_P(FQOutputs_1D__multiplier_3D, FQMulFusion,
::testing::Values(ngraph::Shape{1, 3, 1}),
::testing::Values(ngraph::Shape{1, 3, 1})));
INSTANTIATE_TEST_CASE_P(FQ_all_ones__multiplier_4D_with_channel, FQMulFusion,
::testing::Combine(::testing::Values(ngraph::Shape{1, 1, 1, 1}),
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
INSTANTIATE_TEST_CASE_P(FQInOUt_ones__multiplier_4D_with_channel, FQMulFusion,
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
@ -357,6 +350,40 @@ TEST(FQMulFusion_FQ_Mul_inputs, FQ_out_to_mul_input_2_param) {
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, FakeQuantizeMultiplyFusionNegative) {
const auto data = ngraph::opset4::Constant::create(
ngraph::element::Type_t::f32, ngraph::Shape{1, 300, 1}, {0.0f});
const auto in_low = ngraph::opset4::Constant::create(
ngraph::element::Type_t::f32, ngraph::Shape{}, {-0.5f});
const auto in_high = ngraph::opset4::Constant::create(
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.5f});
const auto out_low = ngraph::opset4::Constant::create(
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
// out_high is a parameter, which means it should not be constant folded
const auto out_high =
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
data, in_low, in_high, out_low, out_high, 42);
const auto mul_value = ngraph::opset4::Constant::create(
ngraph::element::Type_t::f32, ngraph::Shape{1, 300, 16}, {3.14f});
// and here the output of FQ is passed as the second input of Mul
const auto mul = std::make_shared<ngraph::opset4::Multiply>(mul_value, fq);
auto function = std::make_shared<ngraph::Function>(
ngraph::OutputVector{mul}, ngraph::ParameterVector{out_high});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));
ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({1, 300, 16}));
}
} // namespace
} // namespace LayerTestsDefinitions