updated to fuse activation in eltwise_vload8 (#12092)

This commit is contained in:
Eddy Kim
2022-07-12 18:51:48 +09:00
committed by GitHub
parent bbc1c26750
commit a63dad6fdd
2 changed files with 22 additions and 0 deletions

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "activation/activation_kernel_base.h"
#include "eltwise_kernel_vload8.h"
#include "kernel_selector_utils.h"
#include <string>
@@ -32,6 +33,12 @@ bool EltwiseKernel_vload8::Validate(const Params& params, const optional_params&
const auto& ewParams = static_cast<const eltwise_params&>(params);
// Only one activation can be fused.
if (ewParams.fused_ops.size() > 1 ||
(ewParams.activations.size() !=0 && ewParams.fused_ops.size() != 0)) {
return false;
}
for (size_t i = 0; i < ewParams.inputs.size(); i++) {
const auto input_layout = ewParams.inputs[i].GetLayout();
const auto batch_size = ewParams.inputs[i].Batch().v;
@@ -109,6 +116,16 @@ KernelsData EltwiseKernel_vload8::GetKernelsData(const Params& params, const opt
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options);
try {
// move a fused activation from fused_ops to activations
if (newParams.activations.size() == 0 &&
newParams.fused_ops.size() == 1 &&
newParams.fused_ops[0].GetType() == KernelType::ACTIVATION) {
auto p = newParams.fused_ops[0].GetOpParams<activation_fuse_params>();
base_activation_params activation_p = p->param;
newParams.activations.push_back(activation_p);
newParams.fused_ops.clear();
}
auto cldnn_jit = GetJitConstants(newParams);
jit = CreateJit(kernelName, cldnn_jit, entry_point);
} catch (const std::runtime_error&) {

View File

@@ -15,6 +15,11 @@ public:
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return {
FusedOpType::ACTIVATION
};
}
protected:
bool Validate(const Params& p, const optional_params& o) const override;