[CPU] Added H-Swish activation (#1445)

This commit is contained in:
Alexandra Sidorova 2020-08-25 10:19:06 +03:00 committed by GitHub
parent 393e9295cd
commit a2f0eef6aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 135 additions and 7 deletions

View File

@ -705,7 +705,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndActivation(MKLDNNGraph &graph) {
(activationNode->getAlgorithm() == eltwise_relu ||
(conv->getCnnLayer()->precision == Precision::FP32 &&
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
eltwise_swish, eltwise_mish})));
eltwise_swish, eltwise_hswish, eltwise_mish})));
};
for (int i = 0; i < graphNodes.size(); i++) {
@ -1188,7 +1188,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndSimpleOperation(MKLDNNGraph &graph)
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
eltwise_clamp, eltwise_swish, eltwise_mish});
eltwise_clamp, eltwise_swish, eltwise_hswish, eltwise_mish});
}
return false;
@ -1433,7 +1433,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
(activationNode->getAlgorithm() == eltwise_relu ||
(conv->getCnnLayer()->precision == Precision::FP32 &&
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
eltwise_swish, eltwise_mish})));
eltwise_swish, eltwise_hswish, eltwise_mish})));
#else
return false;
#endif
@ -1783,8 +1783,8 @@ void MKLDNNGraphOptimizer::FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph) {
if (activationNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_gelu, eltwise_elu, eltwise_logistic,
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_mish, eltwise_linear, eltwise_abs,
eltwise_square, eltwise_sqrt});
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_hswish, eltwise_mish, eltwise_linear,
eltwise_abs, eltwise_square, eltwise_sqrt});
}
return false;
};
@ -1895,7 +1895,7 @@ void MKLDNNGraphOptimizer::FuseEltwiseAndSimple(MKLDNNGraph &graph) {
if (activationNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
eltwise_clamp, eltwise_swish, eltwise_mish});
eltwise_clamp, eltwise_swish, eltwise_hswish, eltwise_mish});
}
return false;

View File

@ -75,6 +75,7 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
{ "Activation", Activation },
{ "Clamp", Activation },
{ "Swish", Activation },
{ "HSwish", Activation },
{ "Mish", Activation },
{ "ScaleShift", Depthwise },
{ "PReLU", Depthwise },

View File

@ -81,6 +81,7 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
std::dynamic_pointer_cast<const ngraph::opset2::BatchToSpace>(node) ||
std::dynamic_pointer_cast<const ngraph::opset2::SpaceToBatch>(node) ||
std::dynamic_pointer_cast<const ngraph::opset3::ExtractImagePatches>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::HSwish>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL1>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL2>(node) ||
std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);

View File

@ -96,6 +96,11 @@ caseless_map<std::string, std::function<void(GenericLayer*, mkldnn::algorithm&,
beta = 0.0f;
algorithm = eltwise_swish;
}},
{"hswish", [](GenericLayer* activationLayer, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
alpha = 0.0f;
beta = 0.0f;
algorithm = eltwise_hswish;
}},
{"mish", [](GenericLayer* activationLayer, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
alpha = 0.0f;
beta = 0.0f;

View File

@ -17,6 +17,7 @@ class TRANSFORMATIONS_API HSwishFusion;
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
class TRANSFORMATIONS_API HSwishFusionWithoutRelu;
class TRANSFORMATIONS_API HSwishFusionWithClamp;
} // namespace pass
@ -32,6 +33,7 @@ public:
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
}
};
@ -61,3 +63,12 @@ public:
public:
HSwishFusionWithoutRelu();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
*/
class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass {
public:
HSwishFusionWithClamp();
};

View File

@ -178,3 +178,46 @@ ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "HSwishWithoutReluFusion");
register_matcher(m, callback);
}
ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
// Replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(mul_const_value, (1.0/6.0), 0.0001);
if (!valid_constant_values) {
return false;
}
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(clamp).get_node_shared_ptr(),
pattern_to_output.at(mul_constant).get_node_shared_ptr(),
pattern_to_output.at(mul_first).get_node_shared_ptr(),
pattern_to_output.at(mul_second).get_node_shared_ptr()
},
hswish);
ngraph::replace_node(m.get_match_root(), hswish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, "HSwishWithClampFusion");
register_matcher(m, callback);
}

View File

@ -151,6 +151,37 @@ TEST(TransformationTests, HSwishFusionWithoutRelu) {
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithClamp) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0 / 6.0});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithClamp>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithReluMulWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
@ -272,3 +303,39 @@ TEST(TransformationTests, HSwishFusionWithoutReluWrongConstValue) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithClampWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

@ -1 +1 @@
Subproject commit 1f967a094353b30d65d96a3fe1721d8dccf02278
Subproject commit eb54063189a33a10c4aa90311788e6fbb4cdf2f6