Fix FakeQuantizeMulFusion for cases with NUMPY broadcasting (#4570)
* Fix FQMul fusion * Added transformation test * Removed wrong test
This commit is contained in:
parent
c1925cc220
commit
ffade0d1d8
@ -67,13 +67,12 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
|
||||
const auto fq_output_low_p = ngraph::pattern::any_input();
|
||||
const auto fq_output_high_p = ngraph::pattern::any_input();
|
||||
|
||||
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
|
||||
{ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
fq_output_low_p,
|
||||
fq_output_high_p},
|
||||
pattern::consumers_count(1));
|
||||
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>({ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
fq_output_low_p,
|
||||
fq_output_high_p},
|
||||
pattern::consumers_count(1));
|
||||
|
||||
const auto mul_constant_p = ngraph::pattern::wrap_type<opset4::Constant>();
|
||||
const auto mul_node_p = ngraph::pattern::wrap_type<opset4::Multiply>(
|
||||
@ -84,9 +83,9 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
|
||||
|
||||
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
|
||||
|
||||
const auto original_output_low = pattern_map.at(fq_output_low_p);
|
||||
const auto original_output_high = pattern_map.at(fq_output_high_p);
|
||||
const auto mul_constant = pattern_map.at(mul_constant_p);
|
||||
const auto & original_output_low = pattern_map.at(fq_output_low_p);
|
||||
const auto & original_output_high = pattern_map.at(fq_output_high_p);
|
||||
const auto & mul_constant = pattern_map.at(mul_constant_p);
|
||||
|
||||
const auto new_output_limits = get_adjusted_output_range(
|
||||
original_output_low, original_output_high, mul_constant);
|
||||
@ -98,6 +97,26 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
|
||||
new_output_limits.second});
|
||||
|
||||
const auto mul_node = pattern_map.at(mul_node_p).get_node_shared_ptr();
|
||||
|
||||
// WA: this check is intended to prevent replacement when new FQ has shape
|
||||
// which is different to Multiply output shape. Otherwise such replacement
|
||||
// will lead to shape inconsistency in remaining graph. This check must be
|
||||
// removed in future when FQ will have correct validate_and_infer function
|
||||
// for cases with NUMPY broadcast.
|
||||
auto fq_casted = std::dynamic_pointer_cast<opset4::FakeQuantize>(new_fq_node);
|
||||
if (!fq_casted) {
|
||||
return false;
|
||||
}
|
||||
if (fq_casted->get_auto_broadcast() == op::AutoBroadcastType::NUMPY) {
|
||||
if (fq_casted->get_output_partial_shape(0).is_dynamic() ||
|
||||
mul_node->get_output_partial_shape(0).is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
if (fq_casted->get_shape() != mul_node->get_shape()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
replace_node(mul_node, new_fq_node);
|
||||
|
||||
new_fq_node->set_friendly_name(fq_node->get_friendly_name());
|
||||
|
@ -169,13 +169,6 @@ INSTANTIATE_TEST_CASE_P(FQOutputs_1D__multiplier_3D, FQMulFusion,
|
||||
::testing::Values(ngraph::Shape{1, 3, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 3, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQ_all_ones__multiplier_4D_with_channel, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQInOUt_ones__multiplier_4D_with_channel, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
@ -357,6 +350,40 @@ TEST(FQMulFusion_FQ_Mul_inputs, FQ_out_to_mul_input_2_param) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FakeQuantizeMultiplyFusionNegative) {
|
||||
const auto data = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{1, 300, 1}, {0.0f});
|
||||
const auto in_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {-0.5f});
|
||||
const auto in_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.5f});
|
||||
const auto out_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||
// out_high is a parameter, which means it should not be constant folded
|
||||
const auto out_high =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, out_low, out_high, 42);
|
||||
|
||||
const auto mul_value = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{1, 300, 16}, {3.14f});
|
||||
// and here the output of FQ is passed as the second input of Mul
|
||||
const auto mul = std::make_shared<ngraph::opset4::Multiply>(mul_value, fq);
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector{mul}, ngraph::ParameterVector{out_high});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
manager.run_passes(function);
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
|
||||
ASSERT_EQ(function->get_output_shape(0), ngraph::Shape({1, 300, 16}));
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user