[IE CLDNN] Fixed scatter update op & reshape kernel (#4106)

This commit is contained in:
Taylor Yeonbok Lee 2021-02-02 20:57:32 +09:00 committed by GitHub
parent 9c1651b5ad
commit 38fab0265d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 253 additions and 165 deletions

View File

@ -156,10 +156,10 @@ REGISTER_FACTORY(v3, EmbeddingBagOffsetsSum);
REGISTER_FACTORY(v3, EmbeddingBagPackedSum);
REGISTER_FACTORY(v3, EmbeddingSegmentsSum);
REGISTER_FACTORY(v3, ExtractImagePatches);
REGISTER_FACTORY(v3, ScatterUpdate);
// REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion
// ----------------------------- Unsupported v3 ops ----------------------------- //
// REGISTER_FACTORY(v3, ScatterUpdate); // There is the scatter_update primitive, but seems like it produces wrong results
// REGISTER_FACTORY(v3, Assign);
// REGISTER_FACTORY(v3, Bucketize);
// REGISTER_FACTORY(v3, GRUCell);
@ -167,7 +167,6 @@ REGISTER_FACTORY(v3, ExtractImagePatches);
// REGISTER_FACTORY(v3, ROIAlign);
// REGISTER_FACTORY(v3, ReadValue);
// REGISTER_FACTORY(v3, ScatterElementsUpdate);
// REGISTER_FACTORY(v3, ScatterUpdate);
// REGISTER_FACTORY(v3, ScatterNDUpdate);
// REGISTER_FACTORY(v3, ShapeOf);
// REGISTER_FACTORY(v3, TopK);

View File

@ -25,8 +25,9 @@ const std::vector<InferenceEngine::Precision> idxPrecisions = {
// map<inputShape, map<indicesShape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 16, 12, 15}, {{{2, 4}, {0, 1, 2, 3}}, {{8}, {-1, -2, -3, -4}}}},
{{10, 9, 10, 9, 10}, {{{8}, {-3, -1, 0, 2, 4}}, {{4, 2}, {-2, 2}}}},
{{10, 16, 12, 15}, {{{2, 2, 2}, {0, 1, 2, 3}}, {{2, 4}, {0, 1, 2, 3}}, {{8}, {0, 1, 2, 3}}}},
{{10, 9, 10, 9, 10}, {{{8}, {0, 1, 2, 3, 4}}, {{4, 2}, {0, 1, 2, 3, 4}}}},
{{10, 9, 10, 9, 10, 12}, {{{8}, {0, 1, 2, 3, 4, 5}}}},
};
//indices should not be random value
const std::vector<std::vector<int64_t>> idxValue = {

View File

@ -48,7 +48,6 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*(LSTMSequence).*mode=CONVERT_TO_TI_RAND_SEQ_LEN.*)",
R"(.*(smoke_DetectionOutput3In).*)",
R"(.*(smoke_DetectionOutput5In).*)",
R"(.*(ScatterUpdateLayerTest).*)",
// INT8 StridedSlice not supported
R"(.*(LPT/StridedSliceTransformation).*)",

View File

@ -67,22 +67,6 @@ ParamsKey ScatterUpdateKernelRef::GetSupportedKey() const {
return k;
}
static size_t GetNonEmptyDimsNumber(const DataTensor& data_tensor) {
if (data_tensor.LogicalSize() != 1) {
// Count the number of "one size" dimensions starting with X to Batch
size_t one_size_dims = 0;
for (auto& i : data_tensor.GetDims()) {
if (i.v == 1)
one_size_dims++;
else
break;
}
return data_tensor.Dimentions() - one_size_dims;
} else {
return 1;
}
}
static inline std::string GetOrderString(std::vector<std::string>& order) {
std::string order_str = order[0];
for (size_t i = 1; i < order.size(); i++)
@ -104,99 +88,90 @@ static inline std::vector<std::string> GetDefaultOrder(size_t size) {
return default_order;
}
static std::string GetUpdatesIndexOrder(const scatter_update_params& params, size_t axis) {
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
for (unsigned int i = 0; i < params.inputs[2].GetDims().size() - params.output.GetDims().size(); i++)
default_order.push_back("0");
size_t indices_non_empty_dims = GetNonEmptyDimsNumber(params.inputs[1]);
std::string FYX_indices_size = "(INPUT1_FEATURE_NUM * INPUT1_SIZE_Y * INPUT1_SIZE_X)";
std::string YX_indices_size = "(INPUT1_SIZE_Y * INPUT1_SIZE_X)";
std::string X_indices_size = "(INPUT1_SIZE_X)";
// Shift indices of ScatterUpdate updates input related to Indices dims
for (size_t i = default_order.size() - 1; i > (axis + indices_non_empty_dims - 1); i--)
default_order[i] = default_order[i - indices_non_empty_dims + 1];
// Insert Indices indexes in axis dimention in the Update index order
for (size_t i = axis; i < (axis + indices_non_empty_dims) && i < default_order.size(); i++) {
switch(i - axis) {
case 0:
default_order[i] = "(OUTPUT_INDEX_ON_AXIS /" + FYX_indices_size + ")";
break;
case 1:
default_order[i] = "((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")/" + YX_indices_size + ")";
break;
case 2:
default_order[i] = "(((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")%" + YX_indices_size + ")/" + X_indices_size + ")";
break;
case 3:
default_order[i] = "(((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")%" + YX_indices_size + ")%" + X_indices_size + ")";
break;
}
static inline std::string GetAxisName(size_t size, size_t axis) {
std::vector<std::string> axis_names;;
if (size <= 4) {
axis_names = {"BATCH", "FEATURE", "Y", "X"};
} else if (size == 5) {
axis_names = {"BATCH", "FEATURE", "Z", "Y", "X"};
} else if (size == 6) {
axis_names = {"BATCH", "FEATURE", "W", "Z", "Y", "X"};
}
return axis_names[axis];
}
static std::string GetUpdatesIndexOrder(const scatter_update_params& params) {
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
return GetOrderString(default_order);
}
CommonDispatchData ScatterUpdateKernelRef::SetDefault(const scatter_update_params& params, const optional_params&, bool is_second) const {
CommonDispatchData dispatchData;
const auto& output = params.output;
const size_t indices_size = params.inputs[1].LogicalSize();
switch (params.inputs[0].GetLayout()) {
case DataLayout::bfyx:
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
if (is_second) {
if (params.axis == ScatterUpdateAxis::BATCH)
dispatchData.gws[2] = indices_size * output.Feature().v;
else if (params.axis == ScatterUpdateAxis::FEATURE)
dispatchData.gws[2] = indices_size * output.Batch().v;
else if (params.axis == ScatterUpdateAxis::Y)
dispatchData.gws[1] = indices_size;
else
dispatchData.gws[0] = indices_size;
if (!is_second) {
switch (output.GetLayout()) {
case DataLayout::bfyx:
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
break;
case DataLayout::bfzyx:
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
break;
case DataLayout::bfwzyx:
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
break;
default:
throw std::runtime_error("Unsupported combination\n");
break;
}
break;
case DataLayout::bfzyx:
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
if (is_second) {
if (params.axis == ScatterUpdateAxis::BATCH)
dispatchData.gws[2] = indices_size * output.Feature().v;
else if (params.axis == ScatterUpdateAxis::FEATURE)
dispatchData.gws[2] = indices_size * output.Batch().v;
else if (params.axis == ScatterUpdateAxis::Z)
dispatchData.gws[1] = indices_size;
else if (params.axis == ScatterUpdateAxis::Y)
dispatchData.gws[0] = indices_size * output.X().v;
else
dispatchData.gws[0] = indices_size * output.Y().v;
} else {
// second iteration
// Each work item is for each tensor in input2.
// Not using input2's shape info directly, because the input2's shape might be reordered from the reordering pass.
// Instead, we reconsider update2's dimension with input1's shape which is shrinked as 1d.
// e.g., axis = b, input0(10, 9, 10, 9, 10) && input1(4, 2) => input2(8, 9, 10, 9, 10
const size_t indices_size = params.inputs[1].LogicalSize();
switch (output.GetLayout()) {
case DataLayout::bfyx:
if (params.axis == ScatterUpdateAxis::BATCH)
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * indices_size};
else if (params.axis == ScatterUpdateAxis::FEATURE)
dispatchData.gws = {output.X().v, output.Y().v, indices_size * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::Y)
dispatchData.gws = {output.X().v, indices_size, output.Feature().v * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::X)
dispatchData.gws = {indices_size, output.Y().v, output.Feature().v * output.Batch().v};
break;
case DataLayout::bfzyx:
if (params.axis == ScatterUpdateAxis::BATCH)
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * indices_size};
else if (params.axis == ScatterUpdateAxis::FEATURE)
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, indices_size * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::Z)
dispatchData.gws = {output.X().v * output.Y().v, indices_size, output.Feature().v * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::Y)
dispatchData.gws = {output.X().v * indices_size, output.Z().v, output.Feature().v * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::X)
dispatchData.gws = {indices_size * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
break;
case DataLayout::bfwzyx:
if (params.axis == ScatterUpdateAxis::BATCH)
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * indices_size};
else if (params.axis == ScatterUpdateAxis::FEATURE)
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, indices_size * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::W)
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * indices_size, output.Feature().v * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::Z)
dispatchData.gws = {output.X().v * output.Y().v, indices_size * output.W().v, output.Feature().v * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::Y)
dispatchData.gws = {output.X().v * indices_size, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
else if (params.axis == ScatterUpdateAxis::X)
dispatchData.gws = {indices_size * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
break;
default:
throw std::runtime_error("Unsupported combination\n");
break;
}
break;
case DataLayout::bfwzyx:
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
if (is_second) {
if (params.axis == ScatterUpdateAxis::BATCH)
dispatchData.gws[2] = indices_size * output.Feature().v;
else if (params.axis == ScatterUpdateAxis::FEATURE)
dispatchData.gws[2] = indices_size * output.Batch().v;
else if (params.axis == ScatterUpdateAxis::Z)
dispatchData.gws[1] = indices_size * output.W().v;
else if (params.axis == ScatterUpdateAxis::W)
dispatchData.gws[1] = indices_size * output.Z().v;
else if (params.axis == ScatterUpdateAxis::Y)
dispatchData.gws[0] = indices_size * output.X().v;
else
dispatchData.gws[0] = indices_size * output.Y().v;
}
break;
default: break;
}
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
return dispatchData;
@ -208,24 +183,51 @@ static std::string GetOutputIndexOnAxis(const scatter_update_params& params, siz
}
static std::vector<std::string> GetVectorSecondOutputIndexOrder(const scatter_update_params& params, size_t axis) {
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
default_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
return default_order;
auto output_order = GetDefaultOrder(params.output.GetDims().size());
output_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
return output_order;
}
static std::string GetSecondIterOutputIndexOrder(const scatter_update_params& params, size_t axis) {
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
default_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
return GetOrderString(default_order);
auto output_order = GetVectorSecondOutputIndexOrder(params, axis);
return GetOrderString(output_order);
}
JitConstants ScatterUpdateKernelRef::GetJitConstants(const scatter_update_params& params) const {
size_t axis_value = GetScatterUpdateChannelIndex(params);
JitConstants jit = MakeBaseParamsJitConstants(params);
jit.AddConstant(MakeJitConstant("UPDATES_INDEX_ORDER", GetUpdatesIndexOrder(params, GetScatterUpdateChannelIndex(params))));
jit.AddConstant(MakeJitConstant("UPDATES_INDEX_ORDER", GetUpdatesIndexOrder(params)));
jit.AddConstant(MakeJitConstant("SECOND_ITER_OUTPUT_INDEX_ORDER", GetSecondIterOutputIndexOrder(params, GetScatterUpdateChannelIndex(params))));
jit.AddConstant(MakeJitConstant("OUTPUT_INDEX_ON_AXIS", GetOutputIndexOnAxis(params, GetScatterUpdateChannelIndex(params))));
jit.AddConstant(MakeJitConstant("AXIS_VALUE", GetScatterUpdateChannelIndex(params)));
jit.AddConstant(MakeJitConstant("AXIS_VALUE", axis_value));
jit.AddConstant(MakeJitConstant("INDICES_SIZE", params.inputs[1].LogicalSize()));
auto default_order = GetDefaultOrder(params.output.GetDims().size());
size_t dims = default_order.size();
std::string get_update_idx = "(INPUT2_OFFSET)";
std::string output_size_feature = "OUTPUT_FEATURE_NUM";
for (size_t i = 0; i < dims; ++i) {
if (i >= axis_value) {
std::string def_pitch = "UPDATES_" + GetAxisName(dims, i) + "_PITCH";
std::string src_pitch = "(OUTPUT_" + GetAxisName(dims, i) + "_PITCH)";
jit.AddConstant(MakeJitConstant(def_pitch, src_pitch));
} else if (i == (axis_value - 1)) {
std::string def_pitch = "UPDATES_" + GetAxisName(dims, i) + "_PITCH";
std::string src_pitch = "(OUTPUT_" + GetAxisName(dims, i + 1) + "_PITCH * INDICES_SIZE)";
jit.AddConstant(MakeJitConstant(def_pitch, src_pitch));
} else { // i < axis_value - 1
std::string def_pitch = "UPDATES_" + GetAxisName(dims, i) + "_PITCH" + "";
std::string output_size_name;
if (i == 0) output_size_name = "OUTPUT_FEATURE_NUM";
else output_size_name = "OUTPUT_SIZE_" + GetAxisName(dims, i + 1);
std::string src_pitch = "(UPDATES_" + GetAxisName(dims, i + 1) + "_PITCH * " + output_size_name + ")";
jit.AddConstant(MakeJitConstant(def_pitch, src_pitch));
}
get_update_idx = get_update_idx + " + (" + default_order[i] + ")*(UPDATES_" + GetAxisName(dims, i) + "_PITCH)";
}
jit.AddConstant(MakeJitConstant("GET_UPDATES_INDEX(idx_order)", get_update_idx));
if (!params.fused_ops.empty()) {
FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
@ -248,6 +250,10 @@ bool ScatterUpdateKernelRef::Validate(const Params& p, const optional_params& o)
return false;
}
if (params.output.PitchesDifferFromLogicalDims() || params.inputs[2].PitchesDifferFromLogicalDims()) {
return false;
}
return true;
}

View File

@ -47,7 +47,8 @@ public:
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE,
return { FusedOpType::ELTWISE,
FusedOpType::QUANTIZE,
FusedOpType::SCALE,
FusedOpType::ACTIVATION };
}

