[CPU] Added H-Swish activation (#1445)
This commit is contained in:
parent
393e9295cd
commit
a2f0eef6aa
@ -705,7 +705,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndActivation(MKLDNNGraph &graph) {
|
||||
(activationNode->getAlgorithm() == eltwise_relu ||
|
||||
(conv->getCnnLayer()->precision == Precision::FP32 &&
|
||||
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
|
||||
eltwise_swish, eltwise_mish})));
|
||||
eltwise_swish, eltwise_hswish, eltwise_mish})));
|
||||
};
|
||||
|
||||
for (int i = 0; i < graphNodes.size(); i++) {
|
||||
@ -1188,7 +1188,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndSimpleOperation(MKLDNNGraph &graph)
|
||||
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
|
||||
|
||||
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
|
||||
eltwise_clamp, eltwise_swish, eltwise_mish});
|
||||
eltwise_clamp, eltwise_swish, eltwise_hswish, eltwise_mish});
|
||||
}
|
||||
|
||||
return false;
|
||||
@ -1433,7 +1433,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
|
||||
(activationNode->getAlgorithm() == eltwise_relu ||
|
||||
(conv->getCnnLayer()->precision == Precision::FP32 &&
|
||||
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
|
||||
eltwise_swish, eltwise_mish})));
|
||||
eltwise_swish, eltwise_hswish, eltwise_mish})));
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
@ -1783,8 +1783,8 @@ void MKLDNNGraphOptimizer::FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph) {
|
||||
if (activationNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
|
||||
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_gelu, eltwise_elu, eltwise_logistic,
|
||||
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_mish, eltwise_linear, eltwise_abs,
|
||||
eltwise_square, eltwise_sqrt});
|
||||
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_hswish, eltwise_mish, eltwise_linear,
|
||||
eltwise_abs, eltwise_square, eltwise_sqrt});
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@ -1895,7 +1895,7 @@ void MKLDNNGraphOptimizer::FuseEltwiseAndSimple(MKLDNNGraph &graph) {
|
||||
if (activationNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
|
||||
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
|
||||
eltwise_clamp, eltwise_swish, eltwise_mish});
|
||||
eltwise_clamp, eltwise_swish, eltwise_hswish, eltwise_mish});
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -75,6 +75,7 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
|
||||
{ "Activation", Activation },
|
||||
{ "Clamp", Activation },
|
||||
{ "Swish", Activation },
|
||||
{ "HSwish", Activation },
|
||||
{ "Mish", Activation },
|
||||
{ "ScaleShift", Depthwise },
|
||||
{ "PReLU", Depthwise },
|
||||
|
@ -81,6 +81,7 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
|
||||
std::dynamic_pointer_cast<const ngraph::opset2::BatchToSpace>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset2::SpaceToBatch>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset3::ExtractImagePatches>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::HSwish>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL1>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::ReduceL2>(node) ||
|
||||
std::dynamic_pointer_cast<const ngraph::opset4::Pad>(node);
|
||||
|
@ -96,6 +96,11 @@ caseless_map<std::string, std::function<void(GenericLayer*, mkldnn::algorithm&,
|
||||
beta = 0.0f;
|
||||
algorithm = eltwise_swish;
|
||||
}},
|
||||
{"hswish", [](GenericLayer* activationLayer, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
|
||||
alpha = 0.0f;
|
||||
beta = 0.0f;
|
||||
algorithm = eltwise_hswish;
|
||||
}},
|
||||
{"mish", [](GenericLayer* activationLayer, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
|
||||
alpha = 0.0f;
|
||||
beta = 0.0f;
|
||||
|
@ -17,6 +17,7 @@ class TRANSFORMATIONS_API HSwishFusion;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithoutRelu;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithClamp;
|
||||
|
||||
|
||||
} // namespace pass
|
||||
@ -32,6 +33,7 @@ public:
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
|
||||
}
|
||||
};
|
||||
|
||||
@ -61,3 +63,12 @@ public:
|
||||
public:
|
||||
HSwishFusionWithoutRelu();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
HSwishFusionWithClamp();
|
||||
};
|
||||
|
@ -178,3 +178,46 @@ ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "HSwishWithoutReluFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
|
||||
// Replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
|
||||
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto x_output = pattern_to_output.at(input);
|
||||
|
||||
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
|
||||
|
||||
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
|
||||
&& check_constant_value(mul_const_value, (1.0/6.0), 0.0001);
|
||||
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
|
||||
|
||||
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(clamp).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_first).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_second).get_node_shared_ptr()
|
||||
},
|
||||
hswish);
|
||||
ngraph::replace_node(m.get_match_root(), hswish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, "HSwishWithClampFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -151,6 +151,37 @@ TEST(TransformationTests, HSwishFusionWithoutRelu) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithClamp) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0 / 6.0});
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithClamp>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithReluMulWrongConstValue) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
@ -272,3 +303,39 @@ TEST(TransformationTests, HSwishFusionWithoutReluWrongConstValue) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithClampWrongConstValue) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
2
inference-engine/thirdparty/mkl-dnn
vendored
2
inference-engine/thirdparty/mkl-dnn
vendored
@ -1 +1 @@
|
||||
Subproject commit 1f967a094353b30d65d96a3fe1721d8dccf02278
|
||||
Subproject commit eb54063189a33a10c4aa90311788e6fbb4cdf2f6
|
Loading…
Reference in New Issue
Block a user