[GPU] Support of int8 compressed weights for matmul (#19548)
This commit is contained in:
parent
a1cc5e6692
commit
7e3e1e2480
@ -0,0 +1,35 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fully_connected.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gpu {
|
||||
namespace op {
|
||||
|
||||
class FullyConnectedCompressed : public FullyConnected {
|
||||
public:
|
||||
OPENVINO_OP("FullyConnectedCompressed", "gpu_opset");
|
||||
|
||||
FullyConnectedCompressed() = default;
|
||||
|
||||
FullyConnectedCompressed(const ov::Output<Node> &A,
|
||||
const ov::Output<Node> &B,
|
||||
const ov::Output<Node> &decompression_scale,
|
||||
const ov::Output<Node> &decompression_zero_point,
|
||||
const ov::element::Type output_type = ov::element::undefined);
|
||||
|
||||
FullyConnectedCompressed(const ov::Output<Node> &A,
|
||||
const ov::Output<Node> &B,
|
||||
const ov::Output<Node> &decompression_scale,
|
||||
const ov::element::Type output_type = ov::element::undefined);
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace intel_gpu
|
||||
} // namespace ov
|
@ -262,3 +262,4 @@ REGISTER_FACTORY(internal, GenerateProposalsIEInternal);
|
||||
REGISTER_FACTORY(internal, NmsStaticShapeIE8);
|
||||
REGISTER_FACTORY(internal, MulticlassNmsIEInternal);
|
||||
REGISTER_FACTORY(internal, FullyConnected);
|
||||
REGISTER_FACTORY(internal, FullyConnectedCompressed);
|
||||
|
@ -76,10 +76,43 @@ struct fully_connected : public primitive_base<fully_connected> {
|
||||
weights_rank(weights_rank)
|
||||
{}
|
||||
|
||||
/// @brief Constructs fully connected compressed layer.
|
||||
/// @param id This primitive id.
|
||||
/// @param input Input primitive id.
|
||||
/// @param weights Primitive id containing weights data.
|
||||
/// @param bias Primitive id containing bias data.
|
||||
/// @param compression_scale Primitive id containing scale factors for weights decompression.
|
||||
/// @param compression_zero_point Primitive id containing zero points for weights decompression.
|
||||
fully_connected(const primitive_id& id,
|
||||
const input_info& input,
|
||||
const primitive_id& weights,
|
||||
const primitive_id& bias,
|
||||
const primitive_id& decompression_scale,
|
||||
const primitive_id& decompression_zero_point,
|
||||
const data_types data_type,
|
||||
const padding& output_padding = padding(),
|
||||
const size_t input_size = 2,
|
||||
const size_t weights_rank = 2)
|
||||
: primitive_base(id, { input }, {output_padding}, {optional_data_type{data_type}}),
|
||||
weights(weights),
|
||||
bias(bias),
|
||||
compressed_weights(true),
|
||||
decompression_scale(decompression_scale),
|
||||
decompression_zero_point(decompression_zero_point),
|
||||
input_size(input_size),
|
||||
weights_rank(weights_rank) {
|
||||
OPENVINO_ASSERT(!decompression_scale.empty(), "[GPU] Compressed fully connected requires at least decompression scale input");
|
||||
}
|
||||
|
||||
/// @brief Primitive id containing weights data.
|
||||
primitive_id weights;
|
||||
/// @brief Primitive id containing bias data.
|
||||
primitive_id bias;
|
||||
|
||||
bool compressed_weights = false;
|
||||
primitive_id decompression_scale = "";
|
||||
primitive_id decompression_zero_point = "";
|
||||
|
||||
/// @brief Primitive dimension size.
|
||||
size_t input_size = 2;
|
||||
/// @brief Primitive weights rank.
|
||||
@ -90,6 +123,9 @@ struct fully_connected : public primitive_base<fully_connected> {
|
||||
seed = hash_combine(seed, input_size);
|
||||
seed = hash_combine(seed, weights_rank);
|
||||
seed = hash_combine(seed, bias.empty());
|
||||
seed = hash_combine(seed, compressed_weights);
|
||||
seed = hash_combine(seed, !decompression_scale.empty());
|
||||
seed = hash_combine(seed, !decompression_zero_point.empty());
|
||||
return seed;
|
||||
}
|
||||
|
||||
@ -108,6 +144,9 @@ struct fully_connected : public primitive_base<fully_connected> {
|
||||
primitive_base<fully_connected>::save(ob);
|
||||
ob << weights;
|
||||
ob << bias;
|
||||
ob << compressed_weights;
|
||||
ob << decompression_scale;
|
||||
ob << decompression_zero_point;
|
||||
ob << input_size;
|
||||
ob << weights_rank;
|
||||
}
|
||||
@ -116,6 +155,9 @@ struct fully_connected : public primitive_base<fully_connected> {
|
||||
primitive_base<fully_connected>::load(ib);
|
||||
ib >> weights;
|
||||
ib >> bias;
|
||||
ib >> compressed_weights;
|
||||
ib >> decompression_scale;
|
||||
ib >> decompression_zero_point;
|
||||
ib >> input_size;
|
||||
ib >> weights_rank;
|
||||
}
|
||||
@ -128,6 +170,12 @@ protected:
|
||||
if (!bias.empty())
|
||||
ret.push_back(bias);
|
||||
|
||||
if (!decompression_scale.empty())
|
||||
ret.push_back(decompression_scale);
|
||||
|
||||
if (!decompression_zero_point.empty())
|
||||
ret.push_back(decompression_zero_point);
|
||||
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
@ -217,6 +217,11 @@ std::string fully_connected_inst::to_string(fully_connected_node const& node) {
|
||||
json_composite fc_info;
|
||||
fc_info.add("weights id", weights_id);
|
||||
fc_info.add("bias id", bias_id);
|
||||
fc_info.add("compressed weights", desc->compressed_weights ? "true" : "false");
|
||||
if (desc->compressed_weights) {
|
||||
fc_info.add("decompression scale id", desc->decompression_scale);
|
||||
fc_info.add("decompression zp id", desc->decompression_zero_point);
|
||||
}
|
||||
|
||||
node_info->add("fully connected info", fc_info);
|
||||
node_info->dump(primitive_description);
|
||||
|
@ -466,6 +466,11 @@ void prepare_primitive_fusing::fuse_bias(program &p) {
|
||||
desc->output_paddings[0],
|
||||
desc->input_size);
|
||||
|
||||
if (desc->compressed_weights) {
|
||||
fc_with_bias_prim->compressed_weights = true;
|
||||
fc_with_bias_prim->decompression_scale = desc->decompression_scale;
|
||||
fc_with_bias_prim->decompression_zero_point = desc->decompression_zero_point;
|
||||
}
|
||||
auto& new_fc_node = p.get_or_create(fc_with_bias_prim);
|
||||
fuse_bias_f(fc, new_fc_node, bias_node, eltw_node);
|
||||
}
|
||||
|
@ -26,10 +26,19 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {
|
||||
protected:
|
||||
kernel_arguments_data get_arguments(const typed_primitive_inst<fully_connected>& instance) const override {
|
||||
kernel_arguments_data args = parent::get_arguments(instance);
|
||||
const auto& desc = instance.get_typed_desc<fully_connected>();
|
||||
|
||||
args.weights = instance.weights_memory();
|
||||
args.bias = instance.bias_term() ? instance.bias_memory() : nullptr;
|
||||
|
||||
args.inputs = { instance.input_memory_ptr(0) };
|
||||
size_t in_id = instance.bias_term() ? 3 : 2;
|
||||
if (!desc->decompression_scale.empty())
|
||||
args.inputs.push_back(instance.dep_memory_ptr(in_id++));
|
||||
|
||||
if (!desc->decompression_zero_point.empty())
|
||||
args.inputs.push_back(instance.dep_memory_ptr(in_id));
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
@ -72,6 +81,27 @@ public:
|
||||
|
||||
std::vector<layout> layouts{input0_layout, input1_layout};
|
||||
|
||||
bool has_zp = !primitive->decompression_zero_point.empty();
|
||||
bool has_scale = !primitive->decompression_scale.empty();
|
||||
|
||||
size_t offset = primitive->bias.empty() ? 2 : 3;
|
||||
const auto& weights_pshape = input1_layout.get_partial_shape();
|
||||
if (has_scale) {
|
||||
auto scale_layout = input_layouts[offset++];
|
||||
if (input1_pshape.size() != 2) {
|
||||
scale_layout.set_partial_shape(reshape_to_2d(scale_layout.get_partial_shape(), weights_pshape[0], primitive->weights_rank));
|
||||
}
|
||||
layouts.push_back(scale_layout);
|
||||
}
|
||||
|
||||
if (has_zp) {
|
||||
auto zp_layout = input_layouts[offset];
|
||||
if (input1_pshape.size() != 2) {
|
||||
zp_layout.set_partial_shape(reshape_to_2d(zp_layout.get_partial_shape(), weights_pshape[0], primitive->weights_rank));
|
||||
}
|
||||
layouts.push_back(zp_layout);
|
||||
}
|
||||
|
||||
return layouts;
|
||||
};
|
||||
|
||||
@ -105,6 +135,17 @@ public:
|
||||
auto optional_params = get_default_weights_bias_optional_params<kernel_selector::fully_connected_optional_params>(progam);
|
||||
optional_params.allowInputReordering = true;
|
||||
|
||||
bool commpressed = !primitive->decompression_scale.empty();
|
||||
bool with_zp = !primitive->decompression_zero_point.empty();
|
||||
if (commpressed) {
|
||||
params.compressed = true;
|
||||
params.decompression_scale = convert_data_tensor(input_layouts[2]);
|
||||
if (with_zp) {
|
||||
params.has_decompression_zp = true;
|
||||
params.decompression_zero_point = convert_data_tensor(input_layouts[3]);
|
||||
}
|
||||
}
|
||||
|
||||
if (primitive->input_size != 3)
|
||||
params.outputs = { params.outputs[0].FlattenFeatureAndSpatials() };
|
||||
|
||||
|
@ -872,6 +872,10 @@ static bool is_node_for_onednn(deconvolution_node const& node) {
|
||||
|
||||
static bool is_node_for_onednn(fully_connected_node const& node) {
|
||||
auto fc_prim = node.get_primitive();
|
||||
// onednn impl doesn't support compressed weights for now
|
||||
if (fc_prim->compressed_weights)
|
||||
return false;
|
||||
|
||||
auto output_layout = node.get_output_layout();
|
||||
auto ps = output_layout.get_partial_shape();
|
||||
size_t non_spatial_count = 2 + (fc_prim->input_size == 3 ? 1 : 0);
|
||||
|
@ -39,14 +39,15 @@
|
||||
#endif
|
||||
|
||||
// Macros for vectorized types.
|
||||
#define INPUT_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, TILE_IFM)
|
||||
#define ACCUMULATOR_VEC_TYPE MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, TILE_OFM)
|
||||
#define FILTER_VEC_TYPE MAKE_VECTOR_TYPE(FILTER_TYPE, TILE_K_OFM)
|
||||
#define BIAS_VEC_TYPE MAKE_VECTOR_TYPE(BIAS_TYPE, TILE_OFM)
|
||||
#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, TILE_OFM)
|
||||
#define ACTIVATION_VEC_TYPE MAKE_VECTOR_TYPE(ACTIVATION_TYPE, TILE_OFM)
|
||||
#define TO_OUTPUT_VEC_TYPE(x) CAT(convert_, OUTPUT_VEC_TYPE)(x)
|
||||
#define TO_ACTIVATION_VEC_TYPE(x) CAT(convert_, ACTIVATION_VEC_TYPE)(x)
|
||||
#define INPUT_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, TILE_IFM)
|
||||
#define ACCUMULATOR_VEC_TYPE MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, TILE_OFM)
|
||||
#define FILTER_VEC_TYPE MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, TILE_K_OFM)
|
||||
#define BIAS_VEC_TYPE MAKE_VECTOR_TYPE(BIAS_TYPE, TILE_OFM)
|
||||
#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, TILE_OFM)
|
||||
#define ACTIVATION_VEC_TYPE MAKE_VECTOR_TYPE(ACTIVATION_TYPE, TILE_OFM)
|
||||
#define TO_OUTPUT_VEC_TYPE(x) CAT(convert_, OUTPUT_VEC_TYPE)(x)
|
||||
#define TO_ACTIVATION_VEC_TYPE(x) CAT(convert_, ACTIVATION_VEC_TYPE)(x)
|
||||
#define TO_FILTER_VEC_TYPE(x) CAT(convert_, FILTER_VEC_TYPE)(x)
|
||||
|
||||
#define INPUT_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, TILE_IFM, ptr, offset)
|
||||
#define FILTER_BLOCK_READ(ptr, offset) BLOCK_READN(FILTER_TYPE, TILE_K_OFM, ptr, offset)
|
||||
@ -81,6 +82,12 @@ REQD_SUB_GROUP_SIZE(SIMD)
|
||||
KERNEL(fc)(
|
||||
OPTIONAL_SHAPE_INFO_ARG
|
||||
const __global INPUT0_TYPE* input,
|
||||
#if DECOMPRESSION_SCALE_TERM
|
||||
const __global DECOMPRESSION_SCALE_TYPE* decompression_scale,
|
||||
#endif
|
||||
#if DECOMPRESSION_ZP_TERM
|
||||
const __global DECOMPRESSION_ZP_TYPE* decompression_zp,
|
||||
#endif
|
||||
__global OUTPUT_TYPE* output,
|
||||
const __global FILTER_TYPE* weights
|
||||
#if BIAS_TERM
|
||||
@ -113,13 +120,48 @@ KERNEL(fc)(
|
||||
uint input_offset = out_b * TILE_IN_B_PITCH + INPUT0_OFFSET;
|
||||
uint weights_offset = out_f * INPUT_ELEMENTS_COUNT;
|
||||
|
||||
#if COMPRESSED_WEIGHTS
|
||||
#if DECOMPRESSION_SCALE_LENGTH > 1 && DECOMPRESSION_SCALE_LENGTH % SIMD == 0
|
||||
ACCUMULATOR_VEC_TYPE d_scale = BLOCK_READN(ACCUMULATOR_TYPE, TILE_OFM, decompression_scale, out_f);
|
||||
#elif DECOMPRESSION_SCALE_LENGTH > 1 && DECOMPRESSION_SCALE_LENGTH % SIMD != 0
|
||||
ACCUMULATOR_VEC_TYPE d_scale = 0;
|
||||
unroll_for(uint of = 0; of < TILE_OFM; ++of) {
|
||||
uint offset = out_f + of*SIMD + get_sub_group_local_id();
|
||||
if (offset < DECOMPRESSION_SCALE_LENGTH)
|
||||
((ACCUMULATOR_TYPE*)(&d_scale))[of] = decompression_scale[offset];
|
||||
}
|
||||
#else
|
||||
ACCUMULATOR_VEC_TYPE d_scale = decompression_scale[0];
|
||||
#endif
|
||||
|
||||
#if !DECOMPRESSION_ZP_TERM
|
||||
ACCUMULATOR_VEC_TYPE d_zp = 0;
|
||||
#elif DECOMPRESSION_ZP_LENGTH > 1 && DECOMPRESSION_ZP_LENGTH % SIMD == 0
|
||||
ACCUMULATOR_VEC_TYPE d_zp = BLOCK_READN(ACCUMULATOR_TYPE, TILE_OFM, decompression_zp, out_f);
|
||||
#elif DECOMPRESSION_ZP_LENGTH > 1 && DECOMPRESSION_ZP_LENGTH % SIMD != 0
|
||||
ACCUMULATOR_VEC_TYPE d_zp = 0;
|
||||
unroll_for(uint of = 0; of < TILE_OFM; ++of) {
|
||||
uint offset = out_f + of*SIMD + get_sub_group_local_id();
|
||||
if (offset < DECOMPRESSION_ZP_LENGTH)
|
||||
((ACCUMULATOR_TYPE*)(&d_zp))[of] = decompression_zp[offset];
|
||||
}
|
||||
#else
|
||||
ACCUMULATOR_VEC_TYPE d_zp = decompression_zp[0];
|
||||
#endif
|
||||
|
||||
ACCUMULATOR_TYPE* ds = (ACCUMULATOR_TYPE*)(&d_scale);
|
||||
ACCUMULATOR_TYPE* dzp = (ACCUMULATOR_TYPE*)(&d_zp);
|
||||
#endif
|
||||
|
||||
#if REALIGN_FP16_OFFSET
|
||||
// For fp16 we need to ensure that all block reads are aligned to 4 byte (2 words) boundary.
|
||||
// To do this solve first input feature separately.
|
||||
{
|
||||
INPUT0_TYPE tmp_input = input[input_offset + get_sub_group_local_id() % TILE_B * TILE_IN_B_PITCH];
|
||||
MAKE_VECTOR_TYPE(FILTER_TYPE, TILE_OFM) tmp_wei = BLOCK_READN(FILTER_TYPE, TILE_OFM, weights, weights_offset);
|
||||
|
||||
ACCUMULATOR_VEC_TYPE tmp_wei = TO_ACCUMULATOR_VEC_TYPE(BLOCK_READN(FILTER_TYPE, TILE_OFM, weights, weights_offset));
|
||||
#if COMPRESSED_WEIGHTS
|
||||
tmp_wei = (tmp_wei - d_zp) * d_scale;
|
||||
#endif
|
||||
unroll_for(uint bi = 0; bi < TILE_B; ++bi) {
|
||||
acc[bi] = _sub_group_shuffle(tmp_input, bi) * tmp_wei;
|
||||
}
|
||||
@ -146,7 +188,15 @@ KERNEL(fc)(
|
||||
// but significantly degrades readability and generality of code.
|
||||
// It doesn't also show noticable performance improvement on tested configurations.
|
||||
unroll_for(uint ki = 0; ki < (TILE_IFM * SIMD) / TILE_K; ++ki) {
|
||||
wei = FILTER_BLOCK_READ(weights, weights_offset);
|
||||
wei = TO_FILTER_VEC_TYPE(FILTER_BLOCK_READ(weights, weights_offset));
|
||||
#if COMPRESSED_WEIGHTS
|
||||
ACCUMULATOR_TYPE* w = (ACCUMULATOR_TYPE*)(&wei);
|
||||
unroll_for(uint kii = 0; kii < TILE_K; ++kii) {
|
||||
unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) {
|
||||
w[kii * TILE_OFM + fi] = (w[kii * TILE_OFM + fi] - dzp[fi]) * ds[fi];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
weights_offset += TILE_K_OFM * SIMD;
|
||||
|
||||
unroll_for (uint kii = 0; kii < TILE_K; ++kii) {
|
||||
@ -154,7 +204,7 @@ KERNEL(fc)(
|
||||
unroll_for (uint bi = 0; bi < TILE_B; ++bi) {
|
||||
INPUT0_TYPE in_val = _sub_group_shuffle(((INPUT0_TYPE*)(&in_0[bi]))[total_k / SIMD], total_k % SIMD);
|
||||
unroll_for (uint fi = 0; fi < TILE_OFM; ++fi) {
|
||||
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += in_val * ((FILTER_TYPE*)(&wei))[kii * TILE_OFM + fi];
|
||||
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += in_val * ((ACCUMULATOR_TYPE*)(&wei))[kii * TILE_OFM + fi];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -175,7 +225,15 @@ KERNEL(fc)(
|
||||
#undef LOAD_IN_0
|
||||
input_offset += TILE_IFM * SIMD - TILE_IN_B_PITCH * TILE_B;
|
||||
unroll_for(uint ki = 0; ki < CEIL_DIV(LEFTOVER_IFM, TILE_K); ++ki) {
|
||||
wei = FILTER_BLOCK_READ(weights, weights_offset);
|
||||
wei = TO_FILTER_VEC_TYPE(FILTER_BLOCK_READ(weights, weights_offset));
|
||||
#if COMPRESSED_WEIGHTS
|
||||
ACCUMULATOR_TYPE* w = (ACCUMULATOR_TYPE*)(&wei);
|
||||
unroll_for(uint kii = 0; kii < TILE_K; ++kii) {
|
||||
unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) {
|
||||
w[kii * TILE_OFM + fi] = (w[kii * TILE_OFM + fi] - dzp[fi]) * ds[fi];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
weights_offset += TILE_K_OFM * SIMD;
|
||||
|
||||
unroll_for (uint kii = 0; kii < TILE_K; ++kii) {
|
||||
@ -184,7 +242,7 @@ KERNEL(fc)(
|
||||
const uint total_k = ki * TILE_K + kii;
|
||||
if (total_k < LEFTOVER_IFM) {
|
||||
INPUT0_TYPE in_val = _sub_group_shuffle(((INPUT0_TYPE*)(&in_0[bi]))[total_k / SIMD], total_k % SIMD);
|
||||
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += in_val * ((FILTER_TYPE*)(&wei))[kii * TILE_OFM + fi];
|
||||
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += in_val * ((ACCUMULATOR_TYPE*)(&wei))[kii * TILE_OFM + fi];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,12 @@
|
||||
KERNEL(fc)(
|
||||
OPTIONAL_SHAPE_INFO_ARG
|
||||
const __global INPUT0_TYPE* input,
|
||||
#if DECOMPRESSION_SCALE_TERM
|
||||
const __global DECOMPRESSION_SCALE_TYPE* decompression_scale,
|
||||
#endif
|
||||
#if DECOMPRESSION_ZP_TERM
|
||||
const __global DECOMPRESSION_ZP_TYPE* decompression_zp,
|
||||
#endif
|
||||
__global OUTPUT_TYPE* output,
|
||||
const __global FILTER_TYPE* weights
|
||||
#if BIAS_TERM
|
||||
@ -31,7 +37,19 @@ KERNEL(fc)(
|
||||
{
|
||||
const uint input0_idx = INPUT0_GET_INDEX(b, ofm, y, x);
|
||||
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, oym, y, 0, 0);
|
||||
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(weights[filter_idx]);
|
||||
#if COMPRESSED_WEIGHTS
|
||||
ACCUMULATOR_TYPE filter_compressed = TO_ACCUMULATOR_TYPE(weights[filter_idx]);
|
||||
#if DECOMPRESSION_ZP_TERM
|
||||
ACCUMULATOR_TYPE zp = TO_ACCUMULATOR_TYPE(decompression_zp[DECOMPRESSION_ZP_GET_INDEX_SAFE(0, oym, 0, 0)]);
|
||||
#else
|
||||
ACCUMULATOR_TYPE zp = ACCUMULATOR_VAL_ZERO;
|
||||
#endif
|
||||
DECOMPRESSION_SCALE_TYPE scale = decompression_scale[DECOMPRESSION_SCALE_GET_INDEX_SAFE(0, oym, 0, 0)];
|
||||
ACCUMULATOR_TYPE filter_val = (TO_ACCUMULATOR_TYPE(filter_compressed) - TO_ACCUMULATOR_TYPE(zp)) * scale;
|
||||
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(filter_val);
|
||||
#else
|
||||
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(weights[filter_idx]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +68,20 @@ KERNEL(fc)(
|
||||
{
|
||||
const uint input0_idx = INPUT0_GET_INDEX(b, ifm, y, x);
|
||||
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, ofm, ifm, y, x);
|
||||
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(weights[filter_idx]);
|
||||
#if COMPRESSED_WEIGHTS
|
||||
FILTER_TYPE filter_compressed = weights[filter_idx];
|
||||
#if DECOMPRESSION_ZP_TERM
|
||||
ACCUMULATOR_TYPE zp = decompression_zp[DECOMPRESSION_ZP_GET_INDEX_SAFE(0, ofm, 0, 0)];
|
||||
#else
|
||||
ACCUMULATOR_TYPE zp = ACCUMULATOR_VAL_ZERO;
|
||||
#endif
|
||||
|
||||
DECOMPRESSION_SCALE_TYPE scale = decompression_scale[DECOMPRESSION_SCALE_GET_INDEX_SAFE(0, ofm, 0, 0)];
|
||||
ACCUMULATOR_TYPE filter_val = (TO_ACCUMULATOR_TYPE(filter_compressed) - TO_ACCUMULATOR_TYPE(zp)) * scale;
|
||||
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(filter_val);
|
||||
#else
|
||||
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(weights[filter_idx]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -130,6 +130,7 @@ public:
|
||||
uint32_t asym_w_quantization : 1;
|
||||
uint32_t asym_d_quantization : 1;
|
||||
uint32_t dynamic_shapes : 1;
|
||||
uint32_t compressed_weights : 1;
|
||||
|
||||
union dedicated_t {
|
||||
struct argm_t {
|
||||
@ -318,6 +319,7 @@ public:
|
||||
void EnablePoolRemainder(PoolRemainder r);
|
||||
void EnablePoolDilation() { key.restrict.val.dedicated.pooling.dilation = 1; }
|
||||
void EnablePoolIndicesOutput() { key.restrict.val.dedicated.pooling.indices_output = 1; }
|
||||
void EnableWeightsCompression() { key.restrict.val.compressed_weights = 1; }
|
||||
void EnableQuantization(QuantizationType q);
|
||||
void EnablePositionSensitivePooling() { key.restrict.val.dedicated.pooling.position_sensitive = 1; }
|
||||
void EnableDilation() { key.restrict.val.dedicated.conv.dilation = 1; }
|
||||
|
@ -21,6 +21,17 @@ JitConstants FullyConnectedKernelBase::GetJitConstants(const fully_connected_par
|
||||
const auto x_size = input.LogicalSize() / input.Batch().v;
|
||||
jit.AddConstant(MakeJitConstant("INPUT0_ELEMENTS_COUNT", x_size));
|
||||
}
|
||||
|
||||
if (params.compressed) {
|
||||
jit.AddConstants({MakeJitConstant("COMPRESSED_WEIGHTS", 1)});
|
||||
jit.AddConstants({MakeJitConstant("DECOMPRESSION_SCALE_TERM", 1)});
|
||||
jit.AddConstants({MakeJitConstant("DECOMPRESSION_SCALE", params.decompression_scale)});
|
||||
if (params.has_decompression_zp) {
|
||||
jit.AddConstants({MakeJitConstant("DECOMPRESSION_ZP_TERM", 1)});
|
||||
jit.AddConstants({MakeJitConstant("DECOMPRESSION_ZP", params.decompression_zero_point)});
|
||||
}
|
||||
}
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
@ -93,11 +104,11 @@ KernelsData FullyConnectedKernelBase::GetCommonKernelsData(const Params ¶ms,
|
||||
auto cldnn_jit = GetJitConstants(newParams, dispatchData);
|
||||
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
|
||||
uint32_t fused_deps_total = 0;
|
||||
for (auto& fused_dep : newParams.fused_ops) {
|
||||
for (int i = 0; i < static_cast<int>(fused_dep.dep_size); i++) {
|
||||
fused_deps_total++;
|
||||
}
|
||||
int inputs_count = 1;
|
||||
if (newParams.compressed) {
|
||||
inputs_count++;
|
||||
if (newParams.has_decompression_zp)
|
||||
inputs_count++;
|
||||
}
|
||||
|
||||
auto& kernel = kd.kernels[0];
|
||||
@ -110,8 +121,8 @@ KernelsData FullyConnectedKernelBase::GetCommonKernelsData(const Params ¶ms,
|
||||
exeMode,
|
||||
true,
|
||||
!orgParams.bias.empty(),
|
||||
1,
|
||||
fused_deps_total,
|
||||
inputs_count,
|
||||
GetFusedPrimitiveInputsCount(params),
|
||||
1,
|
||||
orgParams.outputs[0].is_dynamic());
|
||||
|
||||
@ -176,10 +187,10 @@ Datatype FullyConnectedKernelBase::GetAccumulatorType(const fully_connected_para
|
||||
return Datatype::INT32;
|
||||
|
||||
// If we either weights or input is quantized, then we use fp32 accumulator to avoid fp16 overflow
|
||||
if (quantized_inputs || quantized_weights)
|
||||
if ((quantized_inputs || quantized_weights) && !params.compressed)
|
||||
return Datatype::F32;
|
||||
|
||||
return params.inputs[0].GetDType();
|
||||
return in_dt;
|
||||
}
|
||||
|
||||
Datatype FullyConnectedKernelBase::GetActivationType(const fully_connected_params& params) const {
|
||||
|
@ -52,6 +52,7 @@ ParamsKey FullyConnected_bf_tiled::GetSupportedKey() const {
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableDifferentInputWeightsTypes();
|
||||
k.EnableDynamicShapesSupport();
|
||||
k.EnableWeightsCompression();
|
||||
return k;
|
||||
}
|
||||
|
||||
@ -200,7 +201,9 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params,
|
||||
while (max_tile_ofm * 2 * simd <= output_f && max_tile_ofm < 4)
|
||||
max_tile_ofm *= 2;
|
||||
|
||||
if (params.is_shape_agnostic) {
|
||||
if (params.compressed && params.engineInfo.supports_immad) {
|
||||
return selector.Default(tune_params(1, 1, 1, 4, 1, 1, EXE_MODE_DEFAULT));
|
||||
} else if (params.is_shape_agnostic) {
|
||||
// Use special tuning params for Gen12HP dGPUs, since these parameters demonstrate higher performance
|
||||
// due to better HW utilization (reduced TILE_OFM parameter) and better assembler kernel's code
|
||||
// generation (extended TILE_K parameter) for both FP16 and FP32 data types
|
||||
|
@ -36,6 +36,7 @@ ParamsKey FullyConnected_bfyx_Ref::GetSupportedKey() const {
|
||||
k.EnableBatching();
|
||||
k.EnableQuantization(QuantizationType::SYMMETRIC);
|
||||
k.EnableDynamicShapesSupport();
|
||||
k.EnableWeightsCompression();
|
||||
return k;
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,10 @@ ParamsKey weight_bias_params::GetParamsKey() const {
|
||||
k.EnableBiasPerOutput();
|
||||
}
|
||||
|
||||
if (compressed) {
|
||||
k.EnableWeightsCompression();
|
||||
}
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,11 @@ struct weight_bias_params : public base_params {
|
||||
WeightsTensor weights;
|
||||
MultiDataTensor bias;
|
||||
|
||||
bool compressed = false;
|
||||
bool has_decompression_zp = false;
|
||||
DataTensor decompression_scale;
|
||||
DataTensor decompression_zero_point;
|
||||
|
||||
ParamsKey GetParamsKey() const override;
|
||||
};
|
||||
|
||||
|
@ -6,16 +6,17 @@
|
||||
#include "intel_gpu/plugin/common_utils.hpp"
|
||||
|
||||
#include "intel_gpu/op/fully_connected.hpp"
|
||||
#include "intel_gpu/op/fully_connected_compressed.hpp"
|
||||
|
||||
#include "intel_gpu/primitives/fully_connected.hpp"
|
||||
#include "intel_gpu/primitives/reshape.hpp"
|
||||
#include "intel_gpu/primitives/reorder.hpp"
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace internal {
|
||||
using FullyConnected = ov::intel_gpu::op::FullyConnected;
|
||||
using FullyConnectedCompressed = ov::intel_gpu::op::FullyConnectedCompressed;
|
||||
} // namespace internal
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
@ -23,13 +24,37 @@ using FullyConnected = ov::intel_gpu::op::FullyConnected;
|
||||
namespace ov {
|
||||
namespace intel_gpu {
|
||||
|
||||
static void CreateFullyConnectedCompressedOp(ProgramBuilder& p, const std::shared_ptr<op::FullyConnectedCompressed>& op) {
|
||||
validate_inputs_count(op, {3, 4});
|
||||
auto inputs = p.GetInputInfo(op);
|
||||
std::string primitive_name = layer_type_name_ID(op);
|
||||
|
||||
auto input_name = inputs[0].pid;
|
||||
auto weights_name = inputs[1].pid;
|
||||
auto scale_name = inputs[2].pid;
|
||||
auto zp_name = inputs.size() == 4 ? inputs[3].pid : "";
|
||||
|
||||
auto fc = cldnn::fully_connected(primitive_name,
|
||||
cldnn::input_info(input_name),
|
||||
weights_name,
|
||||
"",
|
||||
scale_name,
|
||||
zp_name,
|
||||
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
||||
cldnn::padding(),
|
||||
op->get_input_partial_shape(0).size(),
|
||||
op->get_input_partial_shape(1).size());
|
||||
|
||||
p.add_primitive(*op, fc);
|
||||
}
|
||||
|
||||
static void CreateFullyConnectedOp(ProgramBuilder& p, const std::shared_ptr<op::FullyConnected>& op) {
|
||||
validate_inputs_count(op, {2});
|
||||
auto inputs = p.GetInputInfo(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
auto inputName = inputs[0].pid;
|
||||
auto weightsName = inputs[1].pid;
|
||||
auto input_name = inputs[0].pid;
|
||||
auto weights_name = inputs[1].pid;
|
||||
|
||||
auto shape_a = op->get_input_partial_shape(0);
|
||||
auto shape_b = op->get_input_partial_shape(1);
|
||||
@ -38,8 +63,8 @@ static void CreateFullyConnectedOp(ProgramBuilder& p, const std::shared_ptr<op::
|
||||
auto rank_b = shape_b.rank().get_length();
|
||||
|
||||
auto fcPrim = cldnn::fully_connected(layerName,
|
||||
cldnn::input_info(inputName),
|
||||
weightsName,
|
||||
cldnn::input_info(input_name),
|
||||
weights_name,
|
||||
"",
|
||||
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
||||
cldnn::padding(),
|
||||
@ -78,6 +103,7 @@ static void CreateFullyConnectedOp(ProgramBuilder& p, const std::shared_ptr<op::
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(internal, FullyConnected);
|
||||
REGISTER_FACTORY_IMPL(internal, FullyConnectedCompressed);
|
||||
|
||||
} // namespace intel_gpu
|
||||
} // namespace ov
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "intel_gpu/runtime/debug_configuration.hpp"
|
||||
#include "intel_gpu/primitives/mutable_data.hpp"
|
||||
#include "intel_gpu/primitives/data.hpp"
|
||||
#include "intel_gpu/op/fully_connected_compressed.hpp"
|
||||
|
||||
#ifdef __linux__
|
||||
# include <dlfcn.h>
|
||||
@ -556,6 +557,9 @@ bool ProgramBuilder::requires_new_shape_infer(const ov::Node& op) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ov::is_type<op::FullyConnectedCompressed>(&op))
|
||||
return true;
|
||||
|
||||
for (size_t i = 0; i < op.get_output_size(); i++) {
|
||||
if (op.get_output_partial_shape(i).size() > 6)
|
||||
return true;
|
||||
|
@ -0,0 +1,106 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "convert_fc_to_compressed.hpp"
|
||||
|
||||
#include "intel_gpu/op/fully_connected.hpp"
|
||||
#include "intel_gpu/op/fully_connected_compressed.hpp"
|
||||
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gpu {
|
||||
|
||||
ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyConnectedCompressed() {
|
||||
using namespace ov::pass::pattern;
|
||||
|
||||
auto weights_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
|
||||
auto convert_m = wrap_type<ov::op::v0::Convert>({weights_m});
|
||||
|
||||
auto sub_const_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
|
||||
auto subtract_m = wrap_type<ov::op::v1::Subtract>({convert_m, sub_const_m});
|
||||
|
||||
auto mul_const_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
|
||||
auto mul_with_sub_m = wrap_type<ov::op::v1::Multiply>({subtract_m, mul_const_m});
|
||||
auto mul_no_sub_m = wrap_type<ov::op::v1::Multiply>({convert_m, mul_const_m});
|
||||
auto mul_m = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul_with_sub_m, mul_no_sub_m});
|
||||
|
||||
auto transpose_const_m = wrap_type<ov::op::v0::Constant>();
|
||||
auto transpose_m = wrap_type<ov::op::v1::Transpose>({mul_m, transpose_const_m});
|
||||
auto weights_input_m = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{mul_m, transpose_m});
|
||||
|
||||
auto data_m = any_input();
|
||||
auto fully_connected_m = wrap_type<op::FullyConnected>({data_m, weights_input_m});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
OPENVINO_ASSERT(pattern_map.count(fully_connected_m));
|
||||
OPENVINO_ASSERT(pattern_map.count(mul_const_m));
|
||||
OPENVINO_ASSERT(pattern_map.count(weights_m));
|
||||
OPENVINO_ASSERT(pattern_map.count(convert_m));
|
||||
auto fc = std::dynamic_pointer_cast<op::FullyConnected>(pattern_map.at(fully_connected_m).get_node_shared_ptr());
|
||||
if (!fc || transformation_callback(fc)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& fc_input_a = fc->get_input_node_shared_ptr(0);
|
||||
const auto& scale = pattern_map.at(mul_const_m).get_node_shared_ptr();
|
||||
std::shared_ptr<ov::Node> optional_zero_point = nullptr;
|
||||
|
||||
ov::NodeVector nodes_to_copy_info{pattern_map.at(fully_connected_m).get_node_shared_ptr(),
|
||||
pattern_map.at(convert_m).get_node_shared_ptr()};
|
||||
if (pattern_map.count(mul_no_sub_m)) {
|
||||
nodes_to_copy_info.push_back(pattern_map.at(mul_no_sub_m).get_node_shared_ptr());
|
||||
}
|
||||
if (pattern_map.count(mul_with_sub_m)) {
|
||||
nodes_to_copy_info.push_back(pattern_map.at(mul_with_sub_m).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
const bool with_zero_point = pattern_map.count(subtract_m) > 0;
|
||||
if (with_zero_point) {
|
||||
optional_zero_point = pattern_map.at(sub_const_m).get_node_shared_ptr();
|
||||
nodes_to_copy_info.push_back(subtract_m);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> fc_input_b = pattern_map.at(weights_m).get_node_shared_ptr();
|
||||
if (pattern_map.count(transpose_m)) {
|
||||
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
|
||||
const auto& transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
|
||||
fc_input_b = transpose->clone_with_new_inputs({ fc_input_b->output(0), transpose_const });
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> new_fc = nullptr;
|
||||
if (with_zero_point) {
|
||||
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
|
||||
fc_input_b,
|
||||
scale,
|
||||
optional_zero_point,
|
||||
fc->get_output_type());
|
||||
} else {
|
||||
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
|
||||
fc_input_b,
|
||||
scale,
|
||||
fc->get_output_type());
|
||||
}
|
||||
|
||||
new_fc->set_friendly_name(fc->get_friendly_name());
|
||||
ov::copy_runtime_info(nodes_to_copy_info, new_fc);
|
||||
ov::replace_node(fc, new_fc);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(fully_connected_m);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
} // namespace intel_gpu
|
||||
} // namespace ov
|
@ -0,0 +1,19 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gpu {
|
||||
|
||||
class ConvertFullyConnectedToFullyConnectedCompressed: public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedCompressed", "0");
|
||||
ConvertFullyConnectedToFullyConnectedCompressed();
|
||||
};
|
||||
|
||||
} // namespace intel_gpu
|
||||
} // namespace ov
|
@ -0,0 +1,51 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "intel_gpu/op/fully_connected_compressed.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gpu {
|
||||
namespace op {
|
||||
|
||||
FullyConnectedCompressed::FullyConnectedCompressed(const ov::Output<Node>& A,
|
||||
const ov::Output<Node>& B,
|
||||
const ov::Output<Node>& decompression_scale,
|
||||
const ov::Output<Node>& decompression_zero_point,
|
||||
const ov::element::Type output_type)
|
||||
: FullyConnected(A, B, output_type) {
|
||||
set_argument(2, decompression_scale);
|
||||
set_argument(3, decompression_zero_point);
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
FullyConnectedCompressed::FullyConnectedCompressed(const ov::Output<Node>& A,
|
||||
const ov::Output<Node>& B,
|
||||
const ov::Output<Node>& decompression_scale,
|
||||
const ov::element::Type output_type)
|
||||
: FullyConnected(A, B, output_type) {
|
||||
set_argument(2, decompression_scale);
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> FullyConnectedCompressed::clone_with_new_inputs(const ov::OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
|
||||
if (new_args.size() == 3)
|
||||
return std::make_shared<FullyConnectedCompressed>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
m_output_type);
|
||||
else if (new_args.size() == 4)
|
||||
return std::make_shared<FullyConnectedCompressed>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
m_output_type);
|
||||
else
|
||||
OPENVINO_THROW("Unexpected inputs count for FullyConnectedCompressed op: ", new_args.size());
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace intel_gpu
|
||||
} // namespace ov
|
@ -35,6 +35,7 @@
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
#include "openvino/core/deprecated.hpp"
|
||||
|
||||
#include "openvino/pass/visualize_tree.hpp"
|
||||
#include "transformations/einsum_decomposition.hpp"
|
||||
#include "transformations/convert_pooling_to_reduce.hpp"
|
||||
#include "transformations/decompose_reduce_for_false_keepdims.hpp"
|
||||
@ -46,6 +47,7 @@
|
||||
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
|
||||
#include "transformations/resolve_names_collisions.hpp"
|
||||
|
||||
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
|
||||
#include "transformations/fp16_compression/convert_compression_only_to_legacy.hpp"
|
||||
#include "transformations/common_optimizations/common_optimizations.hpp"
|
||||
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
|
||||
@ -55,6 +57,7 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/softmax_fusion.hpp"
|
||||
#include "transformations/common_optimizations/mvn_fusion.hpp"
|
||||
#include "transformations/common_optimizations/compress_float_constants.hpp"
|
||||
|
||||
#include "transformations/op_conversions/convert_depth_to_space.hpp"
|
||||
#include "transformations/op_conversions/convert_space_to_depth.hpp"
|
||||
@ -106,6 +109,7 @@
|
||||
|
||||
#include "plugin/transformations/convert_matmul_to_fc.hpp"
|
||||
#include "plugin/transformations/move_fc_reshape_to_weights.hpp"
|
||||
#include "plugin/transformations/convert_fc_to_compressed.hpp"
|
||||
|
||||
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
|
||||
#include "low_precision/pull_reshape_through_dequantization.hpp"
|
||||
@ -147,6 +151,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
bool unroll_loop = config.get_property(ov::intel_gpu::enable_loop_unrolling);
|
||||
{
|
||||
ov::pass::Manager manager;
|
||||
auto pass_config = manager.get_pass_config();
|
||||
manager.set_per_pass_validation(false);
|
||||
|
||||
enableInt8 = config.get_property(ov::intel_gpu::enable_lp_transformations) && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(func);
|
||||
@ -213,6 +218,15 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
|
||||
manager.register_pass<ov::pass::MVN6Decomposition>();
|
||||
|
||||
auto is_matmul_output = [](const_node_ptr &node) -> bool {
|
||||
const auto outputs = node->get_output_target_inputs(0);
|
||||
return !is_type<ov::op::v0::MatMul>(outputs.begin()->get_node());
|
||||
};
|
||||
|
||||
manager.register_pass<ov::pass::KeepConstAndDecompression>();
|
||||
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8}, true);
|
||||
pass_config->set_callback<ov::pass::KeepConstAndDecompression>(is_matmul_output);
|
||||
|
||||
const bool keep_precision_sensitive_in_fp32_1 = true;
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(fp_convert_precision_map,
|
||||
empty_fuse_map,
|
||||
@ -269,7 +283,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
manager.register_pass<ov::pass::Validate>();
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(int_convert_precision_map);
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->disable<ov::pass::EyeDecomposition>();
|
||||
|
||||
// disable conversion to legacy and use the new mixed precision
|
||||
@ -614,6 +627,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::intel_gpu::ConvertMatMulToFullyConnected>();
|
||||
manager.register_pass<ov::intel_gpu::MoveFCReshapeToWeights>();
|
||||
manager.register_pass<ov::intel_gpu::ConvertFullyConnectedToFullyConnectedCompressed>();
|
||||
|
||||
manager.run_passes(func);
|
||||
}
|
||||
|
@ -2,7 +2,11 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "intel_gpu/runtime/internal_properties.hpp"
|
||||
#include "intel_gpu/runtime/layout.hpp"
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
#include "test_utils.h"
|
||||
#include "float16.h"
|
||||
#include "random_generator.hpp"
|
||||
#include "network_test.h"
|
||||
#include <intel_gpu/runtime/utils.hpp>
|
||||
@ -656,6 +660,172 @@ TEST(fully_connected_gpu, x_f32_relu) {
|
||||
ASSERT_EQ(0.00f, output_ptr[3]);
|
||||
}
|
||||
|
||||
TEST(fully_connected_gpu, compressed_scale_zp_bias) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input_mem = engine.allocate_memory({ {1, 2, 4}, data_types::f32, format::bfyx });
|
||||
auto weights_mem = engine.allocate_memory({ {8, 4}, data_types::f32, format::bfyx });
|
||||
auto bias_mem = engine.allocate_memory({ {1, 1, 8}, data_types::f32, format::bfyx });
|
||||
auto scale_mem = engine.allocate_memory({ {1, 1, 8}, data_types::f32, format::bfyx });
|
||||
auto zp_mem = engine.allocate_memory({ {1, 1, 8}, data_types::f32, format::bfyx });
|
||||
|
||||
set_values(input_mem, { -0.5f, 2.0f, 0.5f, 1.0f,
|
||||
0.5f, -2.0f, -0.5f, -1.0f });
|
||||
set_values(weights_mem, { 1.5f, 1.0f, 0.5f, -1.0f,
|
||||
0.0f, 0.5f, 0.5f, -0.5f,
|
||||
-2.0f, -0.5f, 1.0f, 1.5f,
|
||||
-2.0f, -0.5f, 1.0f, 1.5f,
|
||||
2.0f, 0.5f, -1.0f, -1.5f,
|
||||
2.0f, 0.5f, -1.0f, -1.5f,
|
||||
-1.5f, -1.0f, -0.5f, 1.0f,
|
||||
0.0f, -0.5f, 0.5f, 0.5f });
|
||||
|
||||
set_values(bias_mem, { 1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f, 7.0f, 2.0f });
|
||||
set_values(scale_mem, { 2.0f, 4.0f, -2.0f, -4.0f, 0.5f, -0.5f, 2.0f, 2.0f });
|
||||
set_values(zp_mem, { 1.0f, 2.0f, 2.0f, 1.0f, 4.0f, 1.0f, 6.0f, 2.0f });
|
||||
|
||||
topology topology(
|
||||
input_layout("input", input_mem->get_layout()),
|
||||
data("weights", weights_mem),
|
||||
data("bias", bias_mem),
|
||||
data("scale", scale_mem),
|
||||
data("zp", zp_mem),
|
||||
fully_connected("fc_prim", input_info("input"), "weights", "bias", "scale", "zp", data_types::f32, padding(), 3, 2)
|
||||
);
|
||||
|
||||
auto config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
|
||||
network network(engine, topology, config);
|
||||
network.set_input_data("input", input_mem);
|
||||
|
||||
auto outputs = network.execute();
|
||||
ASSERT_EQ(outputs.size(), size_t(1));
|
||||
ASSERT_EQ(outputs.begin()->first, "fc_prim");
|
||||
|
||||
auto output_mem = outputs.begin()->second.get_memory();
|
||||
|
||||
cldnn::mem_lock<float> output_ptr (output_mem, get_test_stream());
|
||||
|
||||
ov::PartialShape expected_shape{1, 2, 8};
|
||||
ASSERT_EQ(expected_shape, output_mem->get_layout().get_partial_shape());
|
||||
|
||||
std::vector<float> expected_result = {-4.0f, -23.0f, 11.0f, 0.0f, -2.0f, -3.5f, -30.0f, -10.5f,
|
||||
6.0f, 19.0f, -5.0f, -8.0f, 12.0f, -8.5f, 44.0f, 14.5f};
|
||||
|
||||
for (size_t i = 0; i < expected_result.size(); i++) {
|
||||
ASSERT_EQ(expected_result[i], output_ptr[i]) << "i = " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(fully_connected_gpu, compressed_scale_bias) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input_mem = engine.allocate_memory({ {1, 2, 4}, data_types::f32, format::bfyx });
|
||||
auto weights_mem = engine.allocate_memory({ {8, 4}, data_types::f32, format::bfyx });
|
||||
auto bias_mem = engine.allocate_memory({ {1, 1, 8}, data_types::f32, format::bfyx });
|
||||
auto scale_mem = engine.allocate_memory({ {1, 1, 8}, data_types::f32, format::bfyx });
|
||||
|
||||
set_values(input_mem, { -0.5f, 2.0f, 0.5f, 1.0f,
|
||||
0.5f, -2.0f, -0.5f, -1.0f });
|
||||
set_values(weights_mem, { 1.5f, 1.0f, 0.5f, -1.0f,
|
||||
0.0f, 0.5f, 0.5f, -0.5f,
|
||||
-2.0f, -0.5f, 1.0f, 1.5f,
|
||||
-2.0f, -0.5f, 1.0f, 1.5f,
|
||||
2.0f, 0.5f, -1.0f, -1.5f,
|
||||
2.0f, 0.5f, -1.0f, -1.5f,
|
||||
-1.5f, -1.0f, -0.5f, 1.0f,
|
||||
0.0f, -0.5f, 0.5f, 0.5f });
|
||||
|
||||
set_values(bias_mem, { 1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f, 7.0f, -8.0f });
|
||||
set_values(scale_mem, { 2.0f, 4.0f, -2.0f, -4.0f, 0.5f, -0.5f, 2.0f, 1.0f });
|
||||
|
||||
topology topology(
|
||||
input_layout("input", input_mem->get_layout()),
|
||||
data("weights", weights_mem),
|
||||
data("bias", bias_mem),
|
||||
data("scale", scale_mem),
|
||||
fully_connected("fc_prim", input_info("input"), "weights", "bias", "scale", "", data_types::f32, padding(), 3, 2)
|
||||
);
|
||||
|
||||
auto config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
|
||||
network network(engine, topology, config);
|
||||
network.set_input_data("input", input_mem);
|
||||
|
||||
auto outputs = network.execute();
|
||||
ASSERT_EQ(outputs.size(), size_t(1));
|
||||
ASSERT_EQ(outputs.begin()->first, "fc_prim");
|
||||
|
||||
auto output_mem = outputs.begin()->second.get_memory();
|
||||
|
||||
cldnn::mem_lock<float> output_ptr (output_mem, get_test_stream());
|
||||
|
||||
ov::PartialShape expected_shape{1, 2, 8};
|
||||
ASSERT_EQ(expected_shape, output_mem->get_layout().get_partial_shape());
|
||||
|
||||
std::vector<float> expected_result = {2.0f, 1.0f, -1.0f, -12.0f, 4.0f, -5.0f, 6.0f, -8.25f,
|
||||
0.0f, -5.0f, 7.0f, 4.0f, 6.0f, -7.0f, 8.0f, -7.75f};
|
||||
|
||||
for (size_t i = 0; i < expected_result.size(); i++) {
|
||||
ASSERT_EQ(expected_result[i], output_ptr[i]) << "i = " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(fully_connected_gpu, compressed_scale_fp16) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input_mem = engine.allocate_memory({ { 2, 4}, data_types::f16, format::bfyx });
|
||||
auto weights_mem = engine.allocate_memory({ {8, 4}, data_types::f16, format::bfyx });
|
||||
auto scale_mem = engine.allocate_memory({ {1, 8}, data_types::f16, format::bfyx });
|
||||
|
||||
set_values<FLOAT16>(input_mem, { FLOAT16(-0.5f), FLOAT16(2.0f), FLOAT16(0.5f), FLOAT16(1.0f),
|
||||
FLOAT16(0.5f), FLOAT16(-2.0f), FLOAT16(-0.5f), FLOAT16(-1.0f) });
|
||||
set_values<FLOAT16>(weights_mem, {FLOAT16( 1.5f), FLOAT16( 1.0f), FLOAT16( 0.5f), FLOAT16(-1.0f),
|
||||
FLOAT16( 0.0f), FLOAT16( 0.5f), FLOAT16( 0.5f), FLOAT16(-0.5f),
|
||||
FLOAT16(-2.0f), FLOAT16(-0.5f), FLOAT16( 1.0f), FLOAT16( 1.5f),
|
||||
FLOAT16(-2.0f), FLOAT16(-0.5f), FLOAT16( 1.0f), FLOAT16( 1.5f),
|
||||
FLOAT16( 2.0f), FLOAT16( 0.5f), FLOAT16(-1.0f), FLOAT16(-1.5f),
|
||||
FLOAT16( 2.0f), FLOAT16( 0.5f), FLOAT16(-1.0f), FLOAT16(-1.5f),
|
||||
FLOAT16(-1.5f), FLOAT16(-1.0f), FLOAT16(-0.5f), FLOAT16( 1.0f),
|
||||
FLOAT16( 0.0f), FLOAT16(-0.5f), FLOAT16(0.5f), FLOAT16( 0.5f) });
|
||||
|
||||
set_values<FLOAT16>(scale_mem, {FLOAT16(2.0f), FLOAT16(4.0f), FLOAT16(-2.0f), FLOAT16(-4.0f), FLOAT16(0.5f), FLOAT16(-0.5f), FLOAT16(2.0f), FLOAT16(2.0f)});
|
||||
|
||||
topology topology(
|
||||
input_layout("input", input_mem->get_layout()),
|
||||
data("weights", weights_mem),
|
||||
data("scale", scale_mem),
|
||||
fully_connected("fc_prim", input_info("input"), "weights", "", "scale", "", data_types::f32, padding(), 2, 2)
|
||||
);
|
||||
|
||||
auto config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
|
||||
network network(engine, topology, config);
|
||||
network.set_input_data("input", input_mem);
|
||||
|
||||
auto outputs = network.execute();
|
||||
ASSERT_EQ(outputs.size(), size_t(1));
|
||||
ASSERT_EQ(outputs.begin()->first, "fc_prim");
|
||||
|
||||
auto output_mem = outputs.begin()->second.get_memory();
|
||||
|
||||
cldnn::mem_lock<FLOAT16> output_ptr (output_mem, get_test_stream());
|
||||
|
||||
ov::PartialShape expected_shape{2, 8};
|
||||
ASSERT_EQ(expected_shape, output_mem->get_layout().get_partial_shape());
|
||||
|
||||
std::vector<FLOAT16> expected_result = {
|
||||
FLOAT16(1.0f), FLOAT16( 3.0f), FLOAT16(-4.0f), FLOAT16(-8.0f), FLOAT16(-1.0f), FLOAT16( 1.0f), FLOAT16(-1.0f), FLOAT16(-0.5f),
|
||||
FLOAT16(-1.0f), FLOAT16(-3.0f), FLOAT16( 4.0f), FLOAT16( 8.0f), FLOAT16( 1.0f), FLOAT16(-1.0f), FLOAT16( 1.0f), FLOAT16( 0.5f)};
|
||||
|
||||
for (size_t i = 0; i < expected_result.size(); i++) {
|
||||
ASSERT_FLOAT_EQ(expected_result[i], output_ptr[i]) << "i = " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(fully_connected_gpu, x_f32_relu_with_negative_slope) {
|
||||
// Input : 3x1
|
||||
// Output : 4x1
|
||||
|
@ -71,8 +71,8 @@ public:
|
||||
const auto primitive_hash = primitve->hash();
|
||||
const auto params_hash = primitve->type->get_fake_aligned_params(*prim_inst->get_impl_params()).hash();
|
||||
|
||||
ASSERT_EQ(primitive_hash, 2197080758510296176UL);
|
||||
ASSERT_EQ(params_hash, 4714860879383010855UL);
|
||||
ASSERT_EQ(primitive_hash, 6924775129729406941UL);
|
||||
ASSERT_EQ(params_hash, 8552673460001178483UL);
|
||||
}
|
||||
|
||||
void test_gather_basic(bool is_caching_test) {
|
||||
|
Loading…
Reference in New Issue
Block a user