Add ReluFakeQuantize transformation (#3811)

* Add ReluFakeQuantize transformation

* address review comments

* replace constant with any_input

* use MATCHER_SCOPE macro
This commit is contained in:
Mateusz Tabaka 2021-01-20 16:50:19 +01:00 committed by GitHub
parent 33005b7741
commit d4488b9dfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 191 additions and 0 deletions

View File

@ -0,0 +1,31 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ReluFakeQuantizeFusion;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ReluFakeQuantizeFusion transformation replaces following graph:
* Relu -> FakeQuantize to FakeQuantize under following conditions:
* - 'input_low' input to FakeQuantize is a Constant
* - 'input_low' has non negative values
*/
class ngraph::pass::ReluFakeQuantizeFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ReluFakeQuantizeFusion();
};

View File

@ -26,6 +26,7 @@
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
#include "transformations/common_optimizations/hswish_fusion.hpp"
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/relu_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/clamp_fusion.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
@ -122,6 +123,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeReshapeFusion>();
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
fq_fusions->add_matcher<ngraph::pass::ReluFakeQuantizeFusion>();
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
manager.run_passes(f);

View File

@ -0,0 +1,60 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/relu_fake_quantize_fusion.hpp"
#include "transformations/utils/utils.hpp"
#include "itt.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReluFakeQuantizeFusion, "ReluFakeQuantizeFusion", 0);
ngraph::pass::ReluFakeQuantizeFusion::ReluFakeQuantizeFusion() {
MATCHER_SCOPE(ReluFakeQuantizeFusion);
auto data_pattern = ngraph::pattern::any_input();
auto relu_pattern = ngraph::pattern::wrap_type<opset5::Relu>({data_pattern}, pattern::consumers_count(1));
auto input_low_pattern = ngraph::pattern::wrap_type<opset5::Constant>();
auto fq_pattern = ngraph::pattern::wrap_type<opset5::FakeQuantize>({relu_pattern, input_low_pattern,
ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
ngraph::pattern::any_input()});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map[data_pattern];
auto relu = pattern_map[relu_pattern];
auto input_low = pattern_map[input_low_pattern];
auto input_low_const = std::dynamic_pointer_cast<opset5::Constant>(input_low.get_node_shared_ptr());
if (!input_low_const)
return false;
auto input_low_values = input_low_const->cast_vector<float>();
if (std::any_of(input_low_values.begin(), input_low_values.end(), [] (float f) -> bool { return f < 0; }))
return false;
auto fq = std::dynamic_pointer_cast<opset5::FakeQuantize>(pattern_map[fq_pattern].get_node_shared_ptr());
if (!fq)
return false;
auto new_fq = std::make_shared<ngraph::opset5::FakeQuantize>(data,
fq->input_value(1),
fq->input_value(2),
fq->input_value(3),
fq->input_value(4),
fq->get_levels());
new_fq->set_friendly_name(fq->get_friendly_name());
copy_runtime_info({relu.get_node_shared_ptr(), fq}, new_fq);
replace_node(fq, new_fq);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(fq_pattern, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,98 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <transformations/common_optimizations/relu_fake_quantize_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, ReluFakeQuantizeFusion) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{1, 3, 14, 14};
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
auto relu = std::make_shared<opset5::Relu>(data);
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {0, 0, 0});
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(relu, input_low,
input_high, output_low,
output_high, 11);
f = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ReluFakeQuantizeFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {0, 0, 0});
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(data, input_low,
input_high, output_low,
output_high, 11);
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ReluFakeQuantizeFusionNegativeInputLow) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{1, 3, 14, 14};
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
auto relu = std::make_shared<opset5::Relu>(data);
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {2, -2, -2});
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(relu, input_low,
input_high, output_low,
output_high, 11);
f = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ReluFakeQuantizeFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
auto relu = std::make_shared<opset5::Relu>(data);
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {2, -2, -2});
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(relu, input_low,
input_high, output_low,
output_high, 11);
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}