Add Clamp fusion transformation (#3756)
* Add Clamp fusion transformation It fuses Maximum->Minimum subgraph to Clamp operator. Ticket: 44783 * address review comments * update year in headers
This commit is contained in:
parent
00c57a3bdf
commit
bab68b65c7
@ -0,0 +1,35 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <transformations_visibility.hpp>
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API ClampFusion;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief ClampFusion transformation replaces following graph:
|
||||||
|
* Maximum->Minimum to Clamp
|
||||||
|
* Restrictions:
|
||||||
|
* - one of the parameters to Maximum is a scalar constant
|
||||||
|
* - one of the parameters to Minimum is a scalar constant
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ngraph::pass::ClampFusion: public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
ClampFusion();
|
||||||
|
};
|
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/clamp_fusion.hpp"
|
||||||
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset5.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::ClampFusion, "ClampFusion", 0);
|
||||||
|
|
||||||
|
ngraph::pass::ClampFusion::ClampFusion() {
|
||||||
|
auto data_pattern = ngraph::pattern::any_input();
|
||||||
|
auto min_const_pattern = ngraph::pattern::wrap_type<opset5::Constant>();
|
||||||
|
auto max_const_pattern = ngraph::pattern::wrap_type<opset5::Constant>();
|
||||||
|
auto max_pattern = ngraph::pattern::wrap_type<opset5::Maximum>({data_pattern, min_const_pattern}, pattern::consumers_count(1));
|
||||||
|
auto min_pattern = ngraph::pattern::wrap_type<opset5::Minimum>({max_pattern, max_const_pattern});
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
auto pattern_map = m.get_pattern_value_map();
|
||||||
|
auto data = pattern_map.at(data_pattern);
|
||||||
|
auto min_const = std::dynamic_pointer_cast<opset5::Constant>(pattern_map.at(min_const_pattern).get_node_shared_ptr());
|
||||||
|
if (!min_const)
|
||||||
|
return false;
|
||||||
|
if (shape_size(min_const->get_shape()) != 1)
|
||||||
|
return false;
|
||||||
|
auto max_const = std::dynamic_pointer_cast<opset5::Constant>(pattern_map.at(max_const_pattern).get_node_shared_ptr());
|
||||||
|
if (!max_const)
|
||||||
|
return false;
|
||||||
|
if (shape_size(max_const->get_shape()) != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
double min_value = min_const->cast_vector<double>()[0];
|
||||||
|
double max_value = max_const->cast_vector<double>()[0];
|
||||||
|
|
||||||
|
auto clamp = std::make_shared<ngraph::opset5::Clamp>(data, min_value, max_value);
|
||||||
|
auto minimum = pattern_map.at(min_pattern);
|
||||||
|
clamp->set_friendly_name(minimum.get_node()->get_friendly_name());
|
||||||
|
|
||||||
|
copy_runtime_info({
|
||||||
|
pattern_map.at(max_pattern).get_node_shared_ptr(),
|
||||||
|
minimum.get_node_shared_ptr()
|
||||||
|
},
|
||||||
|
clamp);
|
||||||
|
replace_node(minimum.get_node_shared_ptr(), clamp);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(min_pattern, "ClampFusion");
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
@ -26,6 +26,7 @@
|
|||||||
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
|
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
|
||||||
#include "transformations/common_optimizations/hswish_fusion.hpp"
|
#include "transformations/common_optimizations/hswish_fusion.hpp"
|
||||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||||
|
#include "transformations/common_optimizations/clamp_fusion.hpp"
|
||||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||||
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
||||||
#include "transformations/op_conversions/convert_divide.hpp"
|
#include "transformations/op_conversions/convert_divide.hpp"
|
||||||
@ -76,6 +77,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
|||||||
manager.register_pass<ngraph::pass::HSigmoidFusion>();
|
manager.register_pass<ngraph::pass::HSigmoidFusion>();
|
||||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
|
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
|
||||||
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
|
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
|
||||||
|
manager.register_pass<ngraph::pass::ClampFusion>();
|
||||||
|
|
||||||
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
|
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||||
decomp->add_matcher<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
decomp->add_matcher<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
||||||
|
@ -0,0 +1,108 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset5.hpp>
|
||||||
|
#include <transformations/common_optimizations/clamp_fusion.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
|
||||||
|
TEST(TransformationTests, ClampFusion) {
|
||||||
|
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{2, 2});
|
||||||
|
auto min_const = opset5::Constant::create(element::f32, Shape{1}, {0.1});
|
||||||
|
auto max_const = opset5::Constant::create(element::f32, Shape{1}, {5});
|
||||||
|
auto max = std::make_shared<opset5::Maximum>(data, min_const);
|
||||||
|
auto min = std::make_shared<opset5::Minimum>(max, max_const);
|
||||||
|
f = std::make_shared<Function>(NodeVector{min}, ParameterVector{data});
|
||||||
|
|
||||||
|
pass::Manager m;
|
||||||
|
m.register_pass<pass::InitNodeInfo>();
|
||||||
|
m.register_pass<pass::ClampFusion>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
|
||||||
|
auto clamp = std::make_shared<opset5::Clamp>(data, 0.1, 5);
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{clamp}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ClampFusionScalars) {
|
||||||
|
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{2, 2});
|
||||||
|
auto min_const = opset5::Constant::create(element::f32, Shape{}, {0.1});
|
||||||
|
auto max_const = opset5::Constant::create(element::f32, Shape{}, {5});
|
||||||
|
auto max = std::make_shared<opset5::Maximum>(data, min_const);
|
||||||
|
auto min = std::make_shared<opset5::Minimum>(max, max_const);
|
||||||
|
f = std::make_shared<Function>(NodeVector{min}, ParameterVector{data});
|
||||||
|
|
||||||
|
pass::Manager m;
|
||||||
|
m.register_pass<pass::InitNodeInfo>();
|
||||||
|
m.register_pass<pass::ClampFusion>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
|
||||||
|
auto clamp = std::make_shared<opset5::Clamp>(data, 0.1, 5);
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{clamp}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ClampFusionNonConstMin) {
|
||||||
|
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{2, 2});
|
||||||
|
auto min_val = std::make_shared<opset5::Parameter>(element::f32, Shape{});
|
||||||
|
auto max_const = opset5::Constant::create(element::f32, Shape{}, {5});
|
||||||
|
auto max = std::make_shared<opset5::Maximum>(data, min_val);
|
||||||
|
auto min = std::make_shared<opset5::Minimum>(max, max_const);
|
||||||
|
f = std::make_shared<Function>(NodeVector{min}, ParameterVector{data, min_val});
|
||||||
|
|
||||||
|
pass::Manager m;
|
||||||
|
m.register_pass<pass::InitNodeInfo>();
|
||||||
|
m.register_pass<pass::ClampFusion>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{2, 2});
|
||||||
|
auto min_val = std::make_shared<opset5::Parameter>(element::f32, Shape{});
|
||||||
|
auto max_const = opset5::Constant::create(element::f32, Shape{}, {5});
|
||||||
|
auto max = std::make_shared<opset5::Maximum>(data, min_val);
|
||||||
|
auto min = std::make_shared<opset5::Minimum>(max, max_const);
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{min}, ParameterVector{data, min_val});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user