updated to fuse activation in eltwise_vload8 (#12092)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "activation/activation_kernel_base.h"
|
||||
#include "eltwise_kernel_vload8.h"
|
||||
#include "kernel_selector_utils.h"
|
||||
#include <string>
|
||||
@@ -32,6 +33,12 @@ bool EltwiseKernel_vload8::Validate(const Params& params, const optional_params&
|
||||
|
||||
const auto& ewParams = static_cast<const eltwise_params&>(params);
|
||||
|
||||
// Only one activation can be fused.
|
||||
if (ewParams.fused_ops.size() > 1 ||
|
||||
(ewParams.activations.size() !=0 && ewParams.fused_ops.size() != 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ewParams.inputs.size(); i++) {
|
||||
const auto input_layout = ewParams.inputs[i].GetLayout();
|
||||
const auto batch_size = ewParams.inputs[i].Batch().v;
|
||||
@@ -109,6 +116,16 @@ KernelsData EltwiseKernel_vload8::GetKernelsData(const Params& params, const opt
|
||||
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options);
|
||||
|
||||
try {
|
||||
// move a fused activation from fused_ops to activations
|
||||
if (newParams.activations.size() == 0 &&
|
||||
newParams.fused_ops.size() == 1 &&
|
||||
newParams.fused_ops[0].GetType() == KernelType::ACTIVATION) {
|
||||
auto p = newParams.fused_ops[0].GetOpParams<activation_fuse_params>();
|
||||
base_activation_params activation_p = p->param;
|
||||
newParams.activations.push_back(activation_p);
|
||||
newParams.fused_ops.clear();
|
||||
}
|
||||
|
||||
auto cldnn_jit = GetJitConstants(newParams);
|
||||
jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
} catch (const std::runtime_error&) {
|
||||
|
||||
@@ -15,6 +15,11 @@ public:
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return {
|
||||
FusedOpType::ACTIVATION
|
||||
};
|
||||
}
|
||||
|
||||
protected:
|
||||
bool Validate(const Params& p, const optional_params& o) const override;
|
||||
|
||||
Reference in New Issue
Block a user