From bab68b65c771e7bfb9a98143736192f57bdc7f57 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Fri, 15 Jan 2021 15:11:45 +0100 Subject: [PATCH] Add Clamp fusion transformation (#3756) * Add Clamp fusion transformation It fuses Maximum->Minimum subgraph to Clamp operator. Ticket: 44783 * address review comments * update year in headers --- .../common_optimizations/clamp_fusion.hpp | 35 ++++++ .../common_optimizations/clamp_fusion.cpp | 58 ++++++++++ .../common_optimizations.cpp | 2 + .../transformations/clamp_fusion.cpp | 108 ++++++++++++++++++ 4 files changed, 203 insertions(+) create mode 100644 inference-engine/src/transformations/include/transformations/common_optimizations/clamp_fusion.hpp create mode 100644 inference-engine/src/transformations/src/transformations/common_optimizations/clamp_fusion.cpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/clamp_fusion.cpp diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/clamp_fusion.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/clamp_fusion.hpp new file mode 100644 index 00000000000..019ab6d4952 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/clamp_fusion.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API ClampFusion; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief ClampFusion transformation replaces following graph: + * Maximum->Minimum to Clamp + * Restrictions: + * - one of the parameters to Maximum is a scalar constant + * - one of the parameters to Minimum is a scalar constant + */ + +class ngraph::pass::ClampFusion: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + ClampFusion(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/clamp_fusion.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/clamp_fusion.cpp new file mode 100644 index 00000000000..00e79f73afd --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/clamp_fusion.cpp @@ -0,0 +1,58 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/clamp_fusion.hpp" +#include "transformations/utils/utils.hpp" + +#include +#include + +#include +#include +#include + + +NGRAPH_RTTI_DEFINITION(ngraph::pass::ClampFusion, "ClampFusion", 0); + +ngraph::pass::ClampFusion::ClampFusion() { + auto data_pattern = ngraph::pattern::any_input(); + auto min_const_pattern = ngraph::pattern::wrap_type(); + auto max_const_pattern = ngraph::pattern::wrap_type(); + auto max_pattern = ngraph::pattern::wrap_type({data_pattern, min_const_pattern}, pattern::consumers_count(1)); + auto min_pattern = ngraph::pattern::wrap_type({max_pattern, max_const_pattern}); + + ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto pattern_map = m.get_pattern_value_map(); + auto data = pattern_map.at(data_pattern); + auto min_const = std::dynamic_pointer_cast(pattern_map.at(min_const_pattern).get_node_shared_ptr()); + if (!min_const) + return false; + if (shape_size(min_const->get_shape()) != 1) + return false; + auto max_const = std::dynamic_pointer_cast(pattern_map.at(max_const_pattern).get_node_shared_ptr()); + if (!max_const) + return false; + if (shape_size(max_const->get_shape()) != 1) + return false; + + double min_value = min_const->cast_vector()[0]; + double max_value = max_const->cast_vector()[0]; + + auto clamp = std::make_shared(data, min_value, max_value); + auto minimum = pattern_map.at(min_pattern); + clamp->set_friendly_name(minimum.get_node()->get_friendly_name()); + + copy_runtime_info({ + pattern_map.at(max_pattern).get_node_shared_ptr(), + minimum.get_node_shared_ptr() + }, + clamp); + replace_node(minimum.get_node_shared_ptr(), clamp); + + return true; + }; + + auto m = std::make_shared(min_pattern, "ClampFusion"); + this->register_matcher(m, callback); +} diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 10f37907c59..b89a9e78689 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -26,6 +26,7 @@ #include "transformations/common_optimizations/hsigmoid_fusion.hpp" #include "transformations/common_optimizations/hswish_fusion.hpp" #include "transformations/common_optimizations/convert_quantize_dequantize.hpp" +#include "transformations/common_optimizations/clamp_fusion.hpp" #include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp" #include "transformations/op_conversions/convert_pad_to_group_conv.hpp" #include "transformations/op_conversions/convert_divide.hpp" @@ -76,6 +77,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); auto decomp = manager.register_pass(); decomp->add_matcher(); diff --git a/inference-engine/tests/functional/inference_engine/transformations/clamp_fusion.cpp b/inference-engine/tests/functional/inference_engine/transformations/clamp_fusion.cpp new file mode 100644 index 00000000000..34e11e890e3 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/clamp_fusion.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + + +using namespace testing; +using namespace ngraph; + + +TEST(TransformationTests, ClampFusion) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto min_const = opset5::Constant::create(element::f32, Shape{1}, {0.1}); + auto max_const = opset5::Constant::create(element::f32, Shape{1}, {5}); + auto max = std::make_shared(data, min_const); + auto min = std::make_shared(max, max_const); + f = std::make_shared(NodeVector{min}, ParameterVector{data}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto clamp = std::make_shared(data, 0.1, 5); + f_ref = std::make_shared(NodeVector{clamp}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ClampFusionScalars) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto min_const = opset5::Constant::create(element::f32, Shape{}, {0.1}); + auto max_const = opset5::Constant::create(element::f32, Shape{}, {5}); + auto max = std::make_shared(data, min_const); + auto min = std::make_shared(max, max_const); + f = std::make_shared(NodeVector{min}, ParameterVector{data}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto clamp = std::make_shared(data, 0.1, 5); + f_ref = std::make_shared(NodeVector{clamp}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ClampFusionNonConstMin) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto min_val = std::make_shared(element::f32, Shape{}); + auto max_const = opset5::Constant::create(element::f32, Shape{}, {5}); + auto max = std::make_shared(data, min_val); + auto min = std::make_shared(max, max_const); + f = std::make_shared(NodeVector{min}, ParameterVector{data, min_val}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto min_val = std::make_shared(element::f32, Shape{}); + auto max_const = opset5::Constant::create(element::f32, Shape{}, {5}); + auto max = std::make_shared(data, min_val); + auto min = std::make_shared(max, max_const); + f_ref = std::make_shared(NodeVector{min}, ParameterVector{data, min_val}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +}