Add ReluFakeQuantize transformation (#3811)
* Add ReluFakeQuantize transformation * address review comments * replace constant with any_input * use MATCHER_SCOPE macro
This commit is contained in:
parent
33005b7741
commit
d4488b9dfc
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ReluFakeQuantizeFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ReluFakeQuantizeFusion transformation replaces following graph:
|
||||
* Relu -> FakeQuantize to FakeQuantize under following conditions:
|
||||
* - 'input_low' input to FakeQuantize is a Constant
|
||||
* - 'input_low' has non negative values
|
||||
*/
|
||||
|
||||
class ngraph::pass::ReluFakeQuantizeFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ReluFakeQuantizeFusion();
|
||||
};
|
@ -26,6 +26,7 @@
|
||||
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
|
||||
#include "transformations/common_optimizations/hswish_fusion.hpp"
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/common_optimizations/relu_fake_quantize_fusion.hpp"
|
||||
#include "transformations/common_optimizations/clamp_fusion.hpp"
|
||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
||||
@ -122,6 +123,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeReshapeFusion>();
|
||||
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
fq_fusions->add_matcher<ngraph::pass::ReluFakeQuantizeFusion>();
|
||||
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
||||
|
||||
manager.run_passes(f);
|
||||
|
@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/relu_fake_quantize_fusion.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReluFakeQuantizeFusion, "ReluFakeQuantizeFusion", 0);
|
||||
|
||||
ngraph::pass::ReluFakeQuantizeFusion::ReluFakeQuantizeFusion() {
|
||||
MATCHER_SCOPE(ReluFakeQuantizeFusion);
|
||||
auto data_pattern = ngraph::pattern::any_input();
|
||||
auto relu_pattern = ngraph::pattern::wrap_type<opset5::Relu>({data_pattern}, pattern::consumers_count(1));
|
||||
auto input_low_pattern = ngraph::pattern::wrap_type<opset5::Constant>();
|
||||
auto fq_pattern = ngraph::pattern::wrap_type<opset5::FakeQuantize>({relu_pattern, input_low_pattern,
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input()});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto data = pattern_map[data_pattern];
|
||||
auto relu = pattern_map[relu_pattern];
|
||||
auto input_low = pattern_map[input_low_pattern];
|
||||
auto input_low_const = std::dynamic_pointer_cast<opset5::Constant>(input_low.get_node_shared_ptr());
|
||||
if (!input_low_const)
|
||||
return false;
|
||||
auto input_low_values = input_low_const->cast_vector<float>();
|
||||
if (std::any_of(input_low_values.begin(), input_low_values.end(), [] (float f) -> bool { return f < 0; }))
|
||||
return false;
|
||||
auto fq = std::dynamic_pointer_cast<opset5::FakeQuantize>(pattern_map[fq_pattern].get_node_shared_ptr());
|
||||
if (!fq)
|
||||
return false;
|
||||
|
||||
auto new_fq = std::make_shared<ngraph::opset5::FakeQuantize>(data,
|
||||
fq->input_value(1),
|
||||
fq->input_value(2),
|
||||
fq->input_value(3),
|
||||
fq->input_value(4),
|
||||
fq->get_levels());
|
||||
new_fq->set_friendly_name(fq->get_friendly_name());
|
||||
|
||||
copy_runtime_info({relu.get_node_shared_ptr(), fq}, new_fq);
|
||||
replace_node(fq, new_fq);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(fq_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,98 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <transformations/common_optimizations/relu_fake_quantize_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
|
||||
TEST(TransformationTests, ReluFakeQuantizeFusion) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
Shape data_shape{1, 3, 14, 14};
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
auto relu = std::make_shared<opset5::Relu>(data);
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {0, 0, 0});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(relu, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 11);
|
||||
f = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::ReluFakeQuantizeFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {0, 0, 0});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(data, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 11);
|
||||
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReluFakeQuantizeFusionNegativeInputLow) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
Shape data_shape{1, 3, 14, 14};
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
auto relu = std::make_shared<opset5::Relu>(data);
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {2, -2, -2});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(relu, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 11);
|
||||
f = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::ReluFakeQuantizeFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
auto relu = std::make_shared<opset5::Relu>(data);
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {2, -2, -2});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {20, 20, 20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(relu, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 11);
|
||||
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user