[IE CLDNN] Performance / accuracy fixes (#3729)
- Added linear_onnx mode support into resample_opt kernel. - Fixed byxf layout check. - Added Resample + Eltwise fusing support - Update dequantize merge pass to work with eltwise instead of scale - Fixed uninitialized m_maxBatch value for query mode - Fixed missing AddPrimitiveToProfiler for DeformablePSRoiPooling - Fixed 0d gather - Added WA for Resample+Eltwise fusing Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
This commit is contained in:
parent
93e9eefaec
commit
08afa4fd97
@ -54,6 +54,7 @@
|
||||
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
|
||||
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
|
||||
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_0d.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
@ -155,6 +156,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
manager.register_pass<ngraph::pass::ConvertNMS3ToNMS5>();
|
||||
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
|
||||
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather0D>();
|
||||
|
||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
@ -164,7 +166,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
{ngraph::element::boolean, ngraph::element::u8},
|
||||
};
|
||||
|
||||
for (auto & precision : convert_precision_list) {
|
||||
for (auto& precision : convert_precision_list) {
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ public:
|
||||
class Program {
|
||||
public:
|
||||
Program(InferenceEngine::CNNNetwork& network, std::shared_ptr<const cldnn::engine> engine, const Config& config);
|
||||
Program() : m_config({}), m_engine(nullptr), m_curBatch(-1), queryMode(false) {}
|
||||
Program() : m_config({}), m_engine(nullptr), m_curBatch(-1), queryMode(false), m_max_batch(1) {}
|
||||
|
||||
static const cldnn::primitive_id m_preProcessTag;
|
||||
static const cldnn::primitive_id m_meanValuesTag;
|
||||
|
@ -45,20 +45,21 @@ void CreateDeformablePSROIPoolingOp(Program& p, const std::shared_ptr<ngraph::op
|
||||
bool position_sensitive = true;
|
||||
|
||||
auto psROIPoolingPrim = cldnn::roi_pooling(layerName,
|
||||
inputPrimitives,
|
||||
mode,
|
||||
position_sensitive,
|
||||
pooled_width,
|
||||
pooled_height,
|
||||
spatial_scale,
|
||||
trans_std,
|
||||
no_trans,
|
||||
part_size,
|
||||
group_size,
|
||||
output_dim,
|
||||
spatial_bins_x,
|
||||
spatial_bins_y);
|
||||
inputPrimitives,
|
||||
mode,
|
||||
position_sensitive,
|
||||
pooled_width,
|
||||
pooled_height,
|
||||
spatial_scale,
|
||||
trans_std,
|
||||
no_trans,
|
||||
part_size,
|
||||
group_size,
|
||||
output_dim,
|
||||
spatial_bins_x,
|
||||
spatial_bins_y);
|
||||
p.AddPrimitive(psROIPoolingPrim);
|
||||
p.AddPrimitiveToProfiler(op);
|
||||
}
|
||||
|
||||
void CreatePSROIPoolingOp(Program& p, const std::shared_ptr<ngraph::op::v0::PSROIPooling>& op) {
|
||||
|
@ -31,14 +31,23 @@ void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v1::Stri
|
||||
break;
|
||||
}
|
||||
|
||||
bool valid_mask = true;
|
||||
for (auto& m : op->get_begin_mask()) {
|
||||
if (m != 0)
|
||||
if (m != 0) {
|
||||
valid_mask = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& m : op->get_end_mask()) {
|
||||
if (m != 0)
|
||||
if (m != 0) {
|
||||
valid_mask = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!valid_mask) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto input_shape = op->get_input_shape(0);
|
||||
@ -186,7 +195,7 @@ void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v1::Stri
|
||||
uniq_id++;
|
||||
}
|
||||
|
||||
if (axes.size() != 4) {
|
||||
if (axes.size() > 4) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertGather0D;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ConvertGather0D decomposes v1::Gather operation into v0::Unsqueeze + v1::Gather + v0::Squeeze pattern when gather indices is scalar
|
||||
*/
|
||||
class ngraph::pass::ConvertGather0D : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGather0D();
|
||||
};
|
@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_gather_0d.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather0D, "ConvertGather0D", 0);
|
||||
|
||||
ngraph::pass::ConvertGather0D::ConvertGather0D() {
|
||||
auto gather = ngraph::pattern::wrap_type<opset1::Gather>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto gather = std::dynamic_pointer_cast<ngraph::opset1::Gather>(m.get_match_root());
|
||||
if (!gather) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axes_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(gather->input_value(2).get_node_shared_ptr());
|
||||
if (!axes_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the input with indices is scalar we need to unsqueeze it to 1D so plugins which do not support 0D can
|
||||
// execute this layer. Then we need to squeeze the axis dimension to restore original shape of gather output
|
||||
auto indices = gather->input_value(1);
|
||||
const auto indices_rank = indices.get_partial_shape().rank();
|
||||
if (indices_rank.is_dynamic() || indices_rank.get_length() != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis = axes_constant->cast_vector<int64_t>()[0];
|
||||
indices = std::make_shared<ngraph::opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
auto gather_new = std::make_shared<ngraph::opset1::Gather>(gather->input_value(0), indices, axes_constant);
|
||||
auto sq = std::make_shared<ngraph::opset1::Squeeze>(gather_new, opset1::Constant::create(element::i64, Shape{1}, {axis}));
|
||||
sq->set_friendly_name(gather->get_friendly_name());
|
||||
|
||||
ngraph::copy_runtime_info(gather, {indices.get_node_shared_ptr(), gather_new, sq});
|
||||
ngraph::replace_node(gather, sq);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m1 = std::make_shared<ngraph::pattern::Matcher>(gather, "ConvertGather0D");
|
||||
this->register_matcher(m1, callback);
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_0d.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(TransformationTests, ConvertGather0DStatic1) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
|
||||
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
|
||||
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
|
||||
auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertGather0D>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
|
||||
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
|
||||
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
|
||||
auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
|
||||
|
||||
f_ref = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertGather0DStatic2) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
|
||||
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
|
||||
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
|
||||
auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertGather0D>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
|
||||
auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
|
||||
auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
|
||||
auto unsqueeze = std::make_shared<opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
auto gather = std::make_shared<opset1::Gather>(input, unsqueeze, axis_const);
|
||||
auto squeeze = std::make_shared<opset1::Squeeze>(gather, opset1::Constant::create(element::i64, Shape{1}, {1}));
|
||||
|
||||
f_ref = std::make_shared<Function>(NodeVector{squeeze}, ParameterVector{input, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -46,6 +46,7 @@ ParamsKey ResampleKernelOpt::GetSupportedKey() const {
|
||||
k.EnableBatching();
|
||||
k.EnableReampleType(ResampleType::BILINEAR_INTERP);
|
||||
k.EnableReampleType(ResampleType::NEAREST_NEIGHBOR);
|
||||
k.EnableReampleType(ResampleType::LINEAR_ONNX);
|
||||
k.EnableSubGroup();
|
||||
k.EnableSubGroupShort();
|
||||
return k;
|
||||
|
@ -34,6 +34,7 @@ protected:
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ELTWISE,
|
||||
FusedOpType::ACTIVATION };
|
||||
}
|
||||
private:
|
||||
|
@ -29,6 +29,7 @@ public:
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ELTWISE,
|
||||
FusedOpType::ACTIVATION };
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,23 @@
|
||||
#define OUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE)
|
||||
#define TO_OUT_VEC_TYPE(x) CAT(convert_, OUT_VEC_TYPE)(x)
|
||||
|
||||
inline float FUNC(get_original_coordinate)(float num, float scale, int length_resized, int length_original)
|
||||
{
|
||||
#if defined(COORD_TRANS_MODE_HALF_PIXEL)
|
||||
return (num + 0.5f) * scale - 0.5f;
|
||||
#elif defined(COORD_TRANS_MODE_PYTORCH_HALF_PIXEL)
|
||||
return (length_resized > 1) ? (num + 0.5f) * scale - 0.5f : 0.f;
|
||||
#elif defined(COORD_TRANS_MODE_ASYMMETRIC)
|
||||
return num * scale;
|
||||
#elif defined(COORD_TRANS_MODE_TF_HALF_PIXEL_FOR_NN)
|
||||
return (num + 0.5f) * scale;
|
||||
#elif defined(COORD_TRANS_MODE_ALIGN_CORNERS)
|
||||
return (length_resized != 1) ? num * (length_original - 1) / (length_resized - 1) : 0.f;
|
||||
#else
|
||||
#error [clDNN resample_opt.cl]: coordinate transformation mode - not supported
|
||||
#endif
|
||||
}
|
||||
|
||||
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
|
||||
KERNEL (resample_opt)(__global INPUT0_TYPE* input,
|
||||
__global OUTPUT_TYPE* output
|
||||
@ -56,7 +73,7 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input,
|
||||
const int iy = floor(y * SCALES[3]);
|
||||
|
||||
in_vec_t res = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, iy, ix));
|
||||
#else
|
||||
#elif defined(SAMPLE_TYPE_INTERP)
|
||||
const ACCUMULATOR_TYPE ix = TO_ACCUMULATOR_TYPE(SCALES[4]) * (x + out_x);
|
||||
const ACCUMULATOR_TYPE iy = TO_ACCUMULATOR_TYPE(SCALES[3]) * y;
|
||||
|
||||
@ -76,6 +93,52 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input,
|
||||
const acc_vec_t top = TO_ACC_VEC_TYPE(top_left) + (TO_ACC_VEC_TYPE(top_right) - TO_ACC_VEC_TYPE(top_left)) * dx;
|
||||
const acc_vec_t bottom = TO_ACC_VEC_TYPE(bottom_left) + (TO_ACC_VEC_TYPE(bottom_right) - TO_ACC_VEC_TYPE(bottom_left)) * dx;
|
||||
acc_vec_t res = top + (bottom - top) * dy;
|
||||
#else // defined(SAMPLE_TYPE_LINEAR_ONNX)
|
||||
const int PADDED_Y = INPUT0_SIZE_Y + PADS_BEGIN[3] + PADS_END[3];
|
||||
const int PADDED_X = INPUT0_SIZE_X + PADS_BEGIN[4] + PADS_END[4];
|
||||
|
||||
const ACCUMULATOR_TYPE ix = FUNC_CALL(get_original_coordinate)(x + out_x, SCALES[4], OUTPUT_SIZE_X, PADDED_X);
|
||||
const ACCUMULATOR_TYPE iy = FUNC_CALL(get_original_coordinate)(y, SCALES[3], OUTPUT_SIZE_Y, PADDED_Y);
|
||||
|
||||
float in_y = fmax(0, fmin(iy, PADDED_Y - 1));
|
||||
float in_x = fmax(0, fmin(ix, PADDED_X - 1));
|
||||
int in_y1 = min((int)in_y, PADDED_Y - 1);
|
||||
int in_y2 = min(in_y1 + 1, PADDED_Y - 1);
|
||||
int in_x1 = min((int)in_x, PADDED_X - 1);
|
||||
int in_x2 = min(in_x1 + 1, PADDED_X - 1);
|
||||
const ACCUMULATOR_TYPE dx1 = (in_x1 != in_x2) ? TO_ACCUMULATOR_TYPE(fabs(in_x - in_x1)) : 0.5f;
|
||||
const ACCUMULATOR_TYPE dx2 = (in_x1 != in_x2) ? TO_ACCUMULATOR_TYPE(fabs(in_x - in_x2)) : 0.5f;
|
||||
const ACCUMULATOR_TYPE dy1 = (in_y1 != in_y2) ? TO_ACCUMULATOR_TYPE(fabs(in_y - in_y1)) : 0.5f;
|
||||
const ACCUMULATOR_TYPE dy2 = (in_y1 != in_y2) ? TO_ACCUMULATOR_TYPE(fabs(in_y - in_y2)) : 0.5f;
|
||||
#if PADDING_USED == 1
|
||||
in_y1 -= PADS_BEGIN[3];
|
||||
in_y2 -= PADS_BEGIN[3];
|
||||
in_x1 -= PADS_BEGIN[4];
|
||||
in_x2 -= PADS_BEGIN[4];
|
||||
|
||||
bool tlOutOfBounds = in_y1 < 0 || in_y1 >= in_size[3] || in_x1 < 0 || in_x1 >= in_size[4];
|
||||
bool trOutOfBounds = in_y1 < 0 || in_y1 >= in_size[3] || in_x2 < 0 || in_x2 >= in_size[4];
|
||||
bool blOutOfBounds = in_y2 < 0 || in_y2 >= in_size[3] || in_x1 < 0 || in_x1 >= in_size[4];
|
||||
bool brOutOfBounds = in_y2 < 0 || in_y2 >= in_size[3] || in_x2 < 0 || in_x2 >= in_size[4];
|
||||
#endif // PADDING_USED == 1
|
||||
const acc_vec_t top_left = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y1, in_x1)));
|
||||
const acc_vec_t top_right = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y1, in_x2)));
|
||||
const acc_vec_t bottom_left = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y2, in_x1)));
|
||||
const acc_vec_t bottom_right = TO_ACC_VEC_TYPE(READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, in_y2, in_x2)));
|
||||
#if PADDING_USED == 1
|
||||
if (tlOutOfBounds)
|
||||
top_left = INPUT0_VAL_ZERO;
|
||||
if (trOutOfBounds)
|
||||
top_right = INPUT0_VAL_ZERO;
|
||||
if (blOutOfBounds)
|
||||
bottom_left = INPUT0_VAL_ZERO;
|
||||
if (brOutOfBounds)
|
||||
bottom_right = INPUT0_VAL_ZERO;
|
||||
#endif // PADDING_USED == 1
|
||||
acc_vec_t res = TO_ACC_VEC_TYPE(dx2 * dy2 * top_left) +
|
||||
TO_ACC_VEC_TYPE(dx1 * dy2 * top_right) +
|
||||
TO_ACC_VEC_TYPE(dx2 * dy1 * bottom_left) +
|
||||
TO_ACC_VEC_TYPE(dx1 * dy1 * bottom_right);
|
||||
#endif
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS;
|
||||
|
@ -132,9 +132,18 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
|
||||
interp_val = INPUT0_VAL_ZERO;
|
||||
#endif
|
||||
#if HAS_FUSED_OPS
|
||||
#define batch (out_coords[0])
|
||||
#define OF_ID (out_coords[1] + pi)
|
||||
#define oz (out_coords[2])
|
||||
#define oy (out_coords[3])
|
||||
#define ox (out_coords[4])
|
||||
FUSED_OPS;
|
||||
res[pi] = FUSED_OPS_RESULT;
|
||||
#undef batch
|
||||
#undef OF_ID
|
||||
#undef oz
|
||||
#undef oy
|
||||
#undef ox
|
||||
#else // HAS_FUSED_OPS
|
||||
res[pi] = ACTIVATION(interp_val, ACTIVATION_PARAMS);
|
||||
#endif // HAS_FUSED_OPS
|
||||
@ -170,9 +179,18 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
|
||||
interp_val = INPUT0_VAL_ZERO;
|
||||
#endif
|
||||
#if HAS_FUSED_OPS
|
||||
#define batch (out_coords[0])
|
||||
#define OF_ID (out_coords[1])
|
||||
#define oz (out_coords[2])
|
||||
#define oy (out_coords[3])
|
||||
#define ox (out_coords[4])
|
||||
FUSED_OPS;
|
||||
OUTPUT_TYPE res = FUSED_OPS_RESULT;
|
||||
#undef batch
|
||||
#undef OF_ID
|
||||
#undef oz
|
||||
#undef oy
|
||||
#undef ox
|
||||
#else // HAS_FUSED_OPS
|
||||
OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(interp_val, ACTIVATION_PARAMS));
|
||||
#endif // HAS_FUSED_OPS
|
||||
@ -227,9 +245,18 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
|
||||
}
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
#define batch (out_coords[0])
|
||||
#define OF_ID (out_coords[1])
|
||||
#define oz (out_coords[2])
|
||||
#define oy (out_coords[3])
|
||||
#define ox (out_coords[4])
|
||||
FUSED_OPS;
|
||||
OUTPUT_TYPE res = FUSED_OPS_RESULT;
|
||||
#undef batch
|
||||
#undef OF_ID
|
||||
#undef oz
|
||||
#undef oy
|
||||
#undef ox
|
||||
#else // HAS_FUSED_OPS
|
||||
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
|
||||
#endif // HAS_FUSED_OPS
|
||||
@ -296,6 +323,7 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
|
||||
#define OF_ID (in_f)
|
||||
FUSED_OPS;
|
||||
OUTPUT_TYPE res = FUSED_OPS_RESULT;
|
||||
#undef OF_ID
|
||||
#else
|
||||
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
|
||||
#endif
|
||||
@ -337,6 +365,7 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
|
||||
#define OF_ID (in_f)
|
||||
FUSED_OPS;
|
||||
OUTPUT_TYPE res = FUSED_OPS_RESULT;
|
||||
#undef OF_ID
|
||||
#else
|
||||
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
|
||||
#endif
|
||||
@ -465,6 +494,7 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input,
|
||||
#define OF_ID (feature + f)
|
||||
FUSED_OPS;
|
||||
OUTPUT_TYPE res = FUSED_OPS_RESULT;
|
||||
#undef OF_ID
|
||||
#else
|
||||
OUTPUT_TYPE res = ACTIVATION(TO_OUTPUT_TYPE(interp_val), ACTIVATION_PARAMS);
|
||||
#endif
|
||||
|
@ -22,7 +22,6 @@
|
||||
#include "program_node.h"
|
||||
#include "mutable_data_inst.h"
|
||||
#include "concatenation_inst.h"
|
||||
#include "scale_inst.h"
|
||||
#include "tensor_type.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
@ -729,6 +729,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
(parents[i]->is_type<mvn>() && mvn_supports_fusings(parents[i]->as<mvn>())) ||
|
||||
(parents[i]->is_type<deconvolution>()) ||
|
||||
(parents[i]->is_type<permute>()) ||
|
||||
(parents[i]->is_type<resample>()) ||
|
||||
(parents[i]->is_type<space_to_depth>()) ||
|
||||
(parents[i]->is_type<gemm>() && gemm_supports_fusings(parents[i]->as<gemm>())) ||
|
||||
(parents[i]->is_type<batch_to_space>()) ||
|
||||
@ -804,6 +805,12 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
p.get_processing_order().get_processing_number(peer_node))
|
||||
recalc_processing_order = true;
|
||||
|
||||
// [WA]: Resample + Eltwise fusing causes accuracy issues without processing order update.
|
||||
// As in both cases processing order is valid, the issue might be connected with memory pool
|
||||
if (fused_node->is_type<resample>()) {
|
||||
recalc_processing_order = true;
|
||||
}
|
||||
|
||||
p.fuse_nodes(*fused_node, node);
|
||||
};
|
||||
|
||||
|
@ -455,6 +455,9 @@ void prepare_quantization::prepare_asymmetric_quantization(program_impl &p) {
|
||||
if (node.get_dependencies().size() != 2 || prim->mode != eltwise_mode::sub)
|
||||
return false;
|
||||
|
||||
if (node.get_users().size() != 1)
|
||||
return false;
|
||||
|
||||
auto in0_layout = node.get_dependency(0).get_output_layout();
|
||||
auto in1_layout = node.get_dependency(1).get_output_layout();
|
||||
|
||||
@ -729,33 +732,40 @@ void prepare_quantization::prepare_dequantize_merge(program_impl &p) {
|
||||
if (node->is_output())
|
||||
continue;
|
||||
|
||||
program_helpers::do_for_types<scale>(*node, [&p](scale_node& node) {
|
||||
program_helpers::do_for_types<eltwise>(*node, [&p](eltwise_node& node) {
|
||||
for (size_t i = 1; i < node.get_dependencies().size(); i++) {
|
||||
if (!node.get_dependency(i).is_type<data>()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto get_scale_shift_mem = [](const scale_node& scale, size_t dep_id) -> memory_impl& {
|
||||
if (dep_id >= scale.get_dependencies().size())
|
||||
CLDNN_ERROR_MESSAGE(scale.id(), "Invalid dependency id in dequantize optimization");
|
||||
auto get_scale_shift_mem = [](const eltwise_node& eltw, size_t dep_id) -> memory_impl& {
|
||||
if (dep_id >= eltw.get_dependencies().size())
|
||||
CLDNN_ERROR_MESSAGE(eltw.id(), "Invalid dependency id in dequantize optimization");
|
||||
|
||||
return scale.get_dependency(dep_id).as<data>().get_attached_memory();
|
||||
return eltw.get_dependency(dep_id).as<data>().get_attached_memory();
|
||||
};
|
||||
|
||||
auto eltw_mode = node.get_primitive()->mode;
|
||||
if (eltw_mode != eltwise_mode::sum && eltw_mode != eltwise_mode::prod)
|
||||
return;
|
||||
|
||||
auto& input = node.input();
|
||||
|
||||
for (auto& user : input.get_users()) {
|
||||
if (user == &node)
|
||||
continue;
|
||||
|
||||
if (!user->is_type<scale>() || user->get_dependencies().size() != node.get_dependencies().size())
|
||||
if (!user->is_type<eltwise>() || user->get_dependencies().size() != node.get_dependencies().size())
|
||||
continue;
|
||||
|
||||
auto& eltwise_dep = user->as<eltwise>();
|
||||
if (eltwise_dep.get_primitive()->mode != node.get_primitive()->mode)
|
||||
continue;
|
||||
|
||||
auto& scale_dep = user->as<scale>();
|
||||
bool valid_scale_node = true;
|
||||
for (size_t i = 1; i < scale_dep.get_dependencies().size(); i++) {
|
||||
if (!scale_dep.get_dependency(i).is_type<data>()) {
|
||||
for (size_t i = 1; i < eltwise_dep.get_dependencies().size(); i++) {
|
||||
if (!eltwise_dep.get_dependency(i).is_type<data>()) {
|
||||
valid_scale_node = false;
|
||||
}
|
||||
}
|
||||
@ -765,7 +775,7 @@ void prepare_quantization::prepare_dequantize_merge(program_impl &p) {
|
||||
|
||||
bool same_params = true;
|
||||
for (size_t i = 1; i < node.get_dependencies().size(); i++) {
|
||||
auto& mem0 = get_scale_shift_mem(user->as<scale>(), i);
|
||||
auto& mem0 = get_scale_shift_mem(eltwise_dep, i);
|
||||
auto& mem1 = get_scale_shift_mem(node, i);
|
||||
|
||||
auto ptr0 = static_cast<uint8_t*>(mem0.lock());
|
||||
|
@ -225,6 +225,8 @@ inline std::string fmt_to_str(format fmt) {
|
||||
return "g_os_zyx_is_osv32_isv16";
|
||||
case format::g_os_zyx_is_osv32_isv32:
|
||||
return "g_os_zyx_is_osv32_isv32";
|
||||
case format::gs_oi_yxs_gsv32_yxsv4:
|
||||
return "gs_oi_yxs_gsv32_yxsv4";
|
||||
default:
|
||||
return "unknown (" + std::to_string(fmt.value) + ")";
|
||||
}
|
||||
|
@ -630,7 +630,8 @@ bool layout_optimizer::deps_for_convolution_byxf_opt(program_node const& node, u
|
||||
conv_dep)) {
|
||||
return false;
|
||||
}
|
||||
} else if (!dep->is_type<pooling>() && (!dep->is_type<eltwise>() || !is_scale_shift(dep->as<eltwise>()))) {
|
||||
} else if ((!dep->is_type<pooling>() && !dep->is_type<eltwise>()) ||
|
||||
(dep->is_type<eltwise>() && is_scale_shift(dep->as<eltwise>()))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -64,6 +64,8 @@ std::string resample_inst::to_string(resample_node const& node) {
|
||||
resample_info.add("resample_type:", "caffe_bilinear_interp");
|
||||
else if (desc->operation_type == resample_type::cubic)
|
||||
resample_info.add("resample_type:", "cubic");
|
||||
else if (desc->operation_type == resample_type::linear_onnx)
|
||||
resample_info.add("resample_type:", "linear_onnx");
|
||||
else
|
||||
resample_info.add("resample_type:", "not supported sample type");
|
||||
|
||||
|
@ -1492,38 +1492,40 @@ TEST_P(conv_int8_scale_shift_swish, basic) {
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f/p.kernel.count())),
|
||||
data("shift_data", get_mem(get_per_channel_layout(p), 1)),
|
||||
convolution("conv_prim", "input", {"weights"}, {"bias"}, p.groups, p.stride, p.pad, p.dilation),
|
||||
scale("scale0", "conv_prim", "scale_data", "shift_data"),
|
||||
scale("scale1", "conv_prim", "scale_data", "shift_data"),
|
||||
activation("sigmoid", "scale0", activation_func::logistic),
|
||||
eltwise("mul", {"scale1", "sigmoid"}, eltwise_mode::prod),
|
||||
eltwise("scale0", {"conv_prim", "scale_data"}, eltwise_mode::prod),
|
||||
eltwise("scale1", {"conv_prim", "scale_data"}, eltwise_mode::prod),
|
||||
eltwise("shift0", {"scale0", "shift_data"}, eltwise_mode::sum),
|
||||
eltwise("shift1", {"scale1", "shift_data"}, eltwise_mode::sum),
|
||||
activation("sigmoid", "shift0", activation_func::logistic),
|
||||
eltwise("mul", {"shift1", "sigmoid"}, eltwise_mode::prod),
|
||||
reorder("reorder_bfyx", "mul", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
tolerance = 1e-4f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_int8_scale_shift_swish,
|
||||
::testing::ValuesIn(std::vector<bc_test_params>{
|
||||
bc_test_params{CASE_CONV_U8S8_1, 2, 6},
|
||||
bc_test_params{CASE_CONV_U8S8_2, 2, 6},
|
||||
bc_test_params{CASE_CONV_U8S8_3, 2, 6},
|
||||
bc_test_params{CASE_CONV_U8S8_4, 2, 6},
|
||||
bc_test_params{CASE_CONV_S8S8_1, 2, 6},
|
||||
bc_test_params{CASE_CONV_S8S8_2, 2, 6},
|
||||
bc_test_params{CASE_CONV_S8S8_3, 2, 6},
|
||||
bc_test_params{CASE_CONV_S8S8_4, 2, 6},
|
||||
bc_test_params{CASE_CONV_U8S8_1, 2, 8},
|
||||
bc_test_params{CASE_CONV_U8S8_2, 2, 8},
|
||||
bc_test_params{CASE_CONV_U8S8_3, 2, 8},
|
||||
bc_test_params{CASE_CONV_U8S8_4, 2, 8},
|
||||
bc_test_params{CASE_CONV_S8S8_1, 2, 8},
|
||||
bc_test_params{CASE_CONV_S8S8_2, 2, 8},
|
||||
bc_test_params{CASE_CONV_S8S8_3, 2, 8},
|
||||
bc_test_params{CASE_CONV_S8S8_4, 2, 8},
|
||||
|
||||
bc_test_params{CASE_CONV3D_U8S8_1, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_U8S8_2, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_U8S8_3, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_U8S8_4, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_U8S8_5, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_S8S8_1, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_S8S8_2, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_S8S8_3, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_S8S8_4, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_S8S8_5, 2, 6},
|
||||
bc_test_params{CASE_CONV3D_U8S8_1, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_U8S8_2, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_U8S8_3, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_U8S8_4, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_U8S8_5, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_S8S8_1, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_S8S8_2, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_S8S8_3, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_S8S8_4, 2, 8},
|
||||
bc_test_params{CASE_CONV3D_S8S8_5, 2, 8},
|
||||
}), );
|
||||
|
||||
class conv_int8_prelu_eltwise : public ConvFusingTest {};
|
||||
@ -2988,53 +2990,55 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_quantize,
|
||||
// resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 3 },
|
||||
}), );
|
||||
|
||||
class resample_scale_activation : public ResamplePrimitiveFusingTest {};
|
||||
TEST_P(resample_scale_activation, basic) {
|
||||
class resample_scale_activation_eltwise : public ResamplePrimitiveFusingTest {};
|
||||
TEST_P(resample_scale_activation_eltwise, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
|
||||
data("eltwise_data", get_mem(get_output_layout(p), -10, 10)),
|
||||
resample("resample_prim", "input", p.out_shape, p.in_shape.feature[0], p.type),
|
||||
scale("scale", "resample_prim", "scale_data"),
|
||||
activation("activation", "scale", activation_func::abs),
|
||||
reorder("reorder_bfyx", "activation", p.default_format, data_types::f32)
|
||||
eltwise("eltwise", { "activation", "eltwise_data"}, eltwise_mode::sum),
|
||||
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_activation,
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_activation_eltwise,
|
||||
::testing::ValuesIn(std::vector<resample_test_params>{
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_1, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_2, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_3, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_4, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_5, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_6, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_7, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_8, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_9, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_1, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_2, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_3, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_4, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_5, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_6, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_7, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_8, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP32_9, 2, 5 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_1, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_2, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_3, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_4, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_5, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_6, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_7, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_8, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_10, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_1, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_2, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_3, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_4, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_5, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_6, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_7, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_8, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_FP16_10, 2, 5 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_I8_1, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_2, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_3, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_4, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_1, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_2, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_3, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_I8_4, 2, 5 },
|
||||
|
||||
resample_test_params{ CASE_RESAMPLE_U8_1, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_U8_2, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_U8_3, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_U8_4, 2, 4 },
|
||||
resample_test_params{ CASE_RESAMPLE_U8_1, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_U8_2, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_U8_3, 2, 5 },
|
||||
resample_test_params{ CASE_RESAMPLE_U8_4, 2, 5 },
|
||||
}), );
|
||||
|
||||
class resample_quantize_concat : public ResamplePrimitiveFusingTest {};
|
||||
|
Loading…
Reference in New Issue
Block a user