[IE CLDNN] Performance / accuracy fixes (#3729)

- Added linear_onnx mode support into resample_opt kernel.
- Fixed byxf layout check.
- Added Resample + Eltwise fusing support
- Update dequantize merge pass to work with eltwise instead of scale
- Fixed uninitialized m_maxBatch value for query mode
- Fixed missing AddPrimitiveToProfiler for DeformablePSRoiPooling
- Fixed 0d gather
- Added WA for Resample+Eltwise fusing

Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
This commit is contained in:
Vladimir Paramuzov 2021-01-14 15:10:11 +03:00 committed by GitHub
parent 93e9eefaec
commit 08afa4fd97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 386 additions and 85 deletions

View File

@ -54,6 +54,7 @@
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
#include <transformations/op_conversions/convert_gather_0d.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
@ -155,6 +156,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
manager.register_pass<ngraph::pass::ConvertNMS3ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
manager.register_pass<ngraph::pass::ConvertGather0D>();
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
{ngraph::element::i64, ngraph::element::i32},
@ -164,7 +166,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
{ngraph::element::boolean, ngraph::element::u8},
};
for (auto & precision : convert_precision_list) {
for (auto& precision : convert_precision_list) {
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
}

View File

@ -71,7 +71,7 @@ public:
class Program {
public:
Program(InferenceEngine::CNNNetwork& network, std::shared_ptr<const cldnn::engine> engine, const Config& config);
Program() : m_config({}), m_engine(nullptr), m_curBatch(-1), queryMode(false) {}
Program() : m_config({}), m_engine(nullptr), m_curBatch(-1), queryMode(false), m_max_batch(1) {}
static const cldnn::primitive_id m_preProcessTag;
static const cldnn::primitive_id m_meanValuesTag;

View File

@ -45,20 +45,21 @@ void CreateDeformablePSROIPoolingOp(Program& p, const std::shared_ptr<ngraph::op
bool position_sensitive = true;
auto psROIPoolingPrim = cldnn::roi_pooling(layerName,
inputPrimitives,
mode,
position_sensitive,
pooled_width,
pooled_height,
spatial_scale,
trans_std,
no_trans,
part_size,
group_size,
output_dim,
spatial_bins_x,
spatial_bins_y);
inputPrimitives,
mode,
position_sensitive,
pooled_width,
pooled_height,
spatial_scale,
trans_std,
no_trans,
part_size,
group_size,
output_dim,
spatial_bins_x,
spatial_bins_y);
p.AddPrimitive(psROIPoolingPrim);
p.AddPrimitiveToProfiler(op);
}
void CreatePSROIPoolingOp(Program& p, const std::shared_ptr<ngraph::op::v0::PSROIPooling>& op) {

View File

@ -31,14 +31,23 @@ void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v1::Stri
break;
}
bool valid_mask = true;
for (auto& m : op->get_begin_mask()) {
if (m != 0)
if (m != 0) {
valid_mask = false;
break;
}
}
for (auto& m : op->get_end_mask()) {
if (m != 0)
if (m != 0) {
valid_mask = false;
break;
}
}
if (!valid_mask) {
break;
}
auto input_shape = op->get_input_shape(0);
@ -186,7 +195,7 @@ void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v1::Stri
uniq_id++;
}
if (axes.size() != 4) {
if (axes.size() > 4) {
break;
}

View File

@ -0,0 +1,31 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertGather0D;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvertGather0D decomposes v1::Gather operation into v0::Unsqueeze + v1::Gather + v0::Squeeze pattern when gather indices is scalar
*/
class ngraph::pass::ConvertGather0D : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertGather0D();
};

View File

