[IE CLDNN] Enable DepthToSpace (#780)
Enabled DepthToSpace ngraph transformat Updated implementation to support 5d and mode parameter fsv16 direct support Functional tests for GPU
This commit is contained in:
parent
807f85f93f
commit
0022eebd71
@ -74,6 +74,13 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneNetwork(const InferenceEngin
|
||||
std::shared_ptr<ICNNNetwork> clonedNetwork(nullptr);
|
||||
if (network.getFunction()) {
|
||||
const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
// DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
|
||||
// Reshape->Permute->Reshape pattern in theory can change output rank, so this check is added to be sure
|
||||
// that DepthToSpace impl will handle fused case
|
||||
if (auto dtsOp = std::dynamic_pointer_cast<const ::ngraph::opset3::DepthToSpace>(node)) {
|
||||
return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size();
|
||||
}
|
||||
|
||||
return std::dynamic_pointer_cast<const ::ngraph::opset2::Gelu>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset3::ShuffleChannels>(node);
|
||||
};
|
||||
|
@ -3770,11 +3770,12 @@ void Program::CreateDepthToSpacePrimitive(cldnn::topology& topology, InferenceEn
|
||||
auto depthToSpace = as<InferenceEngine::GenericLayer*> (layer);
|
||||
|
||||
size_t blockSize = static_cast<size_t>(depthToSpace->GetParamAsUInt("block_size", 2));
|
||||
std::string mode_s = depthToSpace->GetParamAsString("mode");
|
||||
|
||||
cldnn::depth_to_space_mode mode = mode_s == "depth_first" ? cldnn::depth_to_space_mode::depth_first
|
||||
: cldnn::depth_to_space_mode::blocks_first;
|
||||
|
||||
auto inputDim = depthToSpace->input().get()->getTensorDesc().getDims();
|
||||
if (inputDim.size() != 4)
|
||||
THROW_CLDNN_EXCEPTION("Unsupported size of tensor " << inputDim.size());
|
||||
|
||||
size_t blockSizeSquare = blockSize * blockSize;
|
||||
|
||||
if (inputDim[1] % blockSizeSquare != 0)
|
||||
@ -3784,7 +3785,8 @@ void Program::CreateDepthToSpacePrimitive(cldnn::topology& topology, InferenceEn
|
||||
auto depthToSpacePrim = cldnn::depth_to_space(
|
||||
depthToSpaceName,
|
||||
inputPrimitives[0],
|
||||
blockSize);
|
||||
blockSize,
|
||||
mode);
|
||||
|
||||
topology.add(depthToSpacePrim);
|
||||
AddPrimitiveToProfiler(depthToSpaceName, layer);
|
||||
|
@ -0,0 +1,53 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
|
||||
#include "single_layer_tests/depth_to_space.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
using namespace ngraph::opset3;
|
||||
|
||||
namespace {
|
||||
const std::vector<InferenceEngine::Precision> inputPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::U8,
|
||||
InferenceEngine::Precision::I16,
|
||||
};
|
||||
|
||||
const std::vector<DepthToSpace::DepthToSpaceMode> modes = {
|
||||
DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST,
|
||||
DepthToSpace::DepthToSpaceMode::DEPTH_FIRST};
|
||||
|
||||
const std::vector<std::vector<size_t >> inputShapesBS2 = {
|
||||
{1, 4, 1, 1}, {1, 4, 2, 2}, {1, 4, 3, 3}, {2, 32, 3, 3}, {2, 16, 5, 4},
|
||||
{1, 8, 1, 1, 1}, {1, 8, 2, 2, 2}, {1, 8, 3, 3, 3}, {2, 32, 3, 3, 3}, {2, 16, 5, 4, 6}};
|
||||
|
||||
const auto DepthToSpaceBS2 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesBS2),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(modes),
|
||||
::testing::Values(2),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(DepthToSpaceBS2, DepthToSpaceLayerTest, DepthToSpaceBS2, DepthToSpaceLayerTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<size_t >> inputShapesBS3 = {
|
||||
{1, 9, 1, 1}, {1, 9, 2, 2}, {1, 9, 3, 3}, {2, 36, 3, 3}, {2, 27, 5, 4},
|
||||
{1, 27, 1, 1, 1}, {1, 27, 2, 2, 2}, {1, 27, 3, 3, 3}, {2, 108, 3, 3, 3}, {2, 54, 5, 4, 6}};
|
||||
|
||||
const auto DepthToSpaceBS3 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesBS3),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(modes),
|
||||
::testing::Values(3),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(DepthToSpaceBS3, DepthToSpaceLayerTest, DepthToSpaceBS3, DepthToSpaceLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -1,17 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "depth_to_space_tests.hpp"
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_GPU_TestsDepthToSpace, DepthToSpaceTests,
|
||||
::testing::Values(
|
||||
depth_to_space_test_params{ "GPU", "FP32", { 1, 4, 1, 1 }, 2, { 1, 1, 2, 2 } },
|
||||
depth_to_space_test_params{ "GPU", "FP32", { 1, 4, 2, 1 }, 2, { 1, 1, 4, 2 } },
|
||||
depth_to_space_test_params{ "GPU", "FP32", { 1, 4, 2, 2 }, 2, { 1, 1, 4, 4 } },
|
||||
depth_to_space_test_params{ "GPU", "FP32", { 1, 4, 3, 2 }, 2, { 1, 1, 6, 4 } },
|
||||
depth_to_space_test_params{ "GPU", "FP32", { 1, 9, 3, 3 }, 3, { 1, 1, 9, 9 } },
|
||||
depth_to_space_test_params{ "GPU", "FP32", { 1, 18, 3, 3 }, 3, { 1, 2, 9, 9 } },
|
||||
depth_to_space_test_params{ "GPU", "FP32", { 1, 4, 2048, 512 }, 2, { 1, 1, 4096, 1024 } }
|
||||
));
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
// Copyright (c) 2019-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -26,6 +26,14 @@ namespace cldnn {
|
||||
/// @addtogroup cpp_primitives Primitives
|
||||
/// @{
|
||||
|
||||
/// @brief mode for the @ref depth_to_space primitive.
|
||||
enum class depth_to_space_mode : int32_t {
|
||||
/// @brief the input depth is divided to [block_size, ..., block_size, new_depth].
|
||||
blocks_first,
|
||||
/// @brief the input depth is divided to [new_depth, block_size, ..., block_size]
|
||||
depth_first
|
||||
};
|
||||
|
||||
/// @brief
|
||||
/// @details
|
||||
struct depth_to_space : public primitive_base<depth_to_space> {
|
||||
@ -35,14 +43,20 @@ struct depth_to_space : public primitive_base<depth_to_space> {
|
||||
/// @param id This primitive id.
|
||||
/// @param input Input dictionary primitive id.
|
||||
/// @param block_size Block size.
|
||||
/// @param mode Depth division mode.
|
||||
depth_to_space(const primitive_id& id,
|
||||
const primitive_id& input,
|
||||
const size_t block_size,
|
||||
const depth_to_space_mode mode,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input}, output_padding), block_size(block_size) {}
|
||||
: primitive_base(id, {input}, output_padding)
|
||||
, block_size(block_size)
|
||||
, mode(mode) {}
|
||||
|
||||
/// @brief Block size.
|
||||
size_t block_size;
|
||||
/// @brief depth division mode
|
||||
depth_to_space_mode mode;
|
||||
};
|
||||
/// @}
|
||||
/// @}
|
||||
|
@ -371,6 +371,14 @@ enum class TileAxis {
|
||||
BATCH,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// DepthToSpaceMode
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
enum class DepthToSpaceMode {
|
||||
BLOCKS_FIRST,
|
||||
DEPTH_FIRST,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ResampleType
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -27,6 +27,15 @@ bool DepthToSpaceKernelBase::Validate(const Params& p, const optional_params& o)
|
||||
return false;
|
||||
}
|
||||
|
||||
const depth_to_space_params& params = static_cast<const depth_to_space_params&>(p);
|
||||
for (auto& fused_op : params.fused_ops) {
|
||||
if (!IsFusedPrimitiveSupported(fused_op))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params.inputs[0].Dimentions() > 5)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -35,7 +44,7 @@ CommonDispatchData DepthToSpaceKernelBase::SetDefault(const depth_to_space_param
|
||||
|
||||
std::vector<size_t> global = { params.output.Batch().v,
|
||||
params.output.Feature().v,
|
||||
params.output.Y().v * params.output.X().v };
|
||||
params.output.Z().v * params.output.Y().v * params.output.X().v };
|
||||
|
||||
auto local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
|
||||
|
||||
@ -54,6 +63,11 @@ JitConstants DepthToSpaceKernelBase::GetJitConstants(const depth_to_space_params
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
|
||||
jit.AddConstant(MakeJitConstant("BLOCK_SIZE", params.block_size));
|
||||
if (params.mode == DepthToSpaceMode::BLOCKS_FIRST) {
|
||||
jit.AddConstant(MakeJitConstant("BLOCKS_FIRST", 1));
|
||||
} else {
|
||||
jit.AddConstant(MakeJitConstant("DEPTH_FIRST", 1));
|
||||
}
|
||||
|
||||
return jit;
|
||||
}
|
||||
@ -73,7 +87,8 @@ KernelsData DepthToSpaceKernelBase::GetCommonKernelsData(const Params& params, c
|
||||
|
||||
auto& kernel = kd.kernels[0];
|
||||
|
||||
FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
|
||||
FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point,
|
||||
DEFAULT, false, false, 1, GetFusedPrimitiveInputsCount(params));
|
||||
|
||||
kd.estimatedTime = estimatedTime;
|
||||
|
||||
|
@ -24,8 +24,12 @@ namespace kernel_selector {
|
||||
// depth_to_space_params
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct depth_to_space_params : public base_params {
|
||||
depth_to_space_params() : base_params(KernelType::DEPTH_TO_SPACE), block_size(0) {}
|
||||
depth_to_space_params()
|
||||
: base_params(KernelType::DEPTH_TO_SPACE)
|
||||
, block_size(0)
|
||||
, mode(DepthToSpaceMode::DEPTH_FIRST) {}
|
||||
size_t block_size;
|
||||
DepthToSpaceMode mode;
|
||||
|
||||
virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); }
|
||||
};
|
||||
|
@ -38,6 +38,9 @@ bool DepthToSpaceKernelBlock2Opt::Validate(const Params& p, const optional_param
|
||||
if ((params.block_size != 2) || (params.inputs[0].X().v % 2 != 0))
|
||||
return false;
|
||||
|
||||
if (params.mode != DepthToSpaceMode::BLOCKS_FIRST)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -20,12 +20,18 @@
|
||||
#include <vector>
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
ParamsKey DepthToSpaceKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::UINT8);
|
||||
k.EnableInputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableAllInputLayout();
|
||||
k.EnableAllOutputLayout();
|
||||
k.EnableTensorOffset();
|
||||
@ -37,4 +43,24 @@ ParamsKey DepthToSpaceKernelRef::GetSupportedKey() const {
|
||||
KernelsData DepthToSpaceKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||
return GetCommonKernelsData(params, options, FORCE_PRIORITY_9);
|
||||
}
|
||||
|
||||
JitConstants DepthToSpaceKernelRef::GetJitConstants(const depth_to_space_params& params) const {
|
||||
auto jit = Parent::GetJitConstants(params);
|
||||
auto input = params.inputs[0];
|
||||
auto input_dt = input.GetDType();
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
std::vector<std::string> idx_order;
|
||||
if (input.Dimentions() == 5) {
|
||||
idx_order = {"batch", "feature", "z", "y", "x"};
|
||||
} else if (input.Dimentions() == 4) {
|
||||
idx_order = {"batch", "feature", "y", "x"};
|
||||
}
|
||||
FusedOpsConfiguration conf = {"", idx_order, "in_val", input_dt, 1};
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
|
||||
}
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
} // namespace kernel_selector
|
||||
|
@ -21,10 +21,21 @@
|
||||
namespace kernel_selector {
|
||||
class DepthToSpaceKernelRef : public DepthToSpaceKernelBase {
|
||||
public:
|
||||
using Parent = DepthToSpaceKernelBase;
|
||||
|
||||
DepthToSpaceKernelRef() : DepthToSpaceKernelBase("depth_to_space_ref") {}
|
||||
virtual ~DepthToSpaceKernelRef() {}
|
||||
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
|
||||
protected:
|
||||
JitConstants GetJitConstants(const depth_to_space_params& params) const override;
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::ELTWISE,
|
||||
FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ACTIVATION };
|
||||
}
|
||||
};
|
||||
} // namespace kernel_selector
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
// Copyright (c) 2019-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -15,22 +15,59 @@
|
||||
|
||||
#include "include/include_all.cl"
|
||||
|
||||
KERNEL(depth_to_space_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output)
|
||||
KERNEL(depth_to_space_ref)(const __global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
, FUSED_OPS_DECLS
|
||||
#endif
|
||||
)
|
||||
{
|
||||
const uint batch = get_global_id(0);
|
||||
const uint feature = get_global_id(1);
|
||||
const uint y = (uint)get_global_id(2) / OUTPUT_SIZE_X;
|
||||
#if OUTPUT_DIMS == 5
|
||||
const uint z = (uint)get_global_id(2) / OUTPUT_SIZE_X / OUTPUT_SIZE_Y;
|
||||
const uint y = ((uint)get_global_id(2) / OUTPUT_SIZE_X) % OUTPUT_SIZE_Y;
|
||||
const uint x = (uint)get_global_id(2) % OUTPUT_SIZE_X;
|
||||
|
||||
const uint input_z = z / BLOCK_SIZE;
|
||||
const uint offset_z = z % BLOCK_SIZE;
|
||||
#else
|
||||
const uint y = (uint)get_global_id(2) / OUTPUT_SIZE_X;
|
||||
const uint x = (uint)get_global_id(2) % OUTPUT_SIZE_X;
|
||||
#endif
|
||||
const uint input_y = y / BLOCK_SIZE;
|
||||
const uint offset_y = y % BLOCK_SIZE;
|
||||
|
||||
const uint input_x = x / BLOCK_SIZE;
|
||||
const uint offset_x = (x % BLOCK_SIZE);
|
||||
const uint offset_feature = (offset_y * BLOCK_SIZE + offset_x) * OUTPUT_FEATURE_NUM;
|
||||
|
||||
const uint output_index = OUTPUT_OFFSET + (batch * OUTPUT_BATCH_PITCH) + (feature * OUTPUT_FEATURE_PITCH) + (y * OUTPUT_Y_PITCH) + x;
|
||||
#if OUTPUT_DIMS == 5
|
||||
#if BLOCKS_FIRST
|
||||
const uint offset_feature = (offset_z*BLOCK_SIZE*BLOCK_SIZE + offset_y * BLOCK_SIZE + offset_x) * OUTPUT_FEATURE_NUM;
|
||||
const uint input_feature = feature + offset_feature;
|
||||
const uint input_index = INPUT0_OFFSET + (batch * INPUT0_BATCH_PITCH) + (input_feature * INPUT0_FEATURE_PITCH) + (input_y * INPUT0_Y_PITCH) + input_x;
|
||||
output[output_index] = ACTIVATION(input[input_index], ACTIVATION_PARAMS);
|
||||
#else // BLOCKS_FIRST
|
||||
const uint offset_feature = (offset_z*BLOCK_SIZE*BLOCK_SIZE + offset_y * BLOCK_SIZE + offset_x);
|
||||
const uint input_feature = feature * BLOCK_SIZE * BLOCK_SIZE * BLOCK_SIZE + offset_feature;
|
||||
#endif // BLOCKS_FIRST
|
||||
const uint output_index = OUTPUT_GET_INDEX(batch, feature, z, y, x);
|
||||
const uint input_index = INPUT0_GET_INDEX(batch, input_feature, input_z, input_y, input_x);
|
||||
#else
|
||||
#if BLOCKS_FIRST
|
||||
const uint offset_feature = (offset_y * BLOCK_SIZE + offset_x) * OUTPUT_FEATURE_NUM;
|
||||
const uint input_feature = feature + offset_feature;
|
||||
#else //BLOCKS_FIRST
|
||||
const uint offset_feature = (offset_y * BLOCK_SIZE + offset_x);
|
||||
const uint input_feature = feature * BLOCK_SIZE * BLOCK_SIZE + offset_feature;
|
||||
#endif // BLOCKS_FIRST
|
||||
const uint output_index = OUTPUT_GET_INDEX(batch, feature, y, x);
|
||||
const uint input_index = INPUT0_GET_INDEX(batch, input_feature, input_y, input_x);
|
||||
#endif
|
||||
|
||||
INPUT0_TYPE in_val = input[input_index];
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS;
|
||||
output[output_index] = FUSED_OPS_RESULT;
|
||||
#else
|
||||
output[output_index] = ACTIVATION(in_val, ACTIVATION_PARAMS);
|
||||
#endif
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
// Copyright (c) 2019-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -46,14 +46,25 @@ layout depth_to_space_inst::calc_output_layout(depth_to_space_node const& node)
|
||||
"The depth of the input tensor must be divisible by squared block size. Actual block size is " +
|
||||
std::to_string(block_size));
|
||||
|
||||
const size_t feature = input_layout.size.feature[0] / block_size / block_size;
|
||||
const size_t y = input_layout.size.spatial[1] * block_size;
|
||||
const size_t x = input_layout.size.spatial[0] * block_size;
|
||||
auto out_size = input_layout.size;
|
||||
if (format::spatial_num(input_layout.format) == 3) {
|
||||
const size_t feature = input_layout.size.feature[0] / block_size / block_size / block_size;
|
||||
const size_t z = input_layout.size.spatial[2] * block_size;
|
||||
const size_t y = input_layout.size.spatial[1] * block_size;
|
||||
const size_t x = input_layout.size.spatial[0] * block_size;
|
||||
out_size = tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y), TensorValue(z));
|
||||
} else {
|
||||
const size_t feature = input_layout.size.feature[0] / block_size / block_size;
|
||||
const size_t y = input_layout.size.spatial[1] * block_size;
|
||||
const size_t x = input_layout.size.spatial[0] * block_size;
|
||||
out_size = tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y));
|
||||
}
|
||||
|
||||
return layout{
|
||||
input_layout.data_type,
|
||||
input_format,
|
||||
tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y))};
|
||||
if (node.has_fused_primitives()) {
|
||||
input_layout.data_type = node.get_fused_output_layout().data_type;
|
||||
}
|
||||
|
||||
return layout{input_layout.data_type, input_format, out_size};
|
||||
}
|
||||
|
||||
std::string depth_to_space_inst::to_string(depth_to_space_node const& node) {
|
||||
@ -66,6 +77,7 @@ std::string depth_to_space_inst::to_string(depth_to_space_node const& node) {
|
||||
json_composite depth_to_space_info;
|
||||
depth_to_space_info.add("input id", input.id());
|
||||
depth_to_space_info.add("block size", desc->block_size);
|
||||
depth_to_space_info.add("mode", desc->mode == depth_to_space_mode::blocks_first ? "blocks_first" : "depth_first");
|
||||
|
||||
node_info->add("depth_to_space info", depth_to_space_info);
|
||||
node_info->dump(primitive_description);
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
// Copyright (c) 2019-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -21,6 +21,7 @@
|
||||
#include "depth_to_space/depth_to_space_kernel_selector.h"
|
||||
#include "depth_to_space/depth_to_space_kernel_ref.h"
|
||||
#include "error_handler.h"
|
||||
#include "common_types.h"
|
||||
|
||||
using namespace cldnn;
|
||||
|
||||
@ -37,6 +38,8 @@ public:
|
||||
get_default_optional_params<kernel_selector::depth_to_space_optional_params>(arg.get_program());
|
||||
|
||||
depth_to_space_params.block_size = arg.get_primitive()->block_size;
|
||||
depth_to_space_params.mode = arg.get_primitive()->mode == depth_to_space_mode::blocks_first ? kernel_selector::depth_to_space_mode::BLOCKS_FIRST
|
||||
: kernel_selector::depth_to_space_mode::DEPTH_FIRST;
|
||||
|
||||
auto& kernel_selector = kernel_selector::depth_to_space_kernel_selector::Instance();
|
||||
auto best_kernels = kernel_selector.GetBestKernels(depth_to_space_params, depth_to_space_optional_params);
|
||||
@ -56,10 +59,18 @@ namespace detail {
|
||||
|
||||
attach_depth_to_space_gpu::attach_depth_to_space_gpu() {
|
||||
auto val_fw = depth_to_space_gpu::create;
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx),
|
||||
val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx),
|
||||
val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfzyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfzyx), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv16), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv16), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv16), val_fw);
|
||||
implementation_map<depth_to_space>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv16), val_fw);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
// Copyright (c) 2019-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -117,65 +117,6 @@ void graph_initializations::replace_nodes(program_impl& p) {
|
||||
p.nodes_map.erase(node->id());
|
||||
continue;
|
||||
}
|
||||
|
||||
// find sequence reshape->permute->reshape and exchange with depth to space
|
||||
if (node->is_type<reshape>()) {
|
||||
if (!p.get_options().get<build_option_type::optimize_data>()->enabled())
|
||||
continue;
|
||||
|
||||
if (node->get_users().size() == 0)
|
||||
continue;
|
||||
|
||||
auto& input_node = node->get_dependency(0);
|
||||
if (!(node->get_users().front()->is_type<permute>()) || !(input_node.is_type<reorder>()))
|
||||
continue;
|
||||
|
||||
auto input_node_layout = input_node.get_output_layout();
|
||||
if (input_node_layout.format != format::bfwzyx || input_node_layout.data_type != data_types::f16)
|
||||
continue;
|
||||
|
||||
// optimal implementation only for depth to space block size 2
|
||||
auto reshape1_layout = node->get_output_layout();
|
||||
if (reshape1_layout.size.spatial[3] != 2)
|
||||
continue;
|
||||
|
||||
auto permute_prim = node->get_users().front()->as<permute>().typed_desc();
|
||||
primitive_id permute_id = node->get_users().front()->id();
|
||||
auto& permute_node = node->get_users().front();
|
||||
|
||||
auto reshape1_prim = node->as<reshape>().typed_desc();
|
||||
primitive_id reshape1_id = node->id();
|
||||
|
||||
p.remove_connection(*node, *permute_node);
|
||||
|
||||
auto perm_node_ptr = p.nodes_map.find(permute_id)->second;
|
||||
auto perm_node = &perm_node_ptr->as<permute>();
|
||||
|
||||
auto rename_id = permute_id + "_tmp";
|
||||
p.rename(*perm_node, rename_id);
|
||||
|
||||
auto reorder_id = input_node.id() + "_reorder_for_depth_to_space";
|
||||
auto reorder_prim = std::make_shared<reorder>(reorder_id, input_node.id(), format::bfyx, input_node_layout.data_type);
|
||||
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(permute_id, reorder_id, 2);
|
||||
|
||||
p.get_or_create(reorder_prim);
|
||||
p.get_or_create(pixel_shuffle_prim);
|
||||
auto reorder_depth_node_ptr = p.nodes_map.find(reorder_id)->second;
|
||||
auto pixel_shuffle_node_ptr = p.nodes_map.find(permute_id)->second;
|
||||
p.add_connection(input_node, *reorder_depth_node_ptr);
|
||||
p.add_connection(*reorder_depth_node_ptr, *pixel_shuffle_node_ptr);
|
||||
|
||||
auto deconv_node_ptr = p.nodes_map.find(rename_id)->second;
|
||||
p.replace_all_usages(*deconv_node_ptr, *pixel_shuffle_node_ptr);
|
||||
p.optimized_out.push_back(rename_id);
|
||||
p.nodes_map.erase(rename_id);
|
||||
|
||||
p.remove_connection(input_node, *node);
|
||||
p.replace_all_usages(*node, input_node);
|
||||
p.optimized_out.push_back(reshape1_id);
|
||||
p.nodes_map.erase(reshape1_id);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2018-2019 Intel Corporation
|
||||
// Copyright (c) 2018-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -324,7 +324,7 @@ void pre_replace_deconv::run(program_impl& p) {
|
||||
p.inputs.push_back(weights_node_conv_rpl_ptr.get());
|
||||
}
|
||||
|
||||
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_id, deconv_id_conv, 2);
|
||||
auto pixel_shuffle_prim = std::make_shared<depth_to_space>(deconv_id, deconv_id_conv, 2, depth_to_space_mode::blocks_first);
|
||||
|
||||
p.get_or_create(pixel_shuffle_prim);
|
||||
auto pixel_shuffle_node_ptr = p.nodes_map.find(deconv_id)->second;
|
||||
|
@ -364,7 +364,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
input_data.get_dependency(0).get_output_layout().data_type == data_types::i8);
|
||||
|
||||
should_fuse |= input_data.is_type<deconvolution>();
|
||||
|
||||
|
||||
should_fuse |= input_data.is_type<permute>();
|
||||
|
||||
should_fuse |= input_data.is_type<activation>();
|
||||
@ -373,6 +373,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather>();
|
||||
|
||||
should_fuse |= input_data.is_type<depth_to_space>();
|
||||
|
||||
if (!should_fuse)
|
||||
return;
|
||||
|
||||
@ -407,7 +409,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
input_data.get_dependency(0).get_output_layout().data_type == data_types::i8);
|
||||
|
||||
should_fuse |= input_data.is_type<deconvolution>();
|
||||
|
||||
|
||||
should_fuse |= input_data.is_type<permute>();
|
||||
|
||||
should_fuse |= input_data.is_type<activation>();
|
||||
@ -416,6 +418,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather>();
|
||||
|
||||
should_fuse |= input_data.is_type<depth_to_space>();
|
||||
|
||||
if (!should_fuse)
|
||||
return;
|
||||
|
||||
@ -485,11 +489,12 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
(input_data.get_dependency(0).get_output_layout().data_type == data_types::u8 ||
|
||||
input_data.get_dependency(0).get_output_layout().data_type == data_types::i8 ||
|
||||
input_data.get_output_layout().data_type == out_layout.data_type);
|
||||
|
||||
|
||||
should_fuse |= input_data.is_type<gather>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<permute>() &&
|
||||
quantize_node.get_scale_shift_opt();
|
||||
should_fuse |= input_data.is_type<permute>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<depth_to_space>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
if (!should_fuse)
|
||||
return;
|
||||
@ -512,12 +517,12 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
bool can_fuse_parent1 = (parent1->is_type<convolution>() && conv_supports_fusings(parent1->as<convolution>())) ||
|
||||
(parent1->is_type<mvn>() && mvn_supports_fusings(parent1->as<mvn>())) ||
|
||||
(parent1->is_type<deconvolution>()) || (parent1->is_type<permute>()) ||
|
||||
(parent1->is_type<gemm>());
|
||||
(parent1->is_type<depth_to_space>()) || (parent1->is_type<gemm>());
|
||||
|
||||
bool can_fuse_parent2 = (parent2->is_type<convolution>() && conv_supports_fusings(parent2->as<convolution>())) ||
|
||||
(parent2->is_type<mvn>() && mvn_supports_fusings(parent2->as<mvn>())) ||
|
||||
(parent2->is_type<deconvolution>()) || (parent2->is_type<permute>()) ||
|
||||
(parent2->is_type<gemm>());
|
||||
(parent1->is_type<depth_to_space>()) || (parent2->is_type<gemm>());
|
||||
|
||||
std::vector<bool> can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 };
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2016-2018 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -80,6 +80,7 @@ using border_type = kernel_selector::BorderType;
|
||||
using gather_axis = kernel_selector::GatherAxis;
|
||||
using reduce_mode = kernel_selector::ReduceMode;
|
||||
using cum_sum_axis = kernel_selector::CumSumAxis;
|
||||
using depth_to_space_mode = kernel_selector::DepthToSpaceMode;
|
||||
|
||||
using data_tensor = kernel_selector::DataTensor;
|
||||
using weights_tensor = kernel_selector::WeightsTensor;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2019 Intel Corporation
|
||||
// Copyright (c) 2019-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -49,7 +49,7 @@ TEST(depth_to_space_fp16_gpu, d1411_bs2) {
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size)
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::blocks_first)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
@ -91,7 +91,7 @@ TEST(depth_to_space_fp16_gpu, d1421_bs2) {
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size)
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::blocks_first)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
@ -146,7 +146,7 @@ TEST(depth_to_space_fp16_gpu, d1933_bs3) {
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size)
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::blocks_first)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
@ -193,7 +193,7 @@ TEST(depth_to_space_fp32_gpu, d1411_bs2) {
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size)
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::blocks_first)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
@ -232,7 +232,7 @@ TEST(depth_to_space_fp32_gpu, d112960540_bs2) {
|
||||
topology topology_act;
|
||||
topology_act.add(input_layout("Input0", input1.get_layout()));
|
||||
topology_act.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size)
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::blocks_first)
|
||||
);
|
||||
|
||||
network network_act(engine, topology_act);
|
||||
@ -302,7 +302,7 @@ TEST(depth_to_space_fp32_gpu, d1933_bs3) {
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size)
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::blocks_first)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
@ -330,3 +330,103 @@ TEST(depth_to_space_fp32_gpu, d1933_bs3) {
|
||||
EXPECT_EQ(expected_results[i], output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(depth_to_space_fp32_gpu, d1822_bs2_blocks_first) {
|
||||
// Input : 1x8x2x2
|
||||
// Block size : 2
|
||||
// Output : 1x2x4x4
|
||||
// Input values in fp32
|
||||
|
||||
engine engine;
|
||||
|
||||
auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 8, 2, 2 } });
|
||||
size_t block_size = 2;
|
||||
|
||||
set_values(input1, {
|
||||
0.0f, 1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f,
|
||||
8.0f, 9.0f, 10.0f, 11.0f,
|
||||
12.0f, 13.0f, 14.0f, 15.0f,
|
||||
16.0f, 17.0f, 18.0f, 19.0f,
|
||||
20.0f, 21.0f, 22.0f, 23.0f,
|
||||
24.0f, 25.0f, 26.0f, 27.0f,
|
||||
28.0f, 29.0f, 30.0f, 31.0f
|
||||
});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::blocks_first)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input1);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("depth_to_space").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
0.0f, 8.0f, 1.0f, 9.0f, 16.0f, 24.0f, 17.0f, 25.0f,
|
||||
2.0f, 10.0f, 3.0f, 11.0f, 18.0f, 26.0f, 19.0f, 27.0f,
|
||||
4.0f, 12.0f, 5.0f, 13.0f, 20.0f, 28.0f, 21.0f, 29.0f,
|
||||
6.0f, 14.0f, 7.0f, 15.0f, 22.0f, 30.0f, 23.0f, 31.0f
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_EQ(expected_results[i], output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(depth_to_space_fp32_gpu, d1822_bs2_depth_first) {
|
||||
// Input : 1x8x2x2
|
||||
// Block size : 2
|
||||
// Output : 1x2x4x4
|
||||
// Input values in fp32
|
||||
|
||||
engine engine;
|
||||
|
||||
auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 8, 2, 2 } });
|
||||
size_t block_size = 2;
|
||||
|
||||
set_values(input1, {
|
||||
0.0f, 1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f,
|
||||
8.0f, 9.0f, 10.0f, 11.0f,
|
||||
12.0f, 13.0f, 14.0f, 15.0f,
|
||||
16.0f, 17.0f, 18.0f, 19.0f,
|
||||
20.0f, 21.0f, 22.0f, 23.0f,
|
||||
24.0f, 25.0f, 26.0f, 27.0f,
|
||||
28.0f, 29.0f, 30.0f, 31.0f
|
||||
});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("Input0", input1.get_layout()));
|
||||
topology.add(
|
||||
depth_to_space("depth_to_space", "Input0", block_size, depth_to_space_mode::depth_first)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("Input0", input1);
|
||||
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("depth_to_space").get_memory();
|
||||
auto output_ptr = output.pointer<float>();
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
0.0f, 4.0f, 1.0f, 5.0f, 8.0f, 12.0f, 9.0f, 13.0f,
|
||||
2.0f, 6.0f, 3.0f, 7.0f, 10.0f, 14.0f, 11.0f, 15.0f,
|
||||
16.0f, 20.0f, 17.0f, 21.0f, 24.0f, 28.0f, 25.0f, 29.0f,
|
||||
18.0f, 22.0f, 19.0f, 23.0f, 26.0f, 30.0f, 27.0f, 31.0f
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_EQ(expected_results[i], output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2016 Intel Corporation
|
||||
// Copyright (c) 2016-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -102,7 +102,7 @@ TEST(fused_conv_eltwise, basic_image2d)
|
||||
input_layout("input2", input2.get_layout()),
|
||||
data("weights", weights),
|
||||
convolution("conv", "input", { "weights" }),
|
||||
depth_to_space("depth_to_space", "conv", 2),
|
||||
depth_to_space("depth_to_space", "conv", 2, depth_to_space_mode::blocks_first),
|
||||
eltwise("eltwise", "input2", "depth_to_space", eltwise_mode::sum)
|
||||
);
|
||||
|
||||
@ -125,7 +125,7 @@ TEST(fused_conv_eltwise, basic_image2d)
|
||||
input_layout("input2", input2.get_layout()),
|
||||
data("weights", weights),
|
||||
convolution("conv", "input", { "weights" }),
|
||||
depth_to_space("depth_to_space", "conv", 2),
|
||||
depth_to_space("depth_to_space", "conv", 2, depth_to_space_mode::blocks_first),
|
||||
eltwise("eltwise", "input2", "depth_to_space", eltwise_mode::sum),
|
||||
reorder("out", "eltwise", format::image_2d_rgba, data_types::u8));
|
||||
|
||||
|
@ -34,6 +34,7 @@
|
||||
#include "api/deconvolution.hpp"
|
||||
#include "api/permute.hpp"
|
||||
#include "api/gather.hpp"
|
||||
#include "api/depth_to_space.hpp"
|
||||
|
||||
#include "test_utils/test_utils.h"
|
||||
|
||||
@ -4097,6 +4098,145 @@ INSTANTIATE_TEST_CASE_P(DISABLED_fusings_gpu,
|
||||
pooling_test_params{CASE_POOLING_I8_3, 2, 4, pooling_mode::average, "pooling_gpu_fs_bs_yx_bsv4_fsv32"},
|
||||
}), );
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* -------------------------------- DepthToSpace cases ------------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
struct depth_to_space_test_params {
|
||||
tensor input_size;
|
||||
tensor output_size;
|
||||
depth_to_space_mode mode;
|
||||
data_types input_type;
|
||||
format input_format;
|
||||
size_t block_size;
|
||||
data_types default_type;
|
||||
format default_format;
|
||||
size_t expected_fused_primitives;
|
||||
size_t expected_not_fused_primitives;
|
||||
};
|
||||
|
||||
#define CASE_DEPTH_TO_SPACE_F32_1 {1, 16, 8, 10}, {1, 4, 16, 20}, depth_to_space_mode::blocks_first, data_types::f32, format::bfyx, 2, data_types::f32, format::bfyx
|
||||
#define CASE_DEPTH_TO_SPACE_F32_2 {1, 32, 8, 8}, {1, 2, 32, 32}, depth_to_space_mode::blocks_first, data_types::f32, format::b_fs_yx_fsv16, 4, data_types::f32, format::bfyx
|
||||
#define CASE_DEPTH_TO_SPACE_F16_1 {1, 12, 8, 8}, {1, 3, 16, 16}, depth_to_space_mode::blocks_first, data_types::f16, format::bfyx, 2, data_types::f32, format::bfyx
|
||||
#define CASE_DEPTH_TO_SPACE_F16_2 {1, 16, 9, 8}, {1, 1, 36, 32}, depth_to_space_mode::blocks_first, data_types::f16, format::b_fs_yx_fsv16, 4, data_types::f32, format::bfyx
|
||||
#define CASE_DEPTH_TO_SPACE_U8_1 {1, 128, 8, 8}, {1, 2, 64, 64}, depth_to_space_mode::blocks_first, data_types::u8, format::bfyx, 8, data_types::f32, format::bfyx
|
||||
#define CASE_DEPTH_TO_SPACE_U8_2 {1, 128, 4, 8}, {1, 8, 16, 32}, depth_to_space_mode::blocks_first, data_types::u8, format::b_fs_yx_fsv16, 4, data_types::f32, format::bfyx
|
||||
#define CASE_DEPTH_TO_SPACE_I8_1 {1, 16, 8, 8}, {1, 4, 16, 16}, depth_to_space_mode::blocks_first, data_types::i8, format::bfyx, 2, data_types::f32, format::bfyx
|
||||
#define CASE_DEPTH_TO_SPACE_I8_2 {1, 256, 8, 8}, {1, 4, 64, 64}, depth_to_space_mode::blocks_first, data_types::i8, format::b_fs_yx_fsv16, 8, data_types::f32, format::bfyx
|
||||
|
||||
class DepthToSpaceFusingsTest : public ::BaseFusingTest<depth_to_space_test_params> {
|
||||
public:
|
||||
void execute(depth_to_space_test_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
network_not_fused.set_input_data("input", input_prim);
|
||||
|
||||
compare(network_not_fused, network_fused, p);
|
||||
}
|
||||
|
||||
layout get_input_layout(depth_to_space_test_params& p) { return layout{p.input_type, p.input_format, p.input_size}; }
|
||||
|
||||
layout get_per_channel_layout(depth_to_space_test_params& p) {
|
||||
return layout{p.default_type, p.default_format, tensor{1, p.output_size.feature[0], 1, 1}};
|
||||
}
|
||||
format get_input_format(depth_to_space_test_params &p) { return p.input_format; }
|
||||
};
|
||||
|
||||
class depth_to_space_quantize_i8 : public DepthToSpaceFusingsTest {};
|
||||
TEST_P(depth_to_space_quantize_i8, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
depth_to_space("depth_to_space", "input", p.block_size, p.mode),
|
||||
data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_low", get_mem(get_single_element_layout(p), -128)),
|
||||
data("out_high", get_mem(get_single_element_layout(p), 127)),
|
||||
quantize("quant", "depth_to_space", "in_low", "in_high", "out_low", "out_high", 256, data_types::i8),
|
||||
reorder("reorder_bfyx", "quant", format::bfyx, data_types::f32));
|
||||
|
||||
tolerance = 1.f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
fusings_gpu,
|
||||
depth_to_space_quantize_i8,
|
||||
::testing::ValuesIn(std::vector<depth_to_space_test_params>{
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F32_1, 2, 3},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F32_2, 2, 3},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F16_1, 2, 3},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F16_2, 2, 3},
|
||||
}), );
|
||||
|
||||
class depth_to_space_scale_act_eltwise_quantize_u8 : public DepthToSpaceFusingsTest {};
|
||||
TEST_P(depth_to_space_scale_act_eltwise_quantize_u8, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
depth_to_space("depth_to_space", "input", p.block_size, p.mode),
|
||||
data("scale1_data", get_mem(get_per_channel_layout(p), -0.125f)),
|
||||
scale("scale1", "depth_to_space", "scale1_data"),
|
||||
activation("actv1", "scale1", activation_func::relu),
|
||||
data("eltw_data", get_mem(layout(p.default_type, p.input_format, p.output_size))),
|
||||
eltwise("eltw", {"actv1", "eltw_data"}, eltwise_mode::sum, p.default_type),
|
||||
data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_low", get_mem(get_single_element_layout(p), 0)),
|
||||
data("out_high", get_mem(get_single_element_layout(p), 255)),
|
||||
quantize("quant", "eltw", "in_low", "in_high", "out_low", "out_high", 256, data_types::u8),
|
||||
reorder("reorder_bfyx", "quant", format::bfyx, data_types::f32));
|
||||
|
||||
tolerance = 1.f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
fusings_gpu,
|
||||
depth_to_space_scale_act_eltwise_quantize_u8,
|
||||
::testing::ValuesIn(std::vector<depth_to_space_test_params>{
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F32_1, 2, 6},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F32_2, 2, 6},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F16_1, 2, 6},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F16_2, 2, 6},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_U8_1, 2, 6},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_U8_2, 2, 6},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_I8_1, 2, 6},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_I8_2, 2, 6},
|
||||
}), );
|
||||
|
||||
|
||||
class depth_to_space_scale_act_eltw : public DepthToSpaceFusingsTest {};
|
||||
TEST_P(depth_to_space_scale_act_eltw, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
depth_to_space("depth_to_space", "input", p.block_size, p.mode),
|
||||
data("scale1_data", get_mem(get_per_channel_layout(p), -0.125f)),
|
||||
scale("scale1", "depth_to_space", "scale1_data"),
|
||||
activation("actv1", "scale1", activation_func::relu),
|
||||
data("eltw_data", get_mem(layout(p.default_type, p.input_format, p.output_size))),
|
||||
eltwise("eltw", {"actv1", "eltw_data"}, eltwise_mode::sum, p.default_type),
|
||||
reorder("reorder_bfyx", "eltw", format::bfyx, data_types::f32));
|
||||
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
fusings_gpu,
|
||||
depth_to_space_scale_act_eltw,
|
||||
::testing::ValuesIn(std::vector<depth_to_space_test_params>{
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F32_1, 2, 5},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F32_2, 2, 5},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F16_1, 2, 5},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_F16_2, 2, 5},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_U8_1, 2, 5},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_U8_2, 2, 5},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_I8_1, 2, 5},
|
||||
depth_to_space_test_params{CASE_DEPTH_TO_SPACE_I8_2, 2, 5},
|
||||
}), );
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ------------------------------------------ Gather cases --------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
@ -4189,13 +4329,13 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_quantize,
|
||||
gather_test_params{ CASE_GATHER_FP32_2, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP32_3, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP32_4, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP32_5, 2, 3 },
|
||||
|
||||
gather_test_params{ CASE_GATHER_FP32_5, 2, 3 },
|
||||
|
||||
gather_test_params{ CASE_GATHER_FP16_1, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP16_2, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP16_3, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP16_4, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP16_5, 2, 3 },
|
||||
gather_test_params{ CASE_GATHER_FP16_5, 2, 3 },
|
||||
}), );
|
||||
|
||||
class gather_scale_activation : public GatherPrimitiveFusingTest {};
|
||||
@ -4220,8 +4360,8 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_scale_activation,
|
||||
gather_test_params{ CASE_GATHER_FP32_2, 2, 4 },
|
||||
gather_test_params{ CASE_GATHER_FP32_3, 2, 4 },
|
||||
gather_test_params{ CASE_GATHER_FP32_4, 2, 4 },
|
||||
gather_test_params{ CASE_GATHER_FP32_5, 2, 4 },
|
||||
|
||||
gather_test_params{ CASE_GATHER_FP32_5, 2, 4 },
|
||||
|
||||
gather_test_params{ CASE_GATHER_FP16_1, 2, 4 },
|
||||
gather_test_params{ CASE_GATHER_FP16_2, 2, 4 },
|
||||
gather_test_params{ CASE_GATHER_FP16_3, 2, 4 },
|
||||
|
Loading…
Reference in New Issue
Block a user