[IE CLDNN] Added HSwish-4 operation (#1585)

This commit is contained in:
Roman Lyamin 2020-08-03 10:15:43 +03:00 committed by GitHub
parent a17472fed0
commit 8245e5b6f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 68 additions and 0 deletions

View File

@ -575,6 +575,7 @@ Program::LayerType Program::LayerTypeFromStr(const std::string &str) {
{ "Sinh" , Sinh }, { "Sinh" , Sinh },
{ "Cosh" , Cosh }, { "Cosh" , Cosh },
{ "Swish" , Swish }, { "Swish" , Swish },
{ "HSwish", HSwish },
{ "Mish" , Mish }, { "Mish" , Mish },
{ "Gelu" , Gelu }, { "Gelu" , Gelu },
{ "Atanh" , Atanh }, { "Atanh" , Atanh },
@ -1162,6 +1163,7 @@ void Program::CreateSingleLayerPrimitive(cldnn::topology& topology, InferenceEng
case SoftPlus: case SoftPlus:
case SoftSign: case SoftSign:
case Swish: case Swish:
case HSwish:
case Mish: case Mish:
case Gelu: case Gelu:
CreateActivationPrimitive(topology, layer, LayerTypeFromStr(layer->type)); CreateActivationPrimitive(topology, layer, LayerTypeFromStr(layer->type));
@ -2780,6 +2782,8 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
activationType = ELU; activationType = ELU;
} else if (activation_type == "swish") { } else if (activation_type == "swish") {
activationType = Swish; activationType = Swish;
} else if (activation_type == "hswish") {
activationType = HSwish;
} else if (activation_type == "mish") { } else if (activation_type == "mish") {
activationType = Mish; activationType = Mish;
} else if (activation_type == "gelu") { } else if (activation_type == "gelu") {
@ -2977,6 +2981,11 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
func = cldnn::activation_func::swish; func = cldnn::activation_func::swish;
break; break;
} }
case HSwish:
{
func = cldnn::activation_func::hswish;
break;
}
case Mish: case Mish:
{ {
func = cldnn::activation_func::mish; func = cldnn::activation_func::mish;

View File

@ -202,6 +202,7 @@ public:
SoftPlus, SoftPlus,
SoftSign, SoftSign,
Swish, Swish,
HSwish,
Mish, Mish,
Gelu, Gelu,
Sin, Sin,

View File

@ -68,6 +68,7 @@ enum class activation_func {
softplus, // ln(exp(val) + 1) softplus, // ln(exp(val) + 1)
softsign, // (val/(1+|val|)) softsign, // (val/(1+|val|))
swish, // (val*sigmoid(val)) swish, // (val*sigmoid(val))
hswish, // val * min(max(0, val + 3), 6) / 6
mish, // val*tanh(ln(1 + exp(val))) mish, // val*tanh(ln(1 + exp(val)))
gelu // (0.5*val*(1 + erf(val / sqrt(2))) gelu // (0.5*val*(1 + erf(val / sqrt(2)))
}; };

View File

@ -150,6 +150,7 @@ enum class ActivationFunction {
SOFTPLUS, SOFTPLUS,
SOFTSIGN, SOFTSIGN,
SWISH, SWISH,
HSWISH,
MISH, MISH,
GELU GELU
}; };

View File

@ -717,6 +717,15 @@ JitConstants MakeActivationJitConstants(ActivationFunction activation_function,
(input / (one + exp(neg(input)))).str())); (input / (one + exp(neg(input)))).str()));
break; break;
} }
case ActivationFunction::HSWISH: {
std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
const JitTerm three("3." + type_suffix);
const JitTerm six("6." + type_suffix);
jitConstants.AddConstant(MakeJitConstant(
macro_def,
(input * min_func(max_func(zero, input + three), six) / six).str()));
break;
}
case ActivationFunction::MISH: { case ActivationFunction::MISH: {
std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h"; std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
auto bound = out_dt == Datatype::F32 ? "9.9f"_jit : "4.75h"_jit; auto bound = out_dt == Datatype::F32 ? "9.9f"_jit : "4.75h"_jit;

View File

@ -82,6 +82,7 @@ std::string toString(ActivationFunction activation) {
case ActivationFunction::SOFTPLUS: method = "SOFTPLUS"; break; case ActivationFunction::SOFTPLUS: method = "SOFTPLUS"; break;
case ActivationFunction::SOFTSIGN: method = "SOFTSIGN"; break; case ActivationFunction::SOFTSIGN: method = "SOFTSIGN"; break;
case ActivationFunction::SWISH: method = "SWISH"; break; case ActivationFunction::SWISH: method = "SWISH"; break;
case ActivationFunction::HSWISH: method = "HSWISH"; break;
case ActivationFunction::MISH: method = "MISH"; break; case ActivationFunction::MISH: method = "MISH"; break;
case ActivationFunction::GELU: method = "GELU"; break; case ActivationFunction::GELU: method = "GELU"; break;
default: break; default: break;

View File

@ -666,6 +666,8 @@ kernel_selector::activation_function get_kernel_selector_activation_param(activa
return kernel_selector::activation_function::HARD_SIGMOID; return kernel_selector::activation_function::HARD_SIGMOID;
case cldnn::activation_func::swish: case cldnn::activation_func::swish:
return kernel_selector::activation_function::SWISH; return kernel_selector::activation_function::SWISH;
case cldnn::activation_func::hswish:
return kernel_selector::activation_function::HSWISH;
case cldnn::activation_func::mish: case cldnn::activation_func::mish:
return kernel_selector::activation_function::MISH; return kernel_selector::activation_function::MISH;
case cldnn::activation_func::gelu: case cldnn::activation_func::gelu:

View File

@ -695,6 +695,46 @@ TEST(activation_f16_fw_gpu, basic_yxfb_mish) {
} }
} }
TEST(activation_f16_fw_gpu, basic_yxfb_hswish) {
const auto& engine = get_test_engine();
auto input = memory::allocate(engine, { data_types::f16, format::yxfb, { 1, 2, 5, 2 } });
set_values(input,
{ FLOAT16(0.0f), FLOAT16(-2.0f), FLOAT16(-3.0f), FLOAT16(4.0f), FLOAT16(5.0f),
FLOAT16(2.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(-6.0f),
FLOAT16(3.0f), FLOAT16(-3.0f), FLOAT16(3.0f), FLOAT16(5.0f), FLOAT16(1.0f),
FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(-1.0f), FLOAT16(1.0f) });
topology topology(
input_layout("input", input.get_layout()),
activation("hswish", "input", activation_func::hswish));
network network(engine, topology);
network.set_input_data("input", input);
auto outputs = network.execute();
EXPECT_EQ(outputs.size(), size_t(1));
EXPECT_EQ(outputs.begin()->first, "hswish");
auto output_memory = outputs.at("hswish").get_memory();
auto output_layout = output_memory.get_layout();
auto output_ptr = output_memory.pointer<FLOAT16>();
auto input_ptr = input.pointer<FLOAT16>();
int y_size = output_layout.size.spatial[1];
int x_size = output_layout.size.spatial[0];
int f_size = output_layout.size.feature[0];
int b_size = output_layout.size.batch[0];
EXPECT_EQ(output_layout.format, format::yxfb);
EXPECT_EQ(y_size, 2);
EXPECT_EQ(x_size, 5);
EXPECT_EQ(f_size, 2);
EXPECT_EQ(b_size, 1);
for (size_t i = 0; i < output_layout.get_linear_size(); ++i) {
EXPECT_NEAR((FLOAT16)((float)input_ptr[i] * std::fmin(std::fmax(0.f, (float)input_ptr[i] + 3.f), 6.f) / 6.f),
output_ptr[i], 1e-3f);
}
}
TEST(activation_f32_fw_gpu, basic_yxfb_all_functions) TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
{ {
// Input: // Input:
@ -740,6 +780,7 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
activation_func::negative, activation_func::negative,
activation_func::abs, activation_func::abs,
activation_func::swish, activation_func::swish,
activation_func::hswish,
activation_func::mish, activation_func::mish,
activation_func::gelu activation_func::gelu
}; };
@ -858,6 +899,9 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
case activation_func::swish: case activation_func::swish:
EXPECT_FLOAT_EQ((float)input_ptr[i] / (1.f + std::exp((float)(-input_ptr[i]))), output_ptr[i]); EXPECT_FLOAT_EQ((float)input_ptr[i] / (1.f + std::exp((float)(-input_ptr[i]))), output_ptr[i]);
break; break;
case activation_func::hswish:
EXPECT_FLOAT_EQ((float)input_ptr[i] * std::fmin(std::fmax(0.f, (float)input_ptr[i] + 3.f), 6.f) / 6.f, output_ptr[i]);
break;
case activation_func::mish: case activation_func::mish:
EXPECT_NEAR((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i]))), EXPECT_NEAR((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i]))),
output_ptr[i], 1e-5f); output_ptr[i], 1e-5f);