@ -0,0 +1,52 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_gather_0d.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather0D, "ConvertGather0D", 0);
ngraph::pass::ConvertGather0D::ConvertGather0D() {
auto gather = ngraph::pattern::wrap_type<opset1::Gather>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto gather = std::dynamic_pointer_cast<ngraph::opset1::Gather>(m.get_match_root());
if (!gather) {
return false;
}
auto axes_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(gather->input_value(2).get_node_shared_ptr());
if (!axes_constant) {
return false;
}
// if the input with indices is scalar we need to unsqueeze it to 1D so plugins which do not support 0D can
// execute this layer. Then we need to squeeze the axis dimension to restore original shape of gather output
auto indices = gather->input_value(1);
const auto indices_rank = indices.get_partial_shape().rank();
if (indices_rank.is_dynamic() || indices_rank.get_length() != 0) {
return false;
}
auto axis = axes_constant->cast_vector<int64_t>()[0];
indices = std::make_shared<ngraph::opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
auto gather_new = std::make_shared<ngraph::opset1::Gather>(gather->input_value(0), indices, axes_constant);
auto sq = std::make_shared<ngraph::opset1::Squeeze>(gather_new, opset1::Constant::create(element::i64, Shape{1}, {axis}));
sq->set_friendly_name(gather->get_friendly_name());
ngraph::copy_runtime_info(gather, {indices.get_node_shared_ptr(), gather_new, sq});
ngraph::replace_node(gather, sq);
return true;
};
auto m1 = std::make_shared<ngraph::pattern::Matcher>(gather, "ConvertGather0D");
this->register_matcher(m1, callback);
}

View File

@ -0,0 +1,85 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <transformations/op_conversions/convert_gather_0d.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, ConvertGather0DStatic1) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertGather0D>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
}
{
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
f_ref = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertGather0DStatic2) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertGather0D>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
}
{
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
auto unsqueeze = std::make_shared<opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
auto gather = std::make_shared<opset1::Gather>(input, unsqueeze, axis_const);
auto squeeze = std::make_shared<opset1::Squeeze>(gather, opset1::Constant::create(element::i64, Shape{1}, {1}));
f_ref = std::make_shared<Function>(NodeVector{squeeze}, ParameterVector{input, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -46,6 +46,7 @@ ParamsKey ResampleKernelOpt::GetSupportedKey() const {
k.EnableBatching();
k.EnableReampleType(ResampleType::BILINEAR_INTERP);
k.EnableReampleType(ResampleType::NEAREST_NEIGHBOR);
k.EnableReampleType(ResampleType::LINEAR_ONNX);
k.EnableSubGroup();
k.EnableSubGroupShort();
return k;

View File

@ -34,6 +34,7 @@ protected:
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE,
FusedOpType::SCALE,
FusedOpType::ELTWISE,
FusedOpType::ACTIVATION };
}
private:

View File

@ -29,6 +29,7 @@ public:
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE,
FusedOpType::SCALE,
FusedOpType::ELTWISE,
FusedOpType::ACTIVATION };
}

View File

