[IE CLDNN] Added Mish operation (#1125)
This commit is contained in:
parent
65657ea5c5
commit
f3848b4454
@ -575,6 +575,7 @@ Program::LayerType Program::LayerTypeFromStr(const std::string &str) {
|
||||
{ "Sinh" , Sinh },
|
||||
{ "Cosh" , Cosh },
|
||||
{ "Swish" , Swish },
|
||||
{ "Mish" , Mish },
|
||||
{ "Gelu" , Gelu },
|
||||
{ "Atanh" , Atanh },
|
||||
{ "Floor" , Floor },
|
||||
@ -1159,6 +1160,7 @@ void Program::CreateSingleLayerPrimitive(cldnn::topology& topology, InferenceEng
|
||||
case SoftPlus:
|
||||
case SoftSign:
|
||||
case Swish:
|
||||
case Mish:
|
||||
case Gelu:
|
||||
CreateActivationPrimitive(topology, layer, LayerTypeFromStr(layer->type));
|
||||
break;
|
||||
@ -2767,6 +2769,8 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
|
||||
activationType = ELU;
|
||||
} else if (activation_type == "swish") {
|
||||
activationType = Swish;
|
||||
} else if (activation_type == "mish") {
|
||||
activationType = Mish;
|
||||
} else if (activation_type == "gelu") {
|
||||
activationType = Gelu;
|
||||
} else if (activation_type == "relu") {
|
||||
@ -2957,6 +2961,11 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
|
||||
func = cldnn::activation_func::swish;
|
||||
break;
|
||||
}
|
||||
case Mish:
|
||||
{
|
||||
func = cldnn::activation_func::mish;
|
||||
break;
|
||||
}
|
||||
case Gelu:
|
||||
{
|
||||
func = cldnn::activation_func::gelu;
|
||||
|
@ -201,6 +201,7 @@ public:
|
||||
SoftPlus,
|
||||
SoftSign,
|
||||
Swish,
|
||||
Mish,
|
||||
Gelu,
|
||||
Sin,
|
||||
Sinh,
|
||||
|
@ -68,6 +68,7 @@ enum class activation_func {
|
||||
softplus, // ln(exp(val) + 1)
|
||||
softsign, // (val/(1+|val|))
|
||||
swish, // (val*sigmoid(val))
|
||||
mish, // val*tanh(ln(1 + exp(val)))
|
||||
gelu // (0.5*val*(1 + erf(val / sqrt(2)))
|
||||
};
|
||||
|
||||
|
@ -150,6 +150,7 @@ enum class ActivationFunction {
|
||||
SOFTPLUS,
|
||||
SOFTSIGN,
|
||||
SWISH,
|
||||
MISH,
|
||||
GELU
|
||||
};
|
||||
|
||||
|
@ -717,6 +717,18 @@ JitConstants MakeActivationJitConstants(ActivationFunction activation_function,
|
||||
(input / (one + exp(neg(input)))).str()));
|
||||
break;
|
||||
}
|
||||
case ActivationFunction::MISH: {
|
||||
std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
|
||||
auto bound = out_dt == Datatype::F32 ? "9.9f"_jit : "4.75h"_jit;
|
||||
const JitTerm two("2." + type_suffix);
|
||||
const JitTerm n((exp(input) + two) * exp(input));
|
||||
const JitTerm common_mish_formula((input * n) / (n + two));
|
||||
|
||||
jitConstants.AddConstant(MakeJitConstant(
|
||||
macro_def,
|
||||
ternary(input.ge(bound), input, common_mish_formula).str()));
|
||||
break;
|
||||
}
|
||||
case ActivationFunction::GELU: {
|
||||
std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
|
||||
const JitTerm half{"0.5" + type_suffix};
|
||||
|
@ -82,6 +82,7 @@ std::string toString(ActivationFunction activation) {
|
||||
case ActivationFunction::SOFTPLUS: method = "SOFTPLUS"; break;
|
||||
case ActivationFunction::SOFTSIGN: method = "SOFTSIGN"; break;
|
||||
case ActivationFunction::SWISH: method = "SWISH"; break;
|
||||
case ActivationFunction::MISH: method = "MISH"; break;
|
||||
case ActivationFunction::GELU: method = "GELU"; break;
|
||||
default: break;
|
||||
}
|
||||
|
@ -666,6 +666,8 @@ kernel_selector::activation_function get_kernel_selector_activation_param(activa
|
||||
return kernel_selector::activation_function::HARD_SIGMOID;
|
||||
case cldnn::activation_func::swish:
|
||||
return kernel_selector::activation_function::SWISH;
|
||||
case cldnn::activation_func::mish:
|
||||
return kernel_selector::activation_function::MISH;
|
||||
case cldnn::activation_func::gelu:
|
||||
return kernel_selector::activation_function::GELU;
|
||||
default:
|
||||
|
@ -655,6 +655,46 @@ TEST(activation_f32_fw_gpu, relu_basic_bfzyx) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(activation_f16_fw_gpu, basic_yxfb_mish) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
auto input = memory::allocate(engine, { data_types::f16, format::yxfb, { 1, 1, 5, 4 } });
|
||||
set_values(input,
|
||||
{ FLOAT16(0.0f), FLOAT16(-2.0f), FLOAT16(-3.0f), FLOAT16(4.0f), FLOAT16(5.0f),
|
||||
FLOAT16(2.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(-6.0f),
|
||||
FLOAT16(3.0f), FLOAT16(-3.0f), FLOAT16(3.0f), FLOAT16(5.0f), FLOAT16(1.0f),
|
||||
FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(-1.0f), FLOAT16(1.0f) });
|
||||
|
||||
topology topology(
|
||||
input_layout("input", input.get_layout()),
|
||||
activation("mish", "input", activation_func::mish));
|
||||
network network(engine, topology);
|
||||
network.set_input_data("input", input);
|
||||
auto outputs = network.execute();
|
||||
EXPECT_EQ(outputs.size(), size_t(1));
|
||||
EXPECT_EQ(outputs.begin()->first, "mish");
|
||||
|
||||
auto output_memory = outputs.at("mish").get_memory();
|
||||
auto output_layout = output_memory.get_layout();
|
||||
auto output_ptr = output_memory.pointer<FLOAT16>();
|
||||
auto input_ptr = input.pointer<FLOAT16>();
|
||||
|
||||
int y_size = output_layout.size.spatial[1];
|
||||
int x_size = output_layout.size.spatial[0];
|
||||
int f_size = output_layout.size.feature[0];
|
||||
int b_size = output_layout.size.batch[0];
|
||||
EXPECT_EQ(output_layout.format, format::yxfb);
|
||||
EXPECT_EQ(y_size, 4);
|
||||
EXPECT_EQ(x_size, 5);
|
||||
EXPECT_EQ(f_size, 1);
|
||||
EXPECT_EQ(b_size, 1);
|
||||
|
||||
for (size_t i = 0; i < output_layout.get_linear_size(); ++i) {
|
||||
EXPECT_NEAR((FLOAT16)((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i])))),
|
||||
output_ptr[i], 1e-2f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
|
||||
{
|
||||
// Input:
|
||||
@ -700,6 +740,7 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
|
||||
activation_func::negative,
|
||||
activation_func::abs,
|
||||
activation_func::swish,
|
||||
activation_func::mish,
|
||||
activation_func::gelu
|
||||
};
|
||||
|
||||
@ -817,6 +858,10 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
|
||||
case activation_func::swish:
|
||||
EXPECT_FLOAT_EQ((float)input_ptr[i] / (1.f + std::exp((float)(-input_ptr[i]))), output_ptr[i]);
|
||||
break;
|
||||
case activation_func::mish:
|
||||
EXPECT_NEAR((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i]))),
|
||||
output_ptr[i], 1e-5f);
|
||||
break;
|
||||
case activation_func::gelu:
|
||||
EXPECT_NEAR(0.5f * (float)input_ptr[i] * (1.f + std::erf((float)(input_ptr[i]) / std::sqrt(2.0f))),
|
||||
output_ptr[i], 1e-5f);
|
||||
|
Loading…
Reference in New Issue
Block a user