View File

@ -1,5 +1,5 @@
/*
// Copyright (c) 2016 Intel Corporation
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -125,6 +125,32 @@ inline uint8 FUNC(reshape_6_to_4)(uint o, uint i, uint w, uint z, uint y, uint x
return (uint8)(0, dst_b, dst_f, 0, 0, dst_y, dst_x, 0);
}
inline uint8 FUNC(reshape_6_to_5)(uint o, uint i, uint w, uint z, uint y, uint x,
uint src_size_f, uint src_size_w, uint src_size_z, uint src_size_y, uint src_size_x,
uint dst_size_f, uint dst_size_z, uint dst_size_y, uint dst_size_x)
{
const uint src_pitch_x = 1;
const uint src_pitch_y = src_pitch_x * src_size_x;
const uint src_pitch_z = src_pitch_y * src_size_y;
const uint src_pitch_w = src_pitch_z * src_size_z;
const uint src_pitch_f = src_pitch_w * src_size_w;
const uint src_pitch_b = src_pitch_f * src_size_f;
uint flat_idx = x * src_pitch_x + y * src_pitch_y + z * src_pitch_z + w * src_pitch_w + i * src_pitch_f + o * src_pitch_b;
uint dst_x = flat_idx % dst_size_x;
flat_idx /= dst_size_x;
uint dst_y = flat_idx % dst_size_y;
flat_idx /= dst_size_y;
uint dst_z = flat_idx % dst_size_z;
flat_idx /= dst_size_z;
uint dst_f = flat_idx % dst_size_f;
flat_idx /= dst_size_f;
uint dst_b = flat_idx;
return (uint8)(0, dst_b, dst_f, 0, dst_z, dst_y, dst_x, 0);
}
inline uint8 FUNC(reshape_grouped)(uint g, uint o, uint i, uint z, uint y, uint x, uint src_size_ofm, uint dst_size_ofm)
{
const uint flat_ofm = g * src_size_ofm + o;
@ -167,6 +193,10 @@ inline uint8 FUNC(reshape_dims)(
{
return FUNC_CALL(reshape_5_to_4)(o, i, z, y, x, src_size_f, src_size_z, src_size_y, src_size_x, dst_size_f, dst_size_y, dst_size_x);
}
else if (src_dims == 6 && dst_dims == 5)
{
return FUNC_CALL(reshape_6_to_5)(o, i, w, z, y, x, src_size_f, src_size_w, src_size_z, src_size_y, src_size_x, dst_size_f, dst_size_z, dst_size_y, dst_size_x);
}
return (uint8)(0, o, i, w, z, y, x, 0);
}

View File

@ -15,8 +15,15 @@
#include "include/include_all.cl"
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
#define AXIS_B (0)
#define AXIS_F (1)
#define AXIS_W (2)
#define AXIS_Z (OUTPUT_DIMS - 3)
#define AXIS_Y (OUTPUT_DIMS - 2)
#define AXIS_X (OUTPUT_DIMS - 1)
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
#if OUTPUT_DIMS == 4
#define ORDER b,f,y,x
#elif OUTPUT_DIMS == 5
@ -37,7 +44,6 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
const uint dim0 = get_global_id(0);
const uint dim1 = get_global_id(1);
const uint dim2 = get_global_id(2);
#ifndef IS_SECOND_ITER // First kernel
#if OUTPUT_DIMS == 4
const uint x = dim0;
@ -58,8 +64,9 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
#endif
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
INPUT0_TYPE val = dictionary[output_idx];
#if HAS_FUSED_OPS
FUSED_OPS_FIRST_KERNEL;
@ -69,70 +76,64 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
#endif
#else // Second kernel
#if OUTPUT_DIMS == 4
const uint x = dim0;
#if (OUTPUT_DIMS == 4)
// bf|y|x
#if (AXIS_VALUE == AXIS_F)
const uint b = dim2 / INDICES_SIZE;
const uint f = dim2 % INDICES_SIZE;
#else
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint f = dim2 % OUTPUT_FEATURE_NUM;
#endif
const uint y = dim1;
#if AXIS_VALUE == 0
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint x = dim0;
#elif (OUTPUT_DIMS == 5)
// bf|z|yx
#if (AXIS_VALUE == AXIS_F)
const uint b = dim2 / INDICES_SIZE;
const uint f = dim2 % INDICES_SIZE;
#else
const uint f = dim2 / OUTPUT_BATCH_NUM;
const uint b = dim2 % OUTPUT_BATCH_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint f = dim2 % OUTPUT_FEATURE_NUM;
#endif
#elif OUTPUT_DIMS == 5
const uint z = dim1;
#if AXIS_VALUE == 1
const uint f = dim2 / OUTPUT_BATCH_NUM;
const uint b = dim2 % OUTPUT_BATCH_NUM;
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
#elif AXIS_VALUE == 4
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint x = dim0 / OUTPUT_SIZE_Y;
const uint y = dim0 % OUTPUT_SIZE_Y;
#if (AXIS_VALUE == AXIS_X)
const uint y = dim0 / INDICES_SIZE;
const uint x = dim0 % INDICES_SIZE;
#else
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
const uint x = dim0 % OUTPUT_SIZE_X;
#endif
#elif OUTPUT_DIMS == 6
#if AXIS_VALUE == 1
const uint f = dim2 / OUTPUT_BATCH_NUM;
const uint b = dim2 % OUTPUT_BATCH_NUM;
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
const uint z = dim1 % OUTPUT_SIZE_Z;
const uint w = dim1 / OUTPUT_SIZE_Z;
#elif AXIS_VALUE == 3
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
const uint z = dim1 / OUTPUT_SIZE_W;
const uint w = dim1 % OUTPUT_SIZE_W;
#elif AXIS_VALUE == 5
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint x = dim0 / OUTPUT_SIZE_Y;
const uint y = dim0 % OUTPUT_SIZE_Y;
const uint z = dim1 % OUTPUT_SIZE_Z;
const uint w = dim1 / OUTPUT_SIZE_Z;
#elif (OUTPUT_DIMS == 6)
// bf|wz|yx
#if (AXIS_VALUE == AXIS_F)
const uint b = dim2 / INDICES_SIZE;
const uint f = dim2 % INDICES_SIZE;
#else
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
const uint z = dim1 % OUTPUT_SIZE_Z;
const uint f = dim2 % OUTPUT_FEATURE_NUM;
#endif
#if (AXIS_VALUE == AXIS_Z)
const uint w = dim1 / INDICES_SIZE;
const uint z = dim1 % INDICES_SIZE;
#else
const uint w = dim1 / OUTPUT_SIZE_Z;
const uint z = dim1 % OUTPUT_SIZE_Z;
#endif
#if (AXIS_VALUE == AXIS_X)
const uint y = dim0 / INDICES_SIZE;
const uint x = dim0 % INDICES_SIZE;
#else
const uint y = dim0 / OUTPUT_SIZE_X;
const uint x = dim0 % OUTPUT_SIZE_X;
#endif
#endif
const uint output_idx = GET_OUTPUT_INDEX(SECOND_ITER_OUTPUT_INDEX_ORDER);
const uint updates_idx = GET_UPDATES_INDEX(INPUT2, UPDATES_INDEX_ORDER);
const uint updates_idx = GET_UPDATES_INDEX(UPDATES_INDEX_ORDER);
INPUT2_TYPE val = updates[updates_idx];
#if HAS_FUSED_OPS
FUSED_OPS_SECOND_KERNEL;
output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
@ -142,5 +143,10 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
#endif
}
#undef GET_UPDATES_INDEX
#undef GET_OUTPUT_INDEX
#undef AXIS_B
#undef AXIS_F
#undef AXIS_W
#undef AXIS_Z
#undef AXIS_Y
#undef AXIS_X

View File

@ -5743,6 +5743,52 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_update_scale_activation,
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_5, 2, 4 },
}), );
class scatter_update_scale_activation_eltwise : public ScatterUpdatePrimitiveFusingTest {};
TEST_P(scatter_update_scale_activation_eltwise, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("scatter_update_indices", get_repeatless_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)) - 1)),
data("scatter_update_updates", get_mem(get_updates_layout(p), 0, 1000)),
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
data("eltw_data", get_mem(layout(p.default_type, p.default_format, p.dictionary_shape))),
scatter_update("scatter_update_prim", "input", "scatter_update_indices", "scatter_update_updates", p.axis),
activation("activation", "scatter_update_prim", activation_func::abs),
eltwise("eltw", {"activation", "eltw_data"}, eltwise_mode::sum, p.default_type),
scale("scale", "eltw", "scale_data"),
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
);
tolerance = 1e-5f;
execute(p);
}
INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_update_scale_activation_eltwise,
::testing::ValuesIn(std::vector<scatter_update_test_params> {
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_1, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_2, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_3, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_4, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_5, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_1, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_2, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_3, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_4, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_5, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_1, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_2, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_3, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_4, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_5, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_1, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_2, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_3, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_4, 3, 5 },
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_5, 3, 5 },
}), );
/* ------------------------------------------------------------------------------------------------------------ */
/* ---------------------------------------- PERMUTE FUSE cases -------------------------------------------------- */
/* ------------------------------------------------------------------------------------------------------------ */