@ -28,6 +28,23 @@
#define OUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE)
#define TO_OUT_VEC_TYPE(x) CAT(convert_, OUT_VEC_TYPE)(x)
inline float FUNC(get_original_coordinate)(float num, float scale, int length_resized, int length_original)
{
#if defined(COORD_TRANS_MODE_HALF_PIXEL)
return (num + 0.5f) * scale - 0.5f;
#elif defined(COORD_TRANS_MODE_PYTORCH_HALF_PIXEL)
return (length_resized > 1) ? (num + 0.5f) * scale - 0.5f : 0.f;
#elif defined(COORD_TRANS_MODE_ASYMMETRIC)
return num * scale;
#elif defined(COORD_TRANS_MODE_TF_HALF_PIXEL_FOR_NN)
return (num + 0.5f) * scale;
#elif defined(COORD_TRANS_MODE_ALIGN_CORNERS)
return (length_resized != 1) ? num * (length_original - 1) / (length_resized - 1) : 0.f;
#else
#error [clDNN resample_opt.cl]: coordinate transformation mode - not supported
#endif
}
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
KERNEL (resample_opt)(__global INPUT0_TYPE* input,
__global OUTPUT_TYPE* output
@ -56,7 +73,7 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input,
const int iy = floor(y * SCALES[3]);
in_vec_t res = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, iy, ix));
#else
#elif defined(SAMPLE_TYPE_INTERP)
const ACCUMULATOR_TYPE ix = TO_ACCUMULATOR_TYPE(SCALES[4]) * (x + out_x);
const ACCUMULATOR_TYPE iy = TO_ACCUMULATOR_TYPE(SCALES[3]) * y;
@ -76,6 +93,52 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input,
const acc_vec_t top = TO_ACC_VEC_TYPE(top_left) + (TO_ACC_VEC_TYPE(top_right) - TO_ACC_VEC_TYPE(top_left)) * dx;
const acc_vec_t bottom = TO_ACC_VEC_TYPE(bottom_left) + (TO_ACC_VEC_TYPE(bottom_right) - TO_ACC_VEC_TYPE(bottom_left)) * dx;
acc_vec_t res = top + (bottom - top) * dy;
#else // defined(SAMPLE_TYPE_LINEAR_ONNX)
const int PADDED_Y = INPUT0_SIZE_Y + PADS_BEGIN[3] + PADS_END[3];
const int PADDED_X = INPUT0_SIZE_X + PADS_BEGIN[4] + PADS_END[4];
const ACCUMULATOR_TYPE ix = FUNC_CALL(get_original_coordinate)(x + out_x, SCALES[4], OUTPUT_SIZE_X, PADDED_X);
const ACCUMULATOR_TYPE iy = FUNC_CALL(get_original_coordinate)(y, SCALES[3], OUTPUT_SIZE_Y, PADDED_Y);
float in_y = fmax(0, fmin(iy, PADDED_Y - 1));
float in_x = fmax(0, fmin(ix, PADDED_X - 1));
int in_y1 = min((int)in_y, PADDED_Y - 1);
int in_y2 = min(in_y1 + 1, PADDED_Y - 1);
int in_x1 = min((int)in_x, PADDED_X - 1);
int in_x2 = min(in_x1 + 1, PADDED_X - 1);
const ACCUMULATOR_TYPE dx1 = (in_x1 != in_x2) ? TO_ACCUMULATOR_TYPE(fabs(in_x - in_x1)) : 0.5f;
const ACCUMULATOR_TYPE dx2 = (in_x1 != in_x2) ? TO_ACCUMULATOR_TYPE(fabs(in_x - in_x2)) : 0.5f;
const ACCUMULATOR_TYPE dy1 = (in_y1 != in_y2) ? TO_ACCUMULATOR_TYPE(fabs(in_y - in_y1)) : 0.5f;
const ACCUMULATOR_TYPE dy2 = (in_y1 != in_y2) ? TO_ACCUMULATOR_TYPE(fabs(in_y - in_y2)) : 0.5f;
#if PADDING_USED == 1
in_y1 -= PADS_BEGIN[3];
in_y2 -= PADS_BEGIN[3];
in_x1 -= PADS_BEGIN[4];
in_x2 -= PADS_BEGIN[4];
bool tlOutOfBounds = in_y1 < 0 || in_y1 >= in_size[3] || in_x1 < 0 || in_x1 >= in_size[4];
bool trOutOfBounds = in_y1 < 0 || in_y1 >= in_size[3] || in_x2 < 0 || in_x2 >= in_size[4];
bool blOutOfBounds = in_y2 < 0 || in_y2 >= in_size[3] || in_x1 < 0 || in_x1 >= in_size[4];
bool brOutOfBounds = in_y2 < 0 || in_y2 >= in_size[3] || in_x2 < 0 || in_x2 >= in_size[4];
#endif // PADDING_USED == 1
const acc_vec_t top_left = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y1, in_x1)));
const acc_vec_t top_right = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y1, in_x2)));
const acc_vec_t bottom_left = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y2, in_x1)));
const acc_vec_t bottom_right = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y2, in_x2)));
#if PADDING_USED == 1
if (tlOutOfBounds)
top_left = INPUT0_VAL_ZERO;
if (trOutOfBounds)
top_right = INPUT0_VAL_ZERO;
if (blOutOfBounds)
bottom_left = INPUT0_VAL_ZERO;
if (brOutOfBounds)
bottom_right = INPUT0_VAL_ZERO;
#endif // PADDING_USED == 1
acc_vec_t res = TO_ACC_VEC_TYPE(dx2 * dy2 * top_left) +
TO_ACC_VEC_TYPE(dx1 * dy2 * top_right) +
TO_ACC_VEC_TYPE(dx2 * dy1 * bottom_left) +
TO_ACC_VEC_TYPE(dx1 * dy1 * bottom_right);
#endif
#if HAS_FUSED_OPS
FUSED_OPS;

