[IE CLDNN] Added Mish operation (#1125)

This commit is contained in:
Roman Lyamin 2020-07-09 16:57:59 +03:00 committed by GitHub
parent 65657ea5c5
commit f3848b4454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 72 additions and 0 deletions

View File

@ -575,6 +575,7 @@ Program::LayerType Program::LayerTypeFromStr(const std::string &str) {
{ "Sinh" , Sinh },
{ "Cosh" , Cosh },
{ "Swish" , Swish },
{ "Mish" , Mish },
{ "Gelu" , Gelu },
{ "Atanh" , Atanh },
{ "Floor" , Floor },
@ -1159,6 +1160,7 @@ void Program::CreateSingleLayerPrimitive(cldnn::topology& topology, InferenceEng
case SoftPlus:
case SoftSign:
case Swish:
case Mish:
case Gelu:
CreateActivationPrimitive(topology, layer, LayerTypeFromStr(layer->type));
break;
@ -2767,6 +2769,8 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
activationType = ELU;
} else if (activation_type == "swish") {
activationType = Swish;
} else if (activation_type == "mish") {
activationType = Mish;
} else if (activation_type == "gelu") {
activationType = Gelu;
} else if (activation_type == "relu") {
@ -2957,6 +2961,11 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
func = cldnn::activation_func::swish;
break;
}
case Mish:
{
func = cldnn::activation_func::mish;
break;
}
case Gelu:
{
func = cldnn::activation_func::gelu;

View File

@ -201,6 +201,7 @@ public:
SoftPlus,
SoftSign,
Swish,
Mish,
Gelu,
Sin,
Sinh,

View File

@ -68,6 +68,7 @@ enum class activation_func {
softplus, // ln(exp(val) + 1)
softsign, // (val/(1+|val|))
swish, // (val*sigmoid(val))
mish, // val*tanh(ln(1 + exp(val)))
gelu // (0.5*val*(1 + erf(val / sqrt(2)))
};

View File

@ -150,6 +150,7 @@ enum class ActivationFunction {
SOFTPLUS,
SOFTSIGN,
SWISH,
MISH,
GELU
};

View File

@ -717,6 +717,18 @@ JitConstants MakeActivationJitConstants(ActivationFunction activation_function,
(input / (one + exp(neg(input)))).str()));
break;
}
case ActivationFunction::MISH: {
std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
auto bound = out_dt == Datatype::F32 ? "9.9f"_jit : "4.75h"_jit;
const JitTerm two("2." + type_suffix);
const JitTerm n((exp(input) + two) * exp(input));
const JitTerm common_mish_formula((input * n) / (n + two));
jitConstants.AddConstant(MakeJitConstant(
macro_def,
ternary(input.ge(bound), input, common_mish_formula).str()));
break;
}
case ActivationFunction::GELU: {
std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
const JitTerm half{"0.5" + type_suffix};

View File

@ -82,6 +82,7 @@ std::string toString(ActivationFunction activation) {
case ActivationFunction::SOFTPLUS: method = "SOFTPLUS"; break;
case ActivationFunction::SOFTSIGN: method = "SOFTSIGN"; break;
case ActivationFunction::SWISH: method = "SWISH"; break;
case ActivationFunction::MISH: method = "MISH"; break;
case ActivationFunction::GELU: method = "GELU"; break;
default: break;
}

View File

@ -666,6 +666,8 @@ kernel_selector::activation_function get_kernel_selector_activation_param(activa
return kernel_selector::activation_function::HARD_SIGMOID;
case cldnn::activation_func::swish:
return kernel_selector::activation_function::SWISH;
case cldnn::activation_func::mish:
return kernel_selector::activation_function::MISH;
case cldnn::activation_func::gelu:
return kernel_selector::activation_function::GELU;
default:

View File

@ -655,6 +655,46 @@ TEST(activation_f32_fw_gpu, relu_basic_bfzyx) {
}
}
TEST(activation_f16_fw_gpu, basic_yxfb_mish) {
const auto& engine = get_test_engine();
auto input = memory::allocate(engine, { data_types::f16, format::yxfb, { 1, 1, 5, 4 } });
set_values(input,
{ FLOAT16(0.0f), FLOAT16(-2.0f), FLOAT16(-3.0f), FLOAT16(4.0f), FLOAT16(5.0f),
FLOAT16(2.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(-6.0f),
FLOAT16(3.0f), FLOAT16(-3.0f), FLOAT16(3.0f), FLOAT16(5.0f), FLOAT16(1.0f),
FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(-1.0f), FLOAT16(1.0f) });
topology topology(
input_layout("input", input.get_layout()),
activation("mish", "input", activation_func::mish));
network network(engine, topology);
network.set_input_data("input", input);
auto outputs = network.execute();
EXPECT_EQ(outputs.size(), size_t(1));
EXPECT_EQ(outputs.begin()->first, "mish");
auto output_memory = outputs.at("mish").get_memory();
auto output_layout = output_memory.get_layout();
auto output_ptr = output_memory.pointer<FLOAT16>();
auto input_ptr = input.pointer<FLOAT16>();
int y_size = output_layout.size.spatial[1];
int x_size = output_layout.size.spatial[0];
int f_size = output_layout.size.feature[0];
int b_size = output_layout.size.batch[0];
EXPECT_EQ(output_layout.format, format::yxfb);
EXPECT_EQ(y_size, 4);
EXPECT_EQ(x_size, 5);
EXPECT_EQ(f_size, 1);
EXPECT_EQ(b_size, 1);
for (size_t i = 0; i < output_layout.get_linear_size(); ++i) {
EXPECT_NEAR((FLOAT16)((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i])))),
output_ptr[i], 1e-2f);
}
}
TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
{
// Input:
@ -700,6 +740,7 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
activation_func::negative,
activation_func::abs,
activation_func::swish,
activation_func::mish,
activation_func::gelu
};
@ -817,6 +858,10 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
case activation_func::swish:
EXPECT_FLOAT_EQ((float)input_ptr[i] / (1.f + std::exp((float)(-input_ptr[i]))), output_ptr[i]);
break;
case activation_func::mish:
EXPECT_NEAR((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i]))),
output_ptr[i], 1e-5f);
break;
case activation_func::gelu:
EXPECT_NEAR(0.5f * (float)input_ptr[i] * (1.f + std::erf((float)(input_ptr[i]) / std::sqrt(2.0f))),
output_ptr[i], 1e-5f);