[IE CLDNN] Improvements for SpaceToDepth (#1454)
This commit is contained in:
parent
0560b61cbd
commit
3c99c13feb
@ -73,13 +73,18 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneNetwork(const InferenceEngin
|
||||
std::shared_ptr<ICNNNetwork> clonedNetwork = cloneNetwork(network);
|
||||
if (clonedNetwork->getFunction()) {
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
// DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
|
||||
// Reshape->Permute->Reshape pattern in theory can change output rank, so this check is added to be sure
|
||||
// that DepthToSpace impl will handle fused case
|
||||
// that the following primitives will be handled correctly
|
||||
// DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
|
||||
if (auto dtsOp = std::dynamic_pointer_cast<const ::ngraph::opset3::DepthToSpace>(node)) {
|
||||
return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size();
|
||||
}
|
||||
|
||||
// SpaceToDepth node implementation supports only equal input/output tensors with rank <= 5
|
||||
if (auto stdOp = std::dynamic_pointer_cast<const ::ngraph::opset3::SpaceToDepth>(node)) {
|
||||
return stdOp->input_value(0).get_shape().size() <= 5lu && stdOp->input_value(0).get_shape().size() == stdOp->get_output_shape(0).size();
|
||||
}
|
||||
|
||||
return std::dynamic_pointer_cast<const ::ngraph::opset2::Gelu>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset3::ShuffleChannels>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset2::BatchToSpace>(node) ||
|
||||
|
@ -186,7 +186,8 @@ InferenceEngine::ICNNNetwork::Ptr CLDNNGraph::GetExecGraphInfoByPrimitivesInfo(s
|
||||
{ "reduce_l1", "ReduceL1" },
|
||||
{ "reduce_l2", "ReduceL2" },
|
||||
{ "reduce_log_sum", "ReduceLogSum" },
|
||||
{ "reduce_log_sum_exp", "ReduceLogSumExp" }
|
||||
{ "reduce_log_sum_exp", "ReduceLogSumExp" },
|
||||
{ "space_to_depth", "SpaceToDepth" },
|
||||
};
|
||||
|
||||
if (type_n2l.find(cldnn_name) != type_n2l.end())
|
||||
|
@ -3979,7 +3979,7 @@ void Program::CreateSpaceToDepthPrimitive(cldnn::topology& topology, InferenceEn
|
||||
auto spaceToDepth = as<InferenceEngine::GenericLayer*> (layer);
|
||||
|
||||
size_t blockSize = static_cast<size_t>(spaceToDepth->GetParamAsUInt("block_size", 1));
|
||||
std::string modeAsString = spaceToDepth->GetParamAsString("depth_mode", "blocks_first");
|
||||
std::string modeAsString = spaceToDepth->GetParamAsString("mode", "blocks_first");
|
||||
cldnn::space_to_depth::depth_mode mode;
|
||||
mode = (modeAsString == "blocks_first") ? cldnn::space_to_depth::blocks_first : cldnn::space_to_depth::depth_first;
|
||||
|
||||
|
@ -1,18 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "space_to_depth_tests.hpp"
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_GPU_TestsSpaceToDepth, SpaceToDepthTests,
|
||||
::testing::Values(
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 6, 4 }, "blocks_first", 2, { 1, 4, 3, 2 } },
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 9, 9 }, "blocks_first", 3, { 1, 9, 3, 3 } },
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 2, 9, 9 }, "blocks_first", 3, { 1, 18, 3, 3 } },
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 10, 4096, 1024 }, "blocks_first", 4, { 1, 160, 1024, 256 } },
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 6, 4 }, "depth_first", 2, { 1, 4, 3, 2 } },
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 9, 9 }, "depth_first", 3, { 1, 9, 3, 3 } },
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 2, 9, 9 }, "depth_first", 3, { 1, 18, 3, 3 } },
|
||||
space_to_depth_test_params{ "GPU", "FP32", { 1, 10, 4096, 1024 }, "depth_first", 4, { 1, 160, 1024, 256 } }
|
||||
));
|
@ -20,72 +20,109 @@
|
||||
#include <vector>
|
||||
|
||||
namespace kernel_selector {
|
||||
ParamsKey SpaceToDepthKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableAllInputLayout();
|
||||
k.EnableAllOutputLayout();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
return k;
|
||||
ParamsKey SpaceToDepthKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::INT8);
|
||||
k.EnableInputDataType(Datatype::UINT8);
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableAllInputLayout();
|
||||
k.EnableAllOutputLayout();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
return k;
|
||||
}
|
||||
|
||||
bool SpaceToDepthKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||
if (p.GetType() != KernelType::SPACE_TO_DEPTH ||
|
||||
o.GetType() != KernelType::SPACE_TO_DEPTH) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CommonDispatchData SpaceToDepthKernelRef::SetDefault(const space_to_depth_params& params,
|
||||
const optional_params&) const {
|
||||
CommonDispatchData runInfo;
|
||||
|
||||
std::vector<size_t> global = {params.output.Batch().v,
|
||||
params.output.Feature().v,
|
||||
params.output.Y().v * params.output.X().v};
|
||||
|
||||
auto local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
|
||||
|
||||
runInfo.gws0 = global[0];
|
||||
runInfo.gws1 = global[1];
|
||||
runInfo.gws2 = global[2];
|
||||
|
||||
runInfo.lws0 = local[0];
|
||||
runInfo.lws1 = local[1];
|
||||
runInfo.lws2 = local[2];
|
||||
|
||||
return runInfo;
|
||||
const space_to_depth_params& params = static_cast<const space_to_depth_params&>(p);
|
||||
for (auto& fused_op : params.fused_ops) {
|
||||
if (!IsFusedPrimitiveSupported(fused_op))
|
||||
return false;
|
||||
}
|
||||
|
||||
JitConstants SpaceToDepthKernelRef::GetJitConstants(const space_to_depth_params& params) const {
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
if (params.inputs[0].Dimentions() > 5)
|
||||
return false;
|
||||
|
||||
const size_t block_size = params.block_size;
|
||||
const size_t squared_block_size = params.block_size * params.block_size;
|
||||
const size_t blocks_first_mode = (size_t)params.depth_mode;
|
||||
return true;
|
||||
}
|
||||
|
||||
jit.AddConstant(MakeJitConstant("BLOCK_SIZE", block_size));
|
||||
jit.AddConstant(MakeJitConstant("SQUARED_BLOCK_SIZE", squared_block_size));
|
||||
jit.AddConstant(MakeJitConstant("BLOCKS_FIRST_MODE", blocks_first_mode));
|
||||
CommonDispatchData SpaceToDepthKernelRef::SetDefault(const space_to_depth_params& params,
|
||||
const optional_params&) const {
|
||||
CommonDispatchData runInfo;
|
||||
|
||||
return jit;
|
||||
std::vector<size_t> global = {params.output.Batch().v,
|
||||
params.output.Feature().v,
|
||||
params.output.Z().v * params.output.Y().v * params.output.X().v};
|
||||
|
||||
auto local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
|
||||
|
||||
runInfo.gws0 = global[0];
|
||||
runInfo.gws1 = global[1];
|
||||
runInfo.gws2 = global[2];
|
||||
|
||||
runInfo.lws0 = local[0];
|
||||
runInfo.lws1 = local[1];
|
||||
runInfo.lws2 = local[2];
|
||||
|
||||
return runInfo;
|
||||
}
|
||||
|
||||
JitConstants SpaceToDepthKernelRef::GetJitConstants(const space_to_depth_params& params) const {
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
|
||||
jit.AddConstant(MakeJitConstant("BLOCK_SIZE", params.block_size));
|
||||
if (params.depth_mode == SpaceToDepthMode::BLOCKS_FIRST)
|
||||
jit.AddConstant(MakeJitConstant("BLOCKS_FIRST_MODE", true));
|
||||
else
|
||||
jit.AddConstant(MakeJitConstant("DEPTH_FIRST_MODE", true));
|
||||
|
||||
auto input = params.inputs[0];
|
||||
auto input_dt = input.GetDType();
|
||||
if (!params.fused_ops.empty()) {
|
||||
std::vector<std::string> idx_order;
|
||||
if (input.Dimentions() == 5) {
|
||||
idx_order = {"batch", "feature", "z", "y", "x"};
|
||||
} else if (input.Dimentions() == 4) {
|
||||
idx_order = {"batch", "feature", "y", "x"};
|
||||
}
|
||||
FusedOpsConfiguration conf = {"", idx_order, "in_val", input_dt, 1};
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
|
||||
}
|
||||
|
||||
KernelsData SpaceToDepthKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||
KernelData kd = KernelData::Default<space_to_depth_params>(params);
|
||||
space_to_depth_params& newParams = *static_cast<space_to_depth_params*>(kd.params.get());
|
||||
return jit;
|
||||
}
|
||||
|
||||
assert(params.GetType() == KernelType::SPACE_TO_DEPTH);
|
||||
KernelsData SpaceToDepthKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||
KernelData kd = KernelData::Default<space_to_depth_params>(params);
|
||||
space_to_depth_params& newParams = *static_cast<space_to_depth_params*>(kd.params.get());
|
||||
|
||||
auto runInfo = SetDefault(newParams, options);
|
||||
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
|
||||
auto cldnn_jit = GetJitConstants(newParams);
|
||||
std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
|
||||
auto& kernel = kd.kernels[0];
|
||||
|
||||
FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
|
||||
|
||||
kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;
|
||||
|
||||
return {kd};
|
||||
if (!Validate(params, options)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto runInfo = SetDefault(newParams, options);
|
||||
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
|
||||
auto cldnn_jit = GetJitConstants(newParams);
|
||||
std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
|
||||
auto& kernel = kd.kernels[0];
|
||||
|
||||
FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point,
|
||||
DEFAULT, false, false, 1, GetFusedPrimitiveInputsCount(params));
|
||||
|
||||
kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;
|
||||
|
||||
return {kd};
|
||||
}
|
||||
} // namespace kernel_selector
|
||||
|
@ -43,10 +43,19 @@ struct space_to_depth_optional_params : optional_params {
|
||||
class SpaceToDepthKernelRef : public common_kernel_base {
|
||||
public:
|
||||
SpaceToDepthKernelRef() : common_kernel_base("space_to_depth_ref") {}
|
||||
virtual ~SpaceToDepthKernelRef() {}
|
||||
virtual JitConstants GetJitConstants(const space_to_depth_params& params) const;
|
||||
virtual CommonDispatchData SetDefault(const space_to_depth_params& params, const optional_params&) const;
|
||||
virtual ~SpaceToDepthKernelRef() = default;
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
|
||||
protected:
|
||||
virtual CommonDispatchData SetDefault(const space_to_depth_params& params, const optional_params&) const;
|
||||
virtual JitConstants GetJitConstants(const space_to_depth_params& params) const;
|
||||
virtual bool Validate(const Params& p, const optional_params& o) const;
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::ELTWISE,
|
||||
FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ACTIVATION };
|
||||
}
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
|
@ -14,22 +14,59 @@
|
||||
|
||||
#include "include/include_all.cl"
|
||||
|
||||
KERNEL(space_to_depth_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output)
|
||||
#if OUTPUT_DIMS == 5
|
||||
#define SPATIAL_BLOCK_SIZE (BLOCK_SIZE*BLOCK_SIZE*BLOCK_SIZE)
|
||||
#else
|
||||
#define SPATIAL_BLOCK_SIZE (BLOCK_SIZE*BLOCK_SIZE)
|
||||
#endif
|
||||
|
||||
KERNEL(space_to_depth_ref)(const __global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
, FUSED_OPS_DECLS
|
||||
#endif
|
||||
)
|
||||
{
|
||||
const uint batch = get_global_id(0);
|
||||
const uint feature = get_global_id(1);
|
||||
|
||||
#if OUTPUT_DIMS == 5
|
||||
const uint z = ((uint)get_global_id(2) / OUTPUT_SIZE_X) / OUTPUT_SIZE_Y;
|
||||
const uint y = ((uint)get_global_id(2) / OUTPUT_SIZE_X) % OUTPUT_SIZE_Y;
|
||||
const uint x = (uint)get_global_id(2) % OUTPUT_SIZE_X;
|
||||
#else
|
||||
const uint z = 0;
|
||||
const uint y = (uint)get_global_id(2) / OUTPUT_SIZE_X;
|
||||
const uint x = (uint)get_global_id(2) % OUTPUT_SIZE_X;
|
||||
#endif
|
||||
|
||||
const uint input_offset = BLOCKS_FIRST_MODE * (feature / INPUT0_FEATURE_NUM) + (!BLOCKS_FIRST_MODE) * (feature % SQUARED_BLOCK_SIZE);
|
||||
#if BLOCKS_FIRST_MODE
|
||||
const uint input_offset = feature / INPUT0_FEATURE_NUM;
|
||||
const uint input_feature = feature % INPUT0_FEATURE_NUM;
|
||||
#else
|
||||
const uint input_offset = feature % SPATIAL_BLOCK_SIZE;
|
||||
const uint input_feature = feature / SPATIAL_BLOCK_SIZE;
|
||||
#endif
|
||||
|
||||
#if OUTPUT_DIMS == 5
|
||||
const uint input_z = (z * BLOCK_SIZE) + ((input_offset / BLOCK_SIZE) / BLOCK_SIZE);
|
||||
const uint input_y = (y * BLOCK_SIZE) + ((input_offset / BLOCK_SIZE) % BLOCK_SIZE);
|
||||
const uint input_x = (x * BLOCK_SIZE) + (input_offset % BLOCK_SIZE);
|
||||
const uint input_index = INPUT0_GET_INDEX(batch, input_feature, input_z, input_y, input_x);
|
||||
const uint output_index = OUTPUT_GET_INDEX(batch, feature, z, y, x);
|
||||
#else
|
||||
const uint input_z = 0;
|
||||
const uint input_y = (y * BLOCK_SIZE) + (input_offset / BLOCK_SIZE);
|
||||
const uint input_x = (x * BLOCK_SIZE) + (input_offset % BLOCK_SIZE);
|
||||
const uint input_index = INPUT0_GET_INDEX(batch, input_feature, input_y, input_x);
|
||||
const uint output_index = OUTPUT_GET_INDEX(batch, feature, y, x);
|
||||
#endif
|
||||
|
||||
const uint input_feature = BLOCKS_FIRST_MODE * (feature % INPUT0_FEATURE_NUM) + (!BLOCKS_FIRST_MODE) * (feature / SQUARED_BLOCK_SIZE);
|
||||
const uint input_feature_offset = (input_y * INPUT0_Y_PITCH) + input_x;
|
||||
|
||||
const uint input_index = INPUT0_OFFSET + (batch * INPUT0_BATCH_PITCH) + (input_feature * INPUT0_FEATURE_PITCH) + input_feature_offset;
|
||||
const uint output_index = OUTPUT_OFFSET + (batch * OUTPUT_BATCH_PITCH) + (feature * OUTPUT_FEATURE_PITCH) + (y * OUTPUT_Y_PITCH) + x;
|
||||
|
||||
output[output_index] = ACTIVATION(input[input_index], ACTIVATION_PARAMS);
|
||||
INPUT0_TYPE in_val = input[input_index];
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS;
|
||||
output[output_index] = FUSED_OPS_RESULT;
|
||||
#else
|
||||
output[output_index] = ACTIVATION(in_val, ACTIVATION_PARAMS);
|
||||
#endif
|
||||
}
|
||||
|
@ -60,10 +60,26 @@ namespace detail {
|
||||
|
||||
attach_space_to_depth_gpu::attach_space_to_depth_gpu() {
|
||||
auto val_fw = space_to_depth_gpu::create;
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx),
|
||||
val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx),
|
||||
val_fw);
|
||||
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfzyx), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfzyx), val_fw);
|
||||
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw);
|
||||
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv16), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv16), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv16), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv16), val_fw);
|
||||
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv4), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv4), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv4), val_fw);
|
||||
implementation_map<space_to_depth>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv4), val_fw);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -42,6 +42,7 @@
|
||||
#include "scale_inst.h"
|
||||
#include "resample_inst.h"
|
||||
#include "depth_to_space_inst.h"
|
||||
#include "space_to_depth_inst.h"
|
||||
#include "gather_inst.h"
|
||||
#include "reverse_sequence_inst.h"
|
||||
#include "shuffle_channels_inst.h"
|
||||
@ -375,6 +376,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<depth_to_space>();
|
||||
|
||||
should_fuse |= input_data.is_type<space_to_depth>();
|
||||
|
||||
if (!should_fuse)
|
||||
return;
|
||||
|
||||
@ -420,6 +423,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<depth_to_space>();
|
||||
|
||||
should_fuse |= input_data.is_type<space_to_depth>();
|
||||
|
||||
if (!should_fuse)
|
||||
return;
|
||||
|
||||
@ -496,6 +501,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<depth_to_space>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<space_to_depth>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
if (!should_fuse)
|
||||
return;
|
||||
|
||||
@ -517,12 +524,12 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
bool can_fuse_parent1 = (parent1->is_type<convolution>() && conv_supports_fusings(parent1->as<convolution>())) ||
|
||||
(parent1->is_type<mvn>() && mvn_supports_fusings(parent1->as<mvn>())) ||
|
||||
(parent1->is_type<deconvolution>()) || (parent1->is_type<permute>()) ||
|
||||
(parent1->is_type<depth_to_space>()) || (parent1->is_type<gemm>());
|
||||
(parent1->is_type<depth_to_space>()) || (parent1->is_type<space_to_depth>()) || (parent1->is_type<gemm>());
|
||||
|
||||
bool can_fuse_parent2 = (parent2->is_type<convolution>() && conv_supports_fusings(parent2->as<convolution>())) ||
|
||||
(parent2->is_type<mvn>() && mvn_supports_fusings(parent2->as<mvn>())) ||
|
||||
(parent2->is_type<deconvolution>()) || (parent2->is_type<permute>()) ||
|
||||
(parent1->is_type<depth_to_space>()) || (parent2->is_type<gemm>());
|
||||
(parent1->is_type<depth_to_space>()) || (parent1->is_type<space_to_depth>()) || (parent2->is_type<gemm>());
|
||||
|
||||
std::vector<bool> can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 };
|
||||
|
||||
|
@ -36,6 +36,11 @@ layout space_to_depth_inst::calc_output_layout(space_to_depth_node const& node)
|
||||
const size_t block_size = desc->block_size;
|
||||
auto depth_mode = desc->mode;
|
||||
|
||||
auto output_type = input_layout.data_type;
|
||||
if (node.has_fused_primitives()) {
|
||||
output_type = node.get_fused_output_layout().data_type;
|
||||
}
|
||||
|
||||
if (depth_mode != space_to_depth::depth_first && depth_mode != space_to_depth::blocks_first)
|
||||
CLDNN_ERROR_MESSAGE(node.id(),
|
||||
"Invalid mode for spaceToDepth: must be \"blocks_first\" or \"depth_first\" only");
|
||||
@ -52,14 +57,34 @@ layout space_to_depth_inst::calc_output_layout(space_to_depth_node const& node)
|
||||
std::to_string(input_layout.size.spatial[0]) + ", " + std::to_string(input_layout.size.spatial[1]) +
|
||||
" (x, y). Actual block size is " + std::to_string(block_size));
|
||||
|
||||
const size_t feature = input_layout.size.feature[0] * block_size * block_size;
|
||||
const size_t y = input_layout.size.spatial[1] / block_size;
|
||||
const size_t x = input_layout.size.spatial[0] / block_size;
|
||||
|
||||
return layout{
|
||||
input_layout.data_type,
|
||||
input_format,
|
||||
tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y))};
|
||||
if (input_layout.format.dimension() == 5) {
|
||||
if (input_layout.size.spatial[2] % block_size != 0)
|
||||
CLDNN_ERROR_MESSAGE(
|
||||
node.id(),
|
||||
"Sizes of spatials z must be divisible by block size. Actual spatial sizes are " +
|
||||
std::to_string(input_layout.size.spatial[2]) +
|
||||
" (z). Block size is " + std::to_string(block_size));
|
||||
|
||||
const size_t feature = input_layout.size.feature[0] * block_size * block_size * block_size;
|
||||
const size_t z = input_layout.size.spatial[2] / block_size;
|
||||
const size_t y = input_layout.size.spatial[1] / block_size;
|
||||
const size_t x = input_layout.size.spatial[0] / block_size;
|
||||
|
||||
return layout{
|
||||
output_type,
|
||||
input_format,
|
||||
tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y), TensorValue(z))};
|
||||
} else {
|
||||
const size_t feature = input_layout.size.feature[0] * block_size * block_size;
|
||||
const size_t y = input_layout.size.spatial[1] / block_size;
|
||||
const size_t x = input_layout.size.spatial[0] / block_size;
|
||||
|
||||
return layout{
|
||||
output_type,
|
||||
input_format,
|
||||
tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y))};
|
||||
}
|
||||
}
|
||||
|
||||
std::string space_to_depth_inst::to_string(space_to_depth_node const& node) {
|
||||
|
@ -35,6 +35,7 @@
|
||||
#include "api/permute.hpp"
|
||||
#include "api/gather.hpp"
|
||||
#include "api/depth_to_space.hpp"
|
||||
#include "api/space_to_depth.hpp"
|
||||
|
||||
#include "test_utils/test_utils.h"
|
||||
|
||||
@ -4695,6 +4696,145 @@ INSTANTIATE_TEST_CASE_P(
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_I8_2, 2, 5},
|
||||
}), );
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* -------------------------------- SpaceToDepth cases ------------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
struct space_to_depth_params {
|
||||
tensor input_size;
|
||||
tensor output_size;
|
||||
space_to_depth::depth_mode mode;
|
||||
data_types input_type;
|
||||
format input_format;
|
||||
size_t block_size;
|
||||
data_types default_type;
|
||||
format default_format;
|
||||
size_t expected_fused_primitives;
|
||||
size_t expected_not_fused_primitives;
|
||||
};
|
||||
|
||||
#define CASE_SPACE_TO_DEPTH_F32_1 {2, 2, 8, 10}, {2, 8, 4, 5}, space_to_depth::depth_mode::blocks_first, data_types::f32, format::bfyx, 2, data_types::f32, format::bfyx
|
||||
#define CASE_SPACE_TO_DEPTH_F32_2 {1, 2, 6, 6, 6}, {1, 54, 2, 2, 2}, space_to_depth::depth_mode::depth_first, data_types::f32, format::bfzyx, 3, data_types::f32, format::bfyx
|
||||
#define CASE_SPACE_TO_DEPTH_F16_1 {1, 3, 6, 6}, {1, 12, 3, 3}, space_to_depth::depth_mode::blocks_first, data_types::f16, format::bfyx, 2, data_types::f32, format::bfyx
|
||||
#define CASE_SPACE_TO_DEPTH_F16_2 {2, 1, 3, 3}, {2, 9, 1, 1}, space_to_depth::depth_mode::blocks_first, data_types::f16, format::b_fs_yx_fsv16, 3, data_types::f32, format::bfyx
|
||||
#define CASE_SPACE_TO_DEPTH_U8_1 {2, 2, 8, 10}, {2, 8, 4, 5}, space_to_depth::depth_mode::blocks_first, data_types::u8, format::bfyx, 2, data_types::f32, format::bfyx
|
||||
#define CASE_SPACE_TO_DEPTH_U8_2 {1, 2, 6, 6, 6}, {1, 54, 2, 2, 2}, space_to_depth::depth_mode::depth_first, data_types::u8, format::bfzyx, 3, data_types::f32, format::bfyx
|
||||
#define CASE_SPACE_TO_DEPTH_I8_1 {1, 3, 6, 6}, {1, 12, 3, 3}, space_to_depth::depth_mode::blocks_first, data_types::i8, format::bfyx, 2, data_types::f32, format::bfyx
|
||||
#define CASE_SPACE_TO_DEPTH_I8_2 {2, 1, 3, 3}, {2, 9, 1, 1}, space_to_depth::depth_mode::blocks_first, data_types::i8, format::b_fs_yx_fsv16, 3, data_types::f32, format::bfyx
|
||||
|
||||
class SpaceToDepthFusingsTest : public ::BaseFusingTest<space_to_depth_params> {
|
||||
public:
|
||||
void execute(space_to_depth_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
network_not_fused.set_input_data("input", input_prim);
|
||||
|
||||
compare(network_not_fused, network_fused, p);
|
||||
}
|
||||
|
||||
layout get_input_layout(space_to_depth_params& p) { return layout{p.input_type, p.input_format, p.input_size}; }
|
||||
|
||||
layout get_per_channel_layout(space_to_depth_params& p) {
|
||||
return layout{p.default_type, p.default_format, tensor{1, p.output_size.feature[0], 1, 1}};
|
||||
}
|
||||
format get_input_format(space_to_depth_params &p) { return p.input_format; }
|
||||
};
|
||||
|
||||
class space_to_depth_quantize_i8 : public SpaceToDepthFusingsTest {};
|
||||
TEST_P(space_to_depth_quantize_i8, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
space_to_depth("space_to_depth", "input", p.mode, p.block_size),
|
||||
data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_low", get_mem(get_single_element_layout(p), -128)),
|
||||
data("out_high", get_mem(get_single_element_layout(p), 127)),
|
||||
quantize("quant", "space_to_depth", "in_low", "in_high", "out_low", "out_high", 256, data_types::i8),
|
||||
reorder("reorder_bfyx", "quant", format::bfyx, data_types::f32));
|
||||
|
||||
tolerance = 1.f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
fusings_gpu,
|
||||
space_to_depth_quantize_i8,
|
||||
::testing::ValuesIn(std::vector<space_to_depth_params>{
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_1, 2, 3},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_2, 2, 3},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_1, 2, 3},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_2, 2, 3},
|
||||
}), );
|
||||
|
||||
class space_to_depth_scale_act_eltwise_quantize_u8 : public SpaceToDepthFusingsTest {};
|
||||
TEST_P(space_to_depth_scale_act_eltwise_quantize_u8, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
space_to_depth("space_to_depth", "input", p.mode, p.block_size),
|
||||
data("scale1_data", get_mem(get_per_channel_layout(p), -0.125f)),
|
||||
scale("scale1", "space_to_depth", "scale1_data"),
|
||||
activation("actv1", "scale1", activation_func::relu),
|
||||
data("eltw_data", get_mem(layout(p.default_type, p.input_format, p.output_size))),
|
||||
eltwise("eltw", {"actv1", "eltw_data"}, eltwise_mode::sum, p.default_type),
|
||||
data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_low", get_mem(get_single_element_layout(p), 0)),
|
||||
data("out_high", get_mem(get_single_element_layout(p), 255)),
|
||||
quantize("quant", "eltw", "in_low", "in_high", "out_low", "out_high", 256, data_types::u8),
|
||||
reorder("reorder_bfyx", "quant", format::bfyx, data_types::f32));
|
||||
|
||||
tolerance = 1.f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
fusings_gpu,
|
||||
space_to_depth_scale_act_eltwise_quantize_u8,
|
||||
::testing::ValuesIn(std::vector<space_to_depth_params>{
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_1, 2, 6},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_2, 2, 6},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_1, 2, 6},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_2, 2, 6},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_1, 2, 6},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_2, 2, 6},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_1, 2, 6},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_2, 2, 6},
|
||||
}), );
|
||||
|
||||
|
||||
class space_to_depth_scale_act_eltw : public SpaceToDepthFusingsTest {};
|
||||
TEST_P(space_to_depth_scale_act_eltw, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
space_to_depth("space_to_depth", "input", p.mode, p.block_size),
|
||||
data("scale1_data", get_mem(get_per_channel_layout(p), -0.125f)),
|
||||
scale("scale1", "space_to_depth", "scale1_data"),
|
||||
activation("actv1", "scale1", activation_func::relu),
|
||||
data("eltw_data", get_mem(layout(p.default_type, p.input_format, p.output_size))),
|
||||
eltwise("eltw", {"actv1", "eltw_data"}, eltwise_mode::sum, p.default_type),
|
||||
reorder("reorder_bfyx", "eltw", format::bfyx, data_types::f32));
|
||||
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
fusings_gpu,
|
||||
space_to_depth_scale_act_eltw,
|
||||
::testing::ValuesIn(std::vector<space_to_depth_params>{
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_1, 2, 5},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_2, 2, 5},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_1, 2, 5},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_2, 2, 5},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_1, 2, 5},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_2, 2, 5},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_1, 2, 5},
|
||||
space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_2, 2, 5},
|
||||
}), );
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ------------------------------------------ Gather cases --------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
|
@ -811,3 +811,112 @@ TEST(space_to_depth_fp32_gpu, d1199_bs3_mdf) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(space_to_depth_fp32_gpu, d1199_bs3_mdf_fsv16) {
|
||||
// Input : 1x1x9x9
|
||||
// Block size : 3
|
||||
// Output : 1x9x3x3
|
||||
// Input values in fp32
|
||||
|
||||
engine engine;
|
||||
|
||||
auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 1, 9, 9 } });
|
||||
size_t block_size = 3;
|
||||
|
||||
set_values(input1, {
|
||||
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
|
||||
10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f,
|
||||
20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f,
|
||||
30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f,
|
||||
40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f,
|
||||
50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f,
|
||||
60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f,
|
||||
70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f,
|
||||
80.0f
|
||||
});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(reorder("reorder", "Input0", format::b_fs_yx_fsv16, data_types::f32));
|
||||
topology.add(space_to_depth("space_to_depth", "reorder", space_to_depth::depth_first, block_size));
|
||||
topology.add(reorder("reorder_out", "space_to_depth", format::bfyx, data_types::f32));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input1);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("reorder_out").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
0.0f, 3.0f, 6.0f, 27.0f, 30.0f, 33.0f, 54.0f, 57.0f, 60.0f, 1.0f,
|
||||
4.0f, 7.0f, 28.0f, 31.0f, 34.0f, 55.0f, 58.0f, 61.0f, 2.0f, 5.0f,
|
||||
8.0f, 29.0f, 32.0f, 35.0f, 56.0f, 59.0f, 62.0f, 9.0f, 12.0f, 15.0f,
|
||||
36.0f, 39.0f, 42.0f, 63.0f, 66.0f, 69.0f, 10.0f, 13.0f, 16.0f, 37.0f,
|
||||
40.0f, 43.0f, 64.0f, 67.0f, 70.0f, 11.0f, 14.0f, 17.0f, 38.0f, 41.0f,
|
||||
44.0f, 65.0f, 68.0f, 71.0f, 18.0f, 21.0f, 24.0f, 45.0f, 48.0f, 51.0f,
|
||||
72.0f, 75.0f, 78.0f, 19.0f, 22.0f, 25.0f, 46.0f, 49.0f, 52.0f, 73.0f,
|
||||
76.0f, 79.0f, 20.0f, 23.0f, 26.0f, 47.0f, 50.0f, 53.0f, 74.0f, 77.0f,
|
||||
80.0f
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_EQ(expected_results[i], output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(space_to_depth_fp32_gpu, d1199_bs3_mdf_fsv4) {
|
||||
// Input : 1x1x9x9
|
||||
// Block size : 3
|
||||
// Output : 1x9x3x3
|
||||
// Input values in fp32
|
||||
|
||||
engine engine;
|
||||
|
||||
auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 1, 9, 9 } });
|
||||
size_t block_size = 3;
|
||||
|
||||
set_values(input1, {
|
||||
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
|
||||
10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f,
|
||||
20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f,
|
||||
30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f,
|
||||
40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f,
|
||||
50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f,
|
||||
60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f,
|
||||
70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f,
|
||||
80.0f
|
||||
});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(reorder("reorder", "Input0", format::b_fs_yx_fsv4, data_types::f32));
|
||||
topology.add(space_to_depth("space_to_depth", "reorder", space_to_depth::depth_first, block_size));
|
||||
topology.add(reorder("reorder_out", "space_to_depth", format::bfyx, data_types::f32));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input1);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("reorder_out").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
0.0f, 3.0f, 6.0f, 27.0f, 30.0f, 33.0f, 54.0f, 57.0f, 60.0f, 1.0f,
|
||||
4.0f, 7.0f, 28.0f, 31.0f, 34.0f, 55.0f, 58.0f, 61.0f, 2.0f, 5.0f,
|
||||
8.0f, 29.0f, 32.0f, 35.0f, 56.0f, 59.0f, 62.0f, 9.0f, 12.0f, 15.0f,
|
||||
36.0f, 39.0f, 42.0f, 63.0f, 66.0f, 69.0f, 10.0f, 13.0f, 16.0f, 37.0f,
|
||||
40.0f, 43.0f, 64.0f, 67.0f, 70.0f, 11.0f, 14.0f, 17.0f, 38.0f, 41.0f,
|
||||
44.0f, 65.0f, 68.0f, 71.0f, 18.0f, 21.0f, 24.0f, 45.0f, 48.0f, 51.0f,
|
||||
72.0f, 75.0f, 78.0f, 19.0f, 22.0f, 25.0f, 46.0f, 49.0f, 52.0f, 73.0f,
|
||||
76.0f, 79.0f, 20.0f, 23.0f, 26.0f, 47.0f, 50.0f, 53.0f, 74.0f, 77.0f,
|
||||
80.0f
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_EQ(expected_results[i], output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user