View File

@ -132,9 +132,18 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
interp_val = INPUT0_VAL_ZERO;
#endif
#if HAS_FUSED_OPS
#define batch (out_coords[0])
#define OF_ID (out_coords[1] + pi)
#define oz (out_coords[2])
#define oy (out_coords[3])
#define ox (out_coords[4])
FUSED_OPS;
res[pi] = FUSED_OPS_RESULT;
#undef batch
#undef OF_ID
#undef oz
#undef oy
#undef ox
#else // HAS_FUSED_OPS
res[pi] = ACTIVATION(interp_val, ACTIVATION_PARAMS);
#endif // HAS_FUSED_OPS
@ -170,9 +179,18 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
interp_val = INPUT0_VAL_ZERO;
#endif
#if HAS_FUSED_OPS
#define batch (out_coords[0])
#define OF_ID (out_coords[1])
#define oz (out_coords[2])
#define oy (out_coords[3])
#define ox (out_coords[4])
FUSED_OPS;
OUTPUT_TYPE res = FUSED_OPS_RESULT;
#undef batch
#undef OF_ID
#undef oz
#undef oy
#undef ox
#else // HAS_FUSED_OPS
OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(interp_val, ACTIVATION_PARAMS));
#endif // HAS_FUSED_OPS
@ -227,9 +245,18 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
}
#if HAS_FUSED_OPS
#define batch (out_coords[0])
#define OF_ID (out_coords[1])
#define oz (out_coords[2])
#define oy (out_coords[3])
#define ox (out_coords[4])
FUSED_OPS;
OUTPUT_TYPE res = FUSED_OPS_RESULT;
#undef batch
#undef OF_ID
#undef oz
#undef oy
#undef ox
#else // HAS_FUSED_OPS
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
#endif // HAS_FUSED_OPS
@ -296,6 +323,7 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
#define OF_ID (in_f)
FUSED_OPS;
OUTPUT_TYPE res = FUSED_OPS_RESULT;
#undef OF_ID
#else
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
#endif
@ -337,6 +365,7 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
#define OF_ID (in_f)
FUSED_OPS;
OUTPUT_TYPE res = FUSED_OPS_RESULT;
#undef OF_ID
#else
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
#endif
@ -465,6 +494,7 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
#define OF_ID (feature + f)
FUSED_OPS;
OUTPUT_TYPE res = FUSED_OPS_RESULT;
#undef OF_ID
#else
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
#endif

View File

@ -22,7 +22,6 @@
#include "program_node.h"
#include "mutable_data_inst.h"
#include "concatenation_inst.h"
#include "scale_inst.h"
#include "tensor_type.h"
#include <memory>
#include <vector>

View File

@ -729,6 +729,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
(parents[i]->is_type<mvn>() && mvn_supports_fusings(parents[i]->as<mvn>())) ||
(parents[i]->is_type<deconvolution>()) ||
(parents[i]->is_type<permute>()) ||
(parents[i]->is_type<resample>()) ||
(parents[i]->is_type<space_to_depth>()) ||
(parents[i]->is_type<gemm>() && gemm_supports_fusings(parents[i]->as<gemm>())) ||
(parents[i]->is_type<batch_to_space>()) ||
@ -804,6 +805,12 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
p.get_processing_order().get_processing_number(peer_node))
recalc_processing_order = true;
// [WA]: Resample + Eltwise fusing causes accuracy issues without processing order update.
// As in both cases processing order is valid, the issue might be connected with memory pool
if (fused_node->is_type<resample>()) {
recalc_processing_order = true;
}
p.fuse_nodes(*fused_node, node);
};

View File

@ -455,6 +455,9 @@ void prepare_quantization::prepare_asymmetric_quantization(program_impl &p) {
if (node.get_dependencies().size() != 2 || prim->mode != eltwise_mode::sub)
return false;
if (node.get_users().size() != 1)
return false;
auto in0_layout = node.get_dependency(0).get_output_layout();
auto in1_layout = node.get_dependency(1).get_output_layout();
@ -729,33 +732,40 @@ void prepare_quantization::prepare_dequantize_merge(program_impl &p) {
if (node->is_output())
continue;
program_helpers::do_for_types<scale>(*node, [&p](scale_node& node) {
program_helpers::do_for_types<eltwise>(*node, [&p](eltwise_node& node) {
for (size_t i = 1; i < node.get_dependencies().size(); i++) {
if (!node.get_dependency(i).is_type<data>()) {
return;
}
}
auto get_scale_shift_mem = [](const scale_node& scale, size_t dep_id) -> memory_impl& {
if (dep_id >= scale.get_dependencies().size())
CLDNN_ERROR_MESSAGE(scale.id(), "Invalid dependency id in dequantize optimization");
auto get_scale_shift_mem = [](const eltwise_node& eltw, size_t dep_id) -> memory_impl& {
if (dep_id >= eltw.get_dependencies().size())
CLDNN_ERROR_MESSAGE(eltw.id(), "Invalid dependency id in dequantize optimization");
return scale.get_dependency(dep_id).as<data>().get_attached_memory();
return eltw.get_dependency(dep_id).as<data>().get_attached_memory();
};
auto eltw_mode = node.get_primitive()->mode;
if (eltw_mode != eltwise_mode::sum && eltw_mode != eltwise_mode::prod)
return;
auto& input = node.input();
for (auto& user : input.get_users()) {
if (user == &node)
continue;
if (!user->is_type<scale>() || user->get_dependencies().size() != node.get_dependencies().size())
if (!user->is_type<eltwise>() || user->get_dependencies().size() != node.get_dependencies().size())
continue;
auto& eltwise_dep = user->as<eltwise>();
if (eltwise_dep.get_primitive()->mode != node.get_primitive()->mode)
continue;
auto& scale_dep = user->as<scale>();
bool valid_scale_node = true;
for (size_t i = 1; i < scale_dep.get_dependencies().size(); i++) {
if (!scale_dep.get_dependency(i).is_type<data>()) {
for (size_t i = 1; i < eltwise_dep.get_dependencies().size(); i++) {
if (!eltwise_dep.get_dependency(i).is_type<data>()) {
valid_scale_node = false;
}
}
@ -765,7 +775,7 @@ void prepare_quantization::prepare_dequantize_merge(program_impl &p) {
bool same_params = true;
for (size_t i = 1; i < node.get_dependencies().size(); i++) {
auto& mem0 = get_scale_shift_mem(user->as<scale>(), i);
auto& mem0 = get_scale_shift_mem(eltwise_dep, i);
auto& mem1 = get_scale_shift_mem(node, i);
auto ptr0 = static_cast<uint8_t*>(mem0.lock());

View File

@ -225,6 +225,8 @@ inline std::string fmt_to_str(format fmt) {
return "g_os_zyx_is_osv32_isv16";
case format::g_os_zyx_is_osv32_isv32:
return "g_os_zyx_is_osv32_isv32";
case format::gs_oi_yxs_gsv32_yxsv4:
return "gs_oi_yxs_gsv32_yxsv4";
default:
return "unknown (" + std::to_string(fmt.value) + ")";
}

View File

@ -630,7 +630,8 @@ bool layout_optimizer::deps_for_convolution_byxf_opt(program_node const& node, u
conv_dep)) {
return false;
}
} else if (!dep->is_type<pooling>() && (!dep->is_type<eltwise>() || !is_scale_shift(dep->as<eltwise>()))) {
} else if ((!dep->is_type<pooling>() && !dep->is_type<eltwise>()) ||
(dep->is_type<eltwise>() && is_scale_shift(dep->as<eltwise>()))) {
return false;
}

View File

@ -64,6 +64,8 @@ std::string resample_inst::to_string(resample_node const& node) {
resample_info.add("resample_type:", "caffe_bilinear_interp");
else if (desc->operation_type == resample_type::cubic)
resample_info.add("resample_type:", "cubic");
else if (desc->operation_type == resample_type::linear_onnx)
resample_info.add("resample_type:", "linear_onnx");
else
resample_info.add("resample_type:", "not supported sample type");

View File

@ -1492,38 +1492,40 @@ TEST_P(conv_int8_scale_shift_swish, basic) {
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
data("shift_data", get_mem(get_per_channel_layout(p), 1)),
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
scale("scale0", "conv_prim", "scale_data", "shift_data"),
scale("scale1", "conv_prim", "scale_data", "shift_data"),
activation("sigmoid", "scale0", activation_func::logistic),
eltwise("mul", {"scale1", "sigmoid"}, eltwise_mode::prod),
eltwise("scale0", {"conv_prim", "scale_data"}, eltwise_mode::prod),
eltwise("scale1", {"conv_prim", "scale_data"}, eltwise_mode::prod),
eltwise("shift0", {"scale0", "shift_data"}, eltwise_mode::sum),
eltwise("shift1", {"scale1", "shift_data"}, eltwise_mode::sum),
activation("sigmoid", "shift0", activation_func::logistic),
eltwise("mul", {"shift1", "sigmoid"}, eltwise_mode::prod),
reorder("reorder_bfyx", "mul", p.default_format, data_types::f32)
);
tolerance = 1e-5f;
tolerance = 1e-4f;
execute(p);
}
INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_int8_scale_shift_swish,
::testing::ValuesIn(std::vector<bc_test_params>{
bc_test_params{CASE_CONV_U8S8_1, 2, 6},
bc_test_params{CASE_CONV_U8S8_2, 2, 6},
bc_test_params{CASE_CONV_U8S8_3, 2, 6},
bc_test_params{CASE_CONV_U8S8_4, 2, 6},
bc_test_params{CASE_CONV_S8S8_1, 2, 6},
bc_test_params{CASE_CONV_S8S8_2, 2, 6},
bc_test_params{CASE_CONV_S8S8_3, 2, 6},
bc_test_params{CASE_CONV_S8S8_4, 2, 6},
bc_test_params{CASE_CONV_U8S8_1, 2, 8},
bc_test_params{CASE_CONV_U8S8_2, 2, 8},
bc_test_params{CASE_CONV_U8S8_3, 2, 8},
bc_test_params{CASE_CONV_U8S8_4, 2, 8},
bc_test_params{CASE_CONV_S8S8_1, 2, 8},
bc_test_params{CASE_CONV_S8S8_2, 2, 8},
bc_test_params{CASE_CONV_S8S8_3, 2, 8},
bc_test_params{CASE_CONV_S8S8_4, 2, 8},
bc_test_params{CASE_CONV3D_U8S8_1, 2, 6},
bc_test_params{CASE_CONV3D_U8S8_2, 2, 6},
bc_test_params{CASE_CONV3D_U8S8_3, 2, 6},
bc_test_params{CASE_CONV3D_U8S8_4, 2, 6},
bc_test_params{CASE_CONV3D_U8S8_5, 2, 6},
bc_test_params{CASE_CONV3D_S8S8_1, 2, 6},
bc_test_params{CASE_CONV3D_S8S8_2, 2, 6},
bc_test_params{CASE_CONV3D_S8S8_3, 2, 6},
bc_test_params{CASE_CONV3D_S8S8_4, 2, 6},
bc_test_params{CASE_CONV3D_S8S8_5, 2, 6},
bc_test_params{CASE_CONV3D_U8S8_1, 2, 8},
bc_test_params{CASE_CONV3D_U8S8_2, 2, 8},
bc_test_params{CASE_CONV3D_U8S8_3, 2, 8},
bc_test_params{CASE_CONV3D_U8S8_4, 2, 8},
bc_test_params{CASE_CONV3D_U8S8_5, 2, 8},
bc_test_params{CASE_CONV3D_S8S8_1, 2, 8},
bc_test_params{CASE_CONV3D_S8S8_2, 2, 8},
bc_test_params{CASE_CONV3D_S8S8_3, 2, 8},
bc_test_params{CASE_CONV3D_S8S8_4, 2, 8},
bc_test_params{CASE_CONV3D_S8S8_5, 2, 8},
}), );
class conv_int8_prelu_eltwise : public ConvFusingTest {};
@ -2988,53 +2990,55 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_quantize,
// resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 3 },
}), );
class resample_scale_activation : public ResamplePrimitiveFusingTest {};
TEST_P(resample_scale_activation, basic) {
class resample_scale_activation_eltwise : public ResamplePrimitiveFusingTest {};
TEST_P(resample_scale_activation_eltwise, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
data("eltwise_data", get_mem(get_output_layout(p), -10, 10)),
resample("resample_prim", "input", p.out_shape, p.in_shape.feature[0], p.type),
scale("scale", "resample_prim", "scale_data"),
activation("activation", "scale", activation_func::abs),
reorder("reorder_bfyx", "activation", p.default_format, data_types::f32)
eltwise("eltwise", { "activation", "eltwise_data"}, eltwise_mode::sum),
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
);
tolerance = 1e-5f;
execute(p);
}
INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_activation,
INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_activation_eltwise,
::testing::ValuesIn(std::vector<resample_test_params>{
resample_test_params{ CASE_RESAMPLE_FP32_1, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_2, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_3, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_4, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_5, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_6, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_7, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_8, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_9, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP32_1, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_2, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_3, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_4, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_5, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_6, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_7, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_8, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP32_9, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_1, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_2, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_3, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_4, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_5, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_6, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_7, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_8, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_10, 2, 4 },
resample_test_params{ CASE_RESAMPLE_FP16_1, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_2, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_3, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_4, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_5, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_6, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_7, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_8, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 5 },
resample_test_params{ CASE_RESAMPLE_FP16_10, 2, 5 },
resample_test_params{ CASE_RESAMPLE_I8_1, 2, 4 },
resample_test_params{ CASE_RESAMPLE_I8_2, 2, 4 },
resample_test_params{ CASE_RESAMPLE_I8_3, 2, 4 },
resample_test_params{ CASE_RESAMPLE_I8_4, 2, 4 },
resample_test_params{ CASE_RESAMPLE_I8_1, 2, 5 },
resample_test_params{ CASE_RESAMPLE_I8_2, 2, 5 },
resample_test_params{ CASE_RESAMPLE_I8_3, 2, 5 },
resample_test_params{ CASE_RESAMPLE_I8_4, 2, 5 },
resample_test_params{ CASE_RESAMPLE_U8_1, 2, 4 },
resample_test_params{ CASE_RESAMPLE_U8_2, 2, 4 },
resample_test_params{ CASE_RESAMPLE_U8_3, 2, 4 },
resample_test_params{ CASE_RESAMPLE_U8_4, 2, 4 },
resample_test_params{ CASE_RESAMPLE_U8_1, 2, 5 },
resample_test_params{ CASE_RESAMPLE_U8_2, 2, 5 },
resample_test_params{ CASE_RESAMPLE_U8_3, 2, 5 },
resample_test_params{ CASE_RESAMPLE_U8_4, 2, 5 },
}), );
class resample_quantize_concat : public ResamplePrimitiveFusingTest {};