[IE CLDNN] Fixed scatter update op & reshape kernel (#4106)
This commit is contained in:
parent
9c1651b5ad
commit
38fab0265d
@ -156,10 +156,10 @@ REGISTER_FACTORY(v3, EmbeddingBagOffsetsSum);
|
||||
REGISTER_FACTORY(v3, EmbeddingBagPackedSum);
|
||||
REGISTER_FACTORY(v3, EmbeddingSegmentsSum);
|
||||
REGISTER_FACTORY(v3, ExtractImagePatches);
|
||||
REGISTER_FACTORY(v3, ScatterUpdate);
|
||||
// REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion
|
||||
|
||||
// ----------------------------- Unsupported v3 ops ----------------------------- //
|
||||
// REGISTER_FACTORY(v3, ScatterUpdate); // There is the scatter_update primitive, but seems like it produces wrong results
|
||||
// REGISTER_FACTORY(v3, Assign);
|
||||
// REGISTER_FACTORY(v3, Bucketize);
|
||||
// REGISTER_FACTORY(v3, GRUCell);
|
||||
@ -167,7 +167,6 @@ REGISTER_FACTORY(v3, ExtractImagePatches);
|
||||
// REGISTER_FACTORY(v3, ROIAlign);
|
||||
// REGISTER_FACTORY(v3, ReadValue);
|
||||
// REGISTER_FACTORY(v3, ScatterElementsUpdate);
|
||||
// REGISTER_FACTORY(v3, ScatterUpdate);
|
||||
// REGISTER_FACTORY(v3, ScatterNDUpdate);
|
||||
// REGISTER_FACTORY(v3, ShapeOf);
|
||||
// REGISTER_FACTORY(v3, TopK);
|
||||
|
@ -25,8 +25,9 @@ const std::vector<InferenceEngine::Precision> idxPrecisions = {
|
||||
|
||||
// map<inputShape, map<indicesShape, axis>>
|
||||
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
|
||||
{{10, 16, 12, 15}, {{{2, 4}, {0, 1, 2, 3}}, {{8}, {-1, -2, -3, -4}}}},
|
||||
{{10, 9, 10, 9, 10}, {{{8}, {-3, -1, 0, 2, 4}}, {{4, 2}, {-2, 2}}}},
|
||||
{{10, 16, 12, 15}, {{{2, 2, 2}, {0, 1, 2, 3}}, {{2, 4}, {0, 1, 2, 3}}, {{8}, {0, 1, 2, 3}}}},
|
||||
{{10, 9, 10, 9, 10}, {{{8}, {0, 1, 2, 3, 4}}, {{4, 2}, {0, 1, 2, 3, 4}}}},
|
||||
{{10, 9, 10, 9, 10, 12}, {{{8}, {0, 1, 2, 3, 4, 5}}}},
|
||||
};
|
||||
//indices should not be random value
|
||||
const std::vector<std::vector<int64_t>> idxValue = {
|
||||
|
@ -48,7 +48,6 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*(LSTMSequence).*mode=CONVERT_TO_TI_RAND_SEQ_LEN.*)",
|
||||
R"(.*(smoke_DetectionOutput3In).*)",
|
||||
R"(.*(smoke_DetectionOutput5In).*)",
|
||||
R"(.*(ScatterUpdateLayerTest).*)",
|
||||
|
||||
// INT8 StridedSlice not supported
|
||||
R"(.*(LPT/StridedSliceTransformation).*)",
|
||||
|
@ -67,22 +67,6 @@ ParamsKey ScatterUpdateKernelRef::GetSupportedKey() const {
|
||||
return k;
|
||||
}
|
||||
|
||||
static size_t GetNonEmptyDimsNumber(const DataTensor& data_tensor) {
|
||||
if (data_tensor.LogicalSize() != 1) {
|
||||
// Count the number of "one size" dimensions starting with X to Batch
|
||||
size_t one_size_dims = 0;
|
||||
for (auto& i : data_tensor.GetDims()) {
|
||||
if (i.v == 1)
|
||||
one_size_dims++;
|
||||
else
|
||||
break;
|
||||
}
|
||||
return data_tensor.Dimentions() - one_size_dims;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
static inline std::string GetOrderString(std::vector<std::string>& order) {
|
||||
std::string order_str = order[0];
|
||||
for (size_t i = 1; i < order.size(); i++)
|
||||
@ -104,99 +88,90 @@ static inline std::vector<std::string> GetDefaultOrder(size_t size) {
|
||||
return default_order;
|
||||
}
|
||||
|
||||
static std::string GetUpdatesIndexOrder(const scatter_update_params& params, size_t axis) {
|
||||
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
|
||||
|
||||
for (unsigned int i = 0; i < params.inputs[2].GetDims().size() - params.output.GetDims().size(); i++)
|
||||
default_order.push_back("0");
|
||||
|
||||
size_t indices_non_empty_dims = GetNonEmptyDimsNumber(params.inputs[1]);
|
||||
std::string FYX_indices_size = "(INPUT1_FEATURE_NUM * INPUT1_SIZE_Y * INPUT1_SIZE_X)";
|
||||
std::string YX_indices_size = "(INPUT1_SIZE_Y * INPUT1_SIZE_X)";
|
||||
std::string X_indices_size = "(INPUT1_SIZE_X)";
|
||||
|
||||
// Shift indices of ScatterUpdate updates input related to Indices dims
|
||||
for (size_t i = default_order.size() - 1; i > (axis + indices_non_empty_dims - 1); i--)
|
||||
default_order[i] = default_order[i - indices_non_empty_dims + 1];
|
||||
|
||||
// Insert Indices indexes in axis dimention in the Update index order
|
||||
for (size_t i = axis; i < (axis + indices_non_empty_dims) && i < default_order.size(); i++) {
|
||||
switch(i - axis) {
|
||||
case 0:
|
||||
default_order[i] = "(OUTPUT_INDEX_ON_AXIS /" + FYX_indices_size + ")";
|
||||
break;
|
||||
case 1:
|
||||
default_order[i] = "((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")/" + YX_indices_size + ")";
|
||||
break;
|
||||
case 2:
|
||||
default_order[i] = "(((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")%" + YX_indices_size + ")/" + X_indices_size + ")";
|
||||
break;
|
||||
case 3:
|
||||
default_order[i] = "(((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")%" + YX_indices_size + ")%" + X_indices_size + ")";
|
||||
break;
|
||||
}
|
||||
static inline std::string GetAxisName(size_t size, size_t axis) {
|
||||
std::vector<std::string> axis_names;;
|
||||
if (size <= 4) {
|
||||
axis_names = {"BATCH", "FEATURE", "Y", "X"};
|
||||
} else if (size == 5) {
|
||||
axis_names = {"BATCH", "FEATURE", "Z", "Y", "X"};
|
||||
} else if (size == 6) {
|
||||
axis_names = {"BATCH", "FEATURE", "W", "Z", "Y", "X"};
|
||||
}
|
||||
return axis_names[axis];
|
||||
}
|
||||
|
||||
static std::string GetUpdatesIndexOrder(const scatter_update_params& params) {
|
||||
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
|
||||
return GetOrderString(default_order);
|
||||
}
|
||||
|
||||
CommonDispatchData ScatterUpdateKernelRef::SetDefault(const scatter_update_params& params, const optional_params&, bool is_second) const {
|
||||
CommonDispatchData dispatchData;
|
||||
const auto& output = params.output;
|
||||
|
||||
const size_t indices_size = params.inputs[1].LogicalSize();
|
||||
|
||||
switch (params.inputs[0].GetLayout()) {
|
||||
case DataLayout::bfyx:
|
||||
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
|
||||
if (is_second) {
|
||||
if (params.axis == ScatterUpdateAxis::BATCH)
|
||||
dispatchData.gws[2] = indices_size * output.Feature().v;
|
||||
else if (params.axis == ScatterUpdateAxis::FEATURE)
|
||||
dispatchData.gws[2] = indices_size * output.Batch().v;
|
||||
else if (params.axis == ScatterUpdateAxis::Y)
|
||||
dispatchData.gws[1] = indices_size;
|
||||
else
|
||||
dispatchData.gws[0] = indices_size;
|
||||
if (!is_second) {
|
||||
switch (output.GetLayout()) {
|
||||
case DataLayout::bfyx:
|
||||
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
case DataLayout::bfzyx:
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
case DataLayout::bfwzyx:
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported combination\n");
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case DataLayout::bfzyx:
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
|
||||
if (is_second) {
|
||||
if (params.axis == ScatterUpdateAxis::BATCH)
|
||||
dispatchData.gws[2] = indices_size * output.Feature().v;
|
||||
else if (params.axis == ScatterUpdateAxis::FEATURE)
|
||||
dispatchData.gws[2] = indices_size * output.Batch().v;
|
||||
else if (params.axis == ScatterUpdateAxis::Z)
|
||||
dispatchData.gws[1] = indices_size;
|
||||
else if (params.axis == ScatterUpdateAxis::Y)
|
||||
dispatchData.gws[0] = indices_size * output.X().v;
|
||||
else
|
||||
dispatchData.gws[0] = indices_size * output.Y().v;
|
||||
} else {
|
||||
// second iteration
|
||||
// Each work item is for each tensor in input2.
|
||||
// Not using input2's shape info directly, because the input2's shape might be reordered from the reordering pass.
|
||||
// Instead, we reconsider update2's dimension with input1's shape which is shrinked as 1d.
|
||||
// e.g., axis = b, input0(10, 9, 10, 9, 10) && input1(4, 2) => input2(8, 9, 10, 9, 10
|
||||
const size_t indices_size = params.inputs[1].LogicalSize();
|
||||
switch (output.GetLayout()) {
|
||||
case DataLayout::bfyx:
|
||||
if (params.axis == ScatterUpdateAxis::BATCH)
|
||||
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * indices_size};
|
||||
else if (params.axis == ScatterUpdateAxis::FEATURE)
|
||||
dispatchData.gws = {output.X().v, output.Y().v, indices_size * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::Y)
|
||||
dispatchData.gws = {output.X().v, indices_size, output.Feature().v * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::X)
|
||||
dispatchData.gws = {indices_size, output.Y().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
case DataLayout::bfzyx:
|
||||
if (params.axis == ScatterUpdateAxis::BATCH)
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * indices_size};
|
||||
else if (params.axis == ScatterUpdateAxis::FEATURE)
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, indices_size * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::Z)
|
||||
dispatchData.gws = {output.X().v * output.Y().v, indices_size, output.Feature().v * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::Y)
|
||||
dispatchData.gws = {output.X().v * indices_size, output.Z().v, output.Feature().v * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::X)
|
||||
dispatchData.gws = {indices_size * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
case DataLayout::bfwzyx:
|
||||
if (params.axis == ScatterUpdateAxis::BATCH)
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * indices_size};
|
||||
else if (params.axis == ScatterUpdateAxis::FEATURE)
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, indices_size * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::W)
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * indices_size, output.Feature().v * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::Z)
|
||||
dispatchData.gws = {output.X().v * output.Y().v, indices_size * output.W().v, output.Feature().v * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::Y)
|
||||
dispatchData.gws = {output.X().v * indices_size, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
|
||||
else if (params.axis == ScatterUpdateAxis::X)
|
||||
dispatchData.gws = {indices_size * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported combination\n");
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case DataLayout::bfwzyx:
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
|
||||
if (is_second) {
|
||||
if (params.axis == ScatterUpdateAxis::BATCH)
|
||||
dispatchData.gws[2] = indices_size * output.Feature().v;
|
||||
else if (params.axis == ScatterUpdateAxis::FEATURE)
|
||||
dispatchData.gws[2] = indices_size * output.Batch().v;
|
||||
else if (params.axis == ScatterUpdateAxis::Z)
|
||||
dispatchData.gws[1] = indices_size * output.W().v;
|
||||
else if (params.axis == ScatterUpdateAxis::W)
|
||||
dispatchData.gws[1] = indices_size * output.Z().v;
|
||||
else if (params.axis == ScatterUpdateAxis::Y)
|
||||
dispatchData.gws[0] = indices_size * output.X().v;
|
||||
else
|
||||
dispatchData.gws[0] = indices_size * output.Y().v;
|
||||
}
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||
|
||||
return dispatchData;
|
||||
@ -208,24 +183,51 @@ static std::string GetOutputIndexOnAxis(const scatter_update_params& params, siz
|
||||
}
|
||||
|
||||
static std::vector<std::string> GetVectorSecondOutputIndexOrder(const scatter_update_params& params, size_t axis) {
|
||||
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
|
||||
default_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
|
||||
return default_order;
|
||||
auto output_order = GetDefaultOrder(params.output.GetDims().size());
|
||||
output_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
|
||||
return output_order;
|
||||
}
|
||||
|
||||
static std::string GetSecondIterOutputIndexOrder(const scatter_update_params& params, size_t axis) {
|
||||
std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
|
||||
default_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
|
||||
return GetOrderString(default_order);
|
||||
auto output_order = GetVectorSecondOutputIndexOrder(params, axis);
|
||||
return GetOrderString(output_order);
|
||||
}
|
||||
|
||||
JitConstants ScatterUpdateKernelRef::GetJitConstants(const scatter_update_params& params) const {
|
||||
size_t axis_value = GetScatterUpdateChannelIndex(params);
|
||||
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
|
||||
jit.AddConstant(MakeJitConstant("UPDATES_INDEX_ORDER", GetUpdatesIndexOrder(params, GetScatterUpdateChannelIndex(params))));
|
||||
jit.AddConstant(MakeJitConstant("UPDATES_INDEX_ORDER", GetUpdatesIndexOrder(params)));
|
||||
jit.AddConstant(MakeJitConstant("SECOND_ITER_OUTPUT_INDEX_ORDER", GetSecondIterOutputIndexOrder(params, GetScatterUpdateChannelIndex(params))));
|
||||
jit.AddConstant(MakeJitConstant("OUTPUT_INDEX_ON_AXIS", GetOutputIndexOnAxis(params, GetScatterUpdateChannelIndex(params))));
|
||||
jit.AddConstant(MakeJitConstant("AXIS_VALUE", GetScatterUpdateChannelIndex(params)));
|
||||
jit.AddConstant(MakeJitConstant("AXIS_VALUE", axis_value));
|
||||
jit.AddConstant(MakeJitConstant("INDICES_SIZE", params.inputs[1].LogicalSize()));
|
||||
|
||||
auto default_order = GetDefaultOrder(params.output.GetDims().size());
|
||||
size_t dims = default_order.size();
|
||||
std::string get_update_idx = "(INPUT2_OFFSET)";
|
||||
std::string output_size_feature = "OUTPUT_FEATURE_NUM";
|
||||
for (size_t i = 0; i < dims; ++i) {
|
||||
if (i >= axis_value) {
|
||||
std::string def_pitch = "UPDATES_" + GetAxisName(dims, i) + "_PITCH";
|
||||
std::string src_pitch = "(OUTPUT_" + GetAxisName(dims, i) + "_PITCH)";
|
||||
jit.AddConstant(MakeJitConstant(def_pitch, src_pitch));
|
||||
} else if (i == (axis_value - 1)) {
|
||||
std::string def_pitch = "UPDATES_" + GetAxisName(dims, i) + "_PITCH";
|
||||
std::string src_pitch = "(OUTPUT_" + GetAxisName(dims, i + 1) + "_PITCH * INDICES_SIZE)";
|
||||
jit.AddConstant(MakeJitConstant(def_pitch, src_pitch));
|
||||
} else { // i < axis_value - 1
|
||||
std::string def_pitch = "UPDATES_" + GetAxisName(dims, i) + "_PITCH" + "";
|
||||
std::string output_size_name;
|
||||
if (i == 0) output_size_name = "OUTPUT_FEATURE_NUM";
|
||||
else output_size_name = "OUTPUT_SIZE_" + GetAxisName(dims, i + 1);
|
||||
std::string src_pitch = "(UPDATES_" + GetAxisName(dims, i + 1) + "_PITCH * " + output_size_name + ")";
|
||||
jit.AddConstant(MakeJitConstant(def_pitch, src_pitch));
|
||||
}
|
||||
get_update_idx = get_update_idx + " + (" + default_order[i] + ")*(UPDATES_" + GetAxisName(dims, i) + "_PITCH)";
|
||||
}
|
||||
jit.AddConstant(MakeJitConstant("GET_UPDATES_INDEX(idx_order)", get_update_idx));
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
|
||||
@ -248,6 +250,10 @@ bool ScatterUpdateKernelRef::Validate(const Params& p, const optional_params& o)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params.output.PitchesDifferFromLogicalDims() || params.inputs[2].PitchesDifferFromLogicalDims()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,8 @@ public:
|
||||
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::QUANTIZE,
|
||||
return { FusedOpType::ELTWISE,
|
||||
FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ACTIVATION };
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -125,6 +125,32 @@ inline uint8 FUNC(reshape_6_to_4)(uint o, uint i, uint w, uint z, uint y, uint x
|
||||
return (uint8)(0, dst_b, dst_f, 0, 0, dst_y, dst_x, 0);
|
||||
}
|
||||
|
||||
inline uint8 FUNC(reshape_6_to_5)(uint o, uint i, uint w, uint z, uint y, uint x,
|
||||
uint src_size_f, uint src_size_w, uint src_size_z, uint src_size_y, uint src_size_x,
|
||||
uint dst_size_f, uint dst_size_z, uint dst_size_y, uint dst_size_x)
|
||||
{
|
||||
const uint src_pitch_x = 1;
|
||||
const uint src_pitch_y = src_pitch_x * src_size_x;
|
||||
const uint src_pitch_z = src_pitch_y * src_size_y;
|
||||
const uint src_pitch_w = src_pitch_z * src_size_z;
|
||||
const uint src_pitch_f = src_pitch_w * src_size_w;
|
||||
const uint src_pitch_b = src_pitch_f * src_size_f;
|
||||
|
||||
uint flat_idx = x * src_pitch_x + y * src_pitch_y + z * src_pitch_z + w * src_pitch_w + i * src_pitch_f + o * src_pitch_b;
|
||||
|
||||
uint dst_x = flat_idx % dst_size_x;
|
||||
flat_idx /= dst_size_x;
|
||||
uint dst_y = flat_idx % dst_size_y;
|
||||
flat_idx /= dst_size_y;
|
||||
uint dst_z = flat_idx % dst_size_z;
|
||||
flat_idx /= dst_size_z;
|
||||
uint dst_f = flat_idx % dst_size_f;
|
||||
flat_idx /= dst_size_f;
|
||||
uint dst_b = flat_idx;
|
||||
return (uint8)(0, dst_b, dst_f, 0, dst_z, dst_y, dst_x, 0);
|
||||
}
|
||||
|
||||
|
||||
inline uint8 FUNC(reshape_grouped)(uint g, uint o, uint i, uint z, uint y, uint x, uint src_size_ofm, uint dst_size_ofm)
|
||||
{
|
||||
const uint flat_ofm = g * src_size_ofm + o;
|
||||
@ -167,6 +193,10 @@ inline uint8 FUNC(reshape_dims)(
|
||||
{
|
||||
return FUNC_CALL(reshape_5_to_4)(o, i, z, y, x, src_size_f, src_size_z, src_size_y, src_size_x, dst_size_f, dst_size_y, dst_size_x);
|
||||
}
|
||||
else if (src_dims == 6 && dst_dims == 5)
|
||||
{
|
||||
return FUNC_CALL(reshape_6_to_5)(o, i, w, z, y, x, src_size_f, src_size_w, src_size_z, src_size_y, src_size_x, dst_size_f, dst_size_z, dst_size_y, dst_size_x);
|
||||
}
|
||||
|
||||
return (uint8)(0, o, i, w, z, y, x, 0);
|
||||
}
|
||||
|
@ -15,8 +15,15 @@
|
||||
|
||||
#include "include/include_all.cl"
|
||||
|
||||
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
|
||||
#define AXIS_B (0)
|
||||
#define AXIS_F (1)
|
||||
#define AXIS_W (2)
|
||||
#define AXIS_Z (OUTPUT_DIMS - 3)
|
||||
#define AXIS_Y (OUTPUT_DIMS - 2)
|
||||
#define AXIS_X (OUTPUT_DIMS - 1)
|
||||
|
||||
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
|
||||
|
||||
#if OUTPUT_DIMS == 4
|
||||
#define ORDER b,f,y,x
|
||||
#elif OUTPUT_DIMS == 5
|
||||
@ -37,7 +44,6 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
|
||||
const uint dim0 = get_global_id(0);
|
||||
const uint dim1 = get_global_id(1);
|
||||
const uint dim2 = get_global_id(2);
|
||||
|
||||
#ifndef IS_SECOND_ITER // First kernel
|
||||
#if OUTPUT_DIMS == 4
|
||||
const uint x = dim0;
|
||||
@ -58,8 +64,9 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
#endif
|
||||
|
||||
|
||||
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
|
||||
|
||||
INPUT0_TYPE val = dictionary[output_idx];
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS_FIRST_KERNEL;
|
||||
@ -69,70 +76,64 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
|
||||
#endif
|
||||
|
||||
#else // Second kernel
|
||||
#if OUTPUT_DIMS == 4
|
||||
const uint x = dim0;
|
||||
#if (OUTPUT_DIMS == 4)
|
||||
// bf|y|x
|
||||
#if (AXIS_VALUE == AXIS_F)
|
||||
const uint b = dim2 / INDICES_SIZE;
|
||||
const uint f = dim2 % INDICES_SIZE;
|
||||
#else
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
#endif
|
||||
const uint y = dim1;
|
||||
#if AXIS_VALUE == 0
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint x = dim0;
|
||||
#elif (OUTPUT_DIMS == 5)
|
||||
// bf|z|yx
|
||||
#if (AXIS_VALUE == AXIS_F)
|
||||
const uint b = dim2 / INDICES_SIZE;
|
||||
const uint f = dim2 % INDICES_SIZE;
|
||||
#else
|
||||
const uint f = dim2 / OUTPUT_BATCH_NUM;
|
||||
const uint b = dim2 % OUTPUT_BATCH_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
#endif
|
||||
#elif OUTPUT_DIMS == 5
|
||||
const uint z = dim1;
|
||||
#if AXIS_VALUE == 1
|
||||
const uint f = dim2 / OUTPUT_BATCH_NUM;
|
||||
const uint b = dim2 % OUTPUT_BATCH_NUM;
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||
#elif AXIS_VALUE == 4
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint x = dim0 / OUTPUT_SIZE_Y;
|
||||
const uint y = dim0 % OUTPUT_SIZE_Y;
|
||||
#if (AXIS_VALUE == AXIS_X)
|
||||
const uint y = dim0 / INDICES_SIZE;
|
||||
const uint x = dim0 % INDICES_SIZE;
|
||||
#else
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
#endif
|
||||
#elif OUTPUT_DIMS == 6
|
||||
#if AXIS_VALUE == 1
|
||||
const uint f = dim2 / OUTPUT_BATCH_NUM;
|
||||
const uint b = dim2 % OUTPUT_BATCH_NUM;
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||
const uint z = dim1 % OUTPUT_SIZE_Z;
|
||||
const uint w = dim1 / OUTPUT_SIZE_Z;
|
||||
#elif AXIS_VALUE == 3
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||
const uint z = dim1 / OUTPUT_SIZE_W;
|
||||
const uint w = dim1 % OUTPUT_SIZE_W;
|
||||
#elif AXIS_VALUE == 5
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint x = dim0 / OUTPUT_SIZE_Y;
|
||||
const uint y = dim0 % OUTPUT_SIZE_Y;
|
||||
const uint z = dim1 % OUTPUT_SIZE_Z;
|
||||
const uint w = dim1 / OUTPUT_SIZE_Z;
|
||||
#elif (OUTPUT_DIMS == 6)
|
||||
// bf|wz|yx
|
||||
#if (AXIS_VALUE == AXIS_F)
|
||||
const uint b = dim2 / INDICES_SIZE;
|
||||
const uint f = dim2 % INDICES_SIZE;
|
||||
#else
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||
const uint z = dim1 % OUTPUT_SIZE_Z;
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
#endif
|
||||
#if (AXIS_VALUE == AXIS_Z)
|
||||
const uint w = dim1 / INDICES_SIZE;
|
||||
const uint z = dim1 % INDICES_SIZE;
|
||||
#else
|
||||
const uint w = dim1 / OUTPUT_SIZE_Z;
|
||||
const uint z = dim1 % OUTPUT_SIZE_Z;
|
||||
#endif
|
||||
#if (AXIS_VALUE == AXIS_X)
|
||||
const uint y = dim0 / INDICES_SIZE;
|
||||
const uint x = dim0 % INDICES_SIZE;
|
||||
#else
|
||||
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
const uint output_idx = GET_OUTPUT_INDEX(SECOND_ITER_OUTPUT_INDEX_ORDER);
|
||||
const uint updates_idx = GET_UPDATES_INDEX(INPUT2, UPDATES_INDEX_ORDER);
|
||||
const uint updates_idx = GET_UPDATES_INDEX(UPDATES_INDEX_ORDER);
|
||||
|
||||
INPUT2_TYPE val = updates[updates_idx];
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS_SECOND_KERNEL;
|
||||
output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
|
||||
@ -142,5 +143,10 @@ KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
|
||||
#endif
|
||||
}
|
||||
|
||||
#undef GET_UPDATES_INDEX
|
||||
#undef GET_OUTPUT_INDEX
|
||||
#undef AXIS_B
|
||||
#undef AXIS_F
|
||||
#undef AXIS_W
|
||||
#undef AXIS_Z
|
||||
#undef AXIS_Y
|
||||
#undef AXIS_X
|
||||
|
@ -5743,6 +5743,52 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_update_scale_activation,
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_5, 2, 4 },
|
||||
}), );
|
||||
|
||||
class scatter_update_scale_activation_eltwise : public ScatterUpdatePrimitiveFusingTest {};
|
||||
TEST_P(scatter_update_scale_activation_eltwise, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("scatter_update_indices", get_repeatless_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)) - 1)),
|
||||
data("scatter_update_updates", get_mem(get_updates_layout(p), 0, 1000)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
|
||||
data("eltw_data", get_mem(layout(p.default_type, p.default_format, p.dictionary_shape))),
|
||||
scatter_update("scatter_update_prim", "input", "scatter_update_indices", "scatter_update_updates", p.axis),
|
||||
activation("activation", "scatter_update_prim", activation_func::abs),
|
||||
eltwise("eltw", {"activation", "eltw_data"}, eltwise_mode::sum, p.default_type),
|
||||
scale("scale", "eltw", "scale_data"),
|
||||
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
|
||||
);
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_update_scale_activation_eltwise,
|
||||
::testing::ValuesIn(std::vector<scatter_update_test_params> {
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_1, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_2, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_3, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_4, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_5, 3, 5 },
|
||||
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_1, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_2, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_3, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_4, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_5, 3, 5 },
|
||||
|
||||
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_1, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_2, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_3, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_4, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_5, 3, 5 },
|
||||
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_1, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_2, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_3, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_4, 3, 5 },
|
||||
scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_5, 3, 5 },
|
||||
|
||||
}), );
|
||||
/* ------------------------------------------------------------------------------------------------------------ */
|
||||
/* ---------------------------------------- PERMUTE FUSE cases -------------------------------------------------- */
|
||||
/* ------------------------------------------------------------------------------------------------------------ */
|
||||
|
Loading…
Reference in New Issue
Block a user