[CPU] SDPA supports multi-query and different input layout (#21513)
This commit is contained in:
parent
eff9ba76ba
commit
17fb201433
@ -139,7 +139,11 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
|
||||
auto q_len = query.size(2);
|
||||
auto S = query.size(3);
|
||||
auto kv_len = present_key.size(2);
|
||||
|
||||
auto h_group_num = present_key.size(1);
|
||||
size_t h_each_group_len = 1;
|
||||
if (h_group_num != H) {
|
||||
h_each_group_len = H / h_group_num;
|
||||
}
|
||||
if (d_scale == 0.0f)
|
||||
d_scale = 1.0f / sqrt(S);
|
||||
|
||||
@ -149,20 +153,21 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
|
||||
|
||||
bool is_abcd = present_key.stride(1) >= present_key.stride(2);
|
||||
size_t dim0 = is_abcd ? B : kv_len;
|
||||
size_t dim1 = is_abcd ? H : B;
|
||||
size_t dim2 = is_abcd ? kv_len : H;
|
||||
size_t dim1 = is_abcd ? h_group_num : B;
|
||||
size_t dim2 = is_abcd ? kv_len : h_group_num;
|
||||
|
||||
parallel_for3d(dim0, dim1, dim2, [&](size_t d0, size_t d1, size_t d2) {
|
||||
size_t b = is_abcd ? d0 : d1;
|
||||
size_t h = is_abcd ? d1 : d2;
|
||||
size_t h_group = is_abcd ? d1 : d2;
|
||||
size_t pk = is_abcd ? d2 : d0;
|
||||
|
||||
// which batch item should be used at postion pk?
|
||||
auto b_kv = beams ? beams.at<int32_t>({b, pk}) : b;
|
||||
for (size_t pq = 0; pq < q_len; pq++) {
|
||||
buf_attn_w.at<float>({b, h, pq, pk}) = dot_product(&query.at<T>({b, h, pq, 0}),
|
||||
&present_key.at<T2>({b_kv, h, pk, 0}, true),
|
||||
S);
|
||||
for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) {
|
||||
buf_attn_w.at<float>({b, h, pq, pk}) =
|
||||
dot_product(&query.at<T>({b, h, pq, 0}), &present_key.at<T2>({b_kv, h_group, pk, 0}, true), S);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@ -190,29 +195,31 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
|
||||
// buf_attn_w {B, H, q_len, kv_len}
|
||||
parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) {
|
||||
size_t start{0}, end{0};
|
||||
splitter(B * H * kv_len, nthr, ithr, start, end);
|
||||
splitter(B * h_group_num * kv_len, nthr, ithr, start, end);
|
||||
|
||||
memset(&buf_attn_score.at<float>({ithr, 0, 0, 0, 0}), 0, buf_attn_score.stride(0) * sizeof(float));
|
||||
|
||||
size_t b, h, pv;
|
||||
size_t b, h_group, pv;
|
||||
if (start < end) {
|
||||
if (is_abcd)
|
||||
parallel_it_init(start, b, B, h, H, pv, kv_len);
|
||||
parallel_it_init(start, b, B, h_group, h_group_num, pv, kv_len);
|
||||
else
|
||||
parallel_it_init(start, pv, kv_len, b, B, h, H);
|
||||
parallel_it_init(start, pv, kv_len, b, B, h_group, h_group_num);
|
||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||
auto b_kv = beams ? beams.at<int32_t>({b, pv}) : b;
|
||||
auto* v = &present_value.at<T2>({b_kv, h, pv, 0}, true);
|
||||
auto* v = &present_value.at<T2>({b_kv, h_group, pv, 0}, true);
|
||||
for (size_t pq = 0; pq < q_len; pq++) {
|
||||
attn_acc_value(&buf_attn_score.at<float>({ithr, b, pq, h, 0}),
|
||||
buf_attn_w.at<float>({b, h, pq, pv}),
|
||||
v,
|
||||
S);
|
||||
for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) {
|
||||
attn_acc_value(&buf_attn_score.at<float>({ithr, b, pq, h, 0}),
|
||||
buf_attn_w.at<float>({b, h, pq, pv}),
|
||||
v,
|
||||
S);
|
||||
}
|
||||
}
|
||||
if (is_abcd)
|
||||
parallel_it_step(b, B, h, H, pv, kv_len);
|
||||
parallel_it_step(b, B, h_group, h_group_num, pv, kv_len);
|
||||
else
|
||||
parallel_it_step(pv, kv_len, b, B, h, H);
|
||||
parallel_it_step(pv, kv_len, b, B, h_group, h_group_num);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -548,6 +548,7 @@ void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {
|
||||
|
||||
// Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout
|
||||
precision = SDPA->getOriginalInputPrecisionAtPort(childPort);
|
||||
// Just used a place holder here, the actual layout is obtained at initOptimalPrimitiveDescriptor
|
||||
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
|
||||
|
||||
PortConfig outPortConfig;
|
||||
@ -573,7 +574,6 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
|
||||
"failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
const auto& childConfig = childPd->getConfig();
|
||||
auto childPrecision = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc()->getPrecision();
|
||||
|
||||
auto selectedPd = getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(selectedPd,
|
||||
@ -582,8 +582,9 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
|
||||
" failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
auto config = selectedPd->getConfig();
|
||||
auto memDesc = config.outConfs.front().getMemDesc();
|
||||
auto newMemDesc = memDesc->cloneWithNewPrecision(childPrecision);
|
||||
// The pyscial layout varies from models, e.g. [LBHS]chatglm, [BHLS]Llama
|
||||
// The SDPA knows details, so should trust the layout config provided by SPDA
|
||||
auto newMemDesc = childConfig.inConfs.back().getMemDesc();
|
||||
config.outConfs.front().setMemDesc(newMemDesc);
|
||||
//bypass any checks, we enforce the child descriptor precision
|
||||
selectedPd->setConfig(config);
|
||||
|
@ -512,10 +512,16 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
||||
v_input.assert_dims({B, 0, L1, S}, true);
|
||||
auto past_k_idx = inputs.size() - 2;
|
||||
auto past_k_mem = inputs[past_k_idx + 0];
|
||||
L0 = past_k_mem->getStaticDims()[2];
|
||||
const auto& permute_axes = config.config.permute_axes;
|
||||
L0 = permute_axes.empty() ? past_k_mem->getStaticDims()[2] : past_k_mem->getStaticDims()[permute_axes[2]];
|
||||
// [B, H, L0, S]
|
||||
past_k_output.reset(outputs[1]);
|
||||
past_v_output.reset(outputs[2]);
|
||||
if (!permute_axes.empty()) {
|
||||
// [L, B, H, S] -> [B, H, L, S]
|
||||
past_k_output = past_k_output.permute(permute_axes);
|
||||
past_v_output = past_v_output.permute(permute_axes);
|
||||
}
|
||||
attn_memcpy(k_input, v_input, past_k_output.slice(2, L0, L0 + L1), past_v_output.slice(2, L0, L0 + L1));
|
||||
if (!config.is_concat_inplaced) {
|
||||
PlainTensor past_k_input, past_v_input;
|
||||
@ -560,12 +566,18 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
||||
}
|
||||
|
||||
// q: [B, H, L1, S]
|
||||
const auto & permute_axes = config.config.permute_axes;
|
||||
|
||||
PlainTensor present_key, present_value;
|
||||
if (!permute_axes.empty()) {
|
||||
q_input = q_input.permute(permute_axes);
|
||||
k_input = k_input.permute(permute_axes);
|
||||
v_input = v_input.permute(permute_axes);
|
||||
}
|
||||
B = q_input.size(0);
|
||||
H = q_input.size(1);
|
||||
L1 = q_input.size(2);
|
||||
S = q_input.size(-1);
|
||||
|
||||
PlainTensor present_key, present_value;
|
||||
concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value);
|
||||
|
||||
ov::intel_cpu::PlainTensor output_emb(outputs[0]);
|
||||
@ -634,9 +646,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ngrap
|
||||
void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
rtPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0);
|
||||
|
||||
size_t H_idx = 1;
|
||||
if (!m_config.config.permute_axes.empty()) {
|
||||
H_idx = m_config.config.permute_axes[1];
|
||||
}
|
||||
const auto& qDims = getInputShapeAtPort(0).getDims();
|
||||
const auto& kDims = getInputShapeAtPort(1).getDims();
|
||||
// if multi-query, enforce fp32 TODO: support BF16
|
||||
if (qDims[H_idx] != kDims[H_idx]) {
|
||||
rtPrecision = ov::element::f32;
|
||||
}
|
||||
|
||||
bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && rtPrecision != ov::element::bf16;
|
||||
|
||||
auto kvCachePrecision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision;
|
||||
@ -669,17 +692,25 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
|
||||
}
|
||||
|
||||
if (m_config.config.fuse_concat) {
|
||||
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
|
||||
|
||||
config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
ArbitraryOrderDescCreator layoutDescCreator({2, 0, 1, 3});
|
||||
const auto& permute_axes = m_config.config.permute_axes;
|
||||
if (!permute_axes.empty()) {
|
||||
// [L,B,H,S]->permute[1,2,0,3] ->[B,H,L,S]
|
||||
// The actual index of B is permute[0], H is permute[1], L is permute[2], S is permute[3]
|
||||
layoutDescCreator = ArbitraryOrderDescCreator({static_cast<size_t>(permute_axes[2]),
|
||||
static_cast<size_t>(permute_axes[0]),
|
||||
static_cast<size_t>(permute_axes[1]),
|
||||
static_cast<size_t>(permute_axes[3])});
|
||||
}
|
||||
config.inConfs[orginSDPInputNumber + 0].setMemDesc(layoutDescCreator.createSharedDesc(
|
||||
kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
|
||||
config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
config.inConfs[orginSDPInputNumber + 1].setMemDesc(layoutDescCreator.createSharedDesc(
|
||||
kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));
|
||||
|
||||
config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
config.outConfs[1].setMemDesc(layoutDescCreator.createSharedDesc(
|
||||
kvCachePrecision, getOutputShapeAtPort(1)));
|
||||
config.outConfs[1].inPlace(orginSDPInputNumber + 0);
|
||||
config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
config.outConfs[2].setMemDesc(layoutDescCreator.createSharedDesc(
|
||||
kvCachePrecision, getOutputShapeAtPort(2)));
|
||||
config.outConfs[2].inPlace(orginSDPInputNumber + 1);
|
||||
}
|
||||
@ -712,7 +743,6 @@ void ScaledDotProductAttention::createPrimitive() {
|
||||
|
||||
m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0;
|
||||
}
|
||||
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
|
||||
if (rtPrecision == ov::element::bf16) {
|
||||
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);
|
||||
|
@ -54,6 +54,7 @@ private:
|
||||
Config m_config;
|
||||
std::shared_ptr<Executor> m_executor;
|
||||
template <KernelTypes KType, typename T> struct AttentionExecutor;
|
||||
ov::element::Type rtPrecision;
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
|
@ -29,18 +29,43 @@ void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_typ
|
||||
// [B, H, L0, S]
|
||||
auto past_kv_ps = get_input_partial_shape(input_num - 1);
|
||||
|
||||
auto output_logits = q_ps;
|
||||
NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false);
|
||||
NODE_VALIDATION_CHECK(this, q_ps.size() >= 3);
|
||||
// permute_axes from original to [B, H, L, S]
|
||||
const auto& permute_axes = this->m_config.permute_axes;
|
||||
if (past_kv_ps.rank().is_static()) {
|
||||
const size_t length_index = permute_axes.empty() ? q_ps.size() - 2 : permute_axes[permute_axes.size() - 2];
|
||||
const size_t head_num_index = permute_axes.empty() ? q_ps.size() - 3 : permute_axes[permute_axes.size() - 3];
|
||||
NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size());
|
||||
for (size_t i = 0; i < q_ps.size(); i++) {
|
||||
if (i == q_ps.size() - 2)
|
||||
if (i == head_num_index) {
|
||||
if (q_ps[i].is_static() && past_kv_ps[i].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
q_ps[i].get_length() % past_kv_ps[i].get_length() == 0,
|
||||
"shape not compatiable at index ",
|
||||
i);
|
||||
}
|
||||
} else if (i == length_index) {
|
||||
continue;
|
||||
NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i]));
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
q_ps[i].compatible(past_kv_ps[i]),
|
||||
"shape not compatiable at index ",
|
||||
i);
|
||||
}
|
||||
}
|
||||
past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2];
|
||||
past_kv_ps[length_index] += q_ps[length_index];
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), q_ps);
|
||||
if (!permute_axes.empty()) {
|
||||
if (q_ps.rank().is_static()) {
|
||||
// q_ps needs permute to BHLS
|
||||
for (size_t i = 0; i < q_ps.size(); i++) {
|
||||
output_logits[i] = q_ps[permute_axes[i]];
|
||||
}
|
||||
}
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_logits);
|
||||
set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps);
|
||||
set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps);
|
||||
}
|
||||
@ -52,6 +77,7 @@ bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::A
|
||||
visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn);
|
||||
visitor.on_attribute("is_causal", m_config.is_causal);
|
||||
visitor.on_attribute("fuse_concat", m_config.fuse_concat);
|
||||
visitor.on_attribute("permute_axes", m_config.permute_axes);
|
||||
visitor.finish_structure();
|
||||
return true;
|
||||
}
|
||||
}
|
@ -21,11 +21,13 @@ public:
|
||||
ScaledDotProductAttentionWithKVCache() = default;
|
||||
|
||||
struct Config {
|
||||
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
|
||||
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
|
||||
|
||||
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
|
||||
bool is_causal = false; // apply causal mask internally
|
||||
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
|
||||
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
|
||||
bool is_causal = false; // apply causal mask internally
|
||||
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
|
||||
std::vector<size_t> permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S]
|
||||
// e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S]
|
||||
};
|
||||
|
||||
ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg);
|
||||
@ -47,4 +49,4 @@ private:
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
} // namespace ov
|
@ -0,0 +1,176 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "stateful_transpose_sdpa_fusion.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <openvino/core/rt_info.hpp>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <openvino/opsets/opset13.hpp>
|
||||
#include <openvino/opsets/opset6.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ov_ops/type_relaxed.hpp"
|
||||
#include "transformations/cpu_opset/common/op/sdpa.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
StatefulTransposeSDPAFusion::StatefulTransposeSDPAFusion() {
|
||||
MATCHER_SCOPE(StatefulTransposeSDPAFusion);
|
||||
using namespace ov::pass::pattern;
|
||||
|
||||
auto past_k = wrap_type<opset6::ReadValue>();
|
||||
auto past_v = wrap_type<opset6::ReadValue>();
|
||||
auto convert_past_k = wrap_type<opset1::Convert>({past_k});
|
||||
auto convert_past_v = wrap_type<opset1::Convert>({past_v});
|
||||
auto concat_input_k = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_k, convert_past_k});
|
||||
auto concat_input_v = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_v, convert_past_v});
|
||||
auto concat_k = wrap_type<opset6::Concat>({concat_input_k, any_input()});
|
||||
auto concat_v = wrap_type<opset6::Concat>({concat_input_v, any_input()});
|
||||
|
||||
// multi-query branch
|
||||
auto reshape_k = wrap_type<opset6::Reshape>({concat_k, any_input()});
|
||||
auto reshape_v = wrap_type<opset6::Reshape>({concat_v, any_input()});
|
||||
auto constant_k = wrap_type<opset6::Constant>();
|
||||
auto constant_v = wrap_type<opset6::Constant>();
|
||||
auto multiply_k = wrap_type<opset6::Multiply>({reshape_k, constant_k});
|
||||
auto multiply_v = wrap_type<opset6::Multiply>({reshape_v, constant_v});
|
||||
auto reshape1_k = wrap_type<opset6::Reshape>({multiply_k, any_input()});
|
||||
auto reshape1_v = wrap_type<opset6::Reshape>({multiply_v, any_input()});
|
||||
|
||||
auto transpose_k_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{reshape1_k, concat_k});
|
||||
auto transpose_v_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{reshape1_v, concat_v});
|
||||
auto order_k = wrap_type<opset6::Constant>();
|
||||
auto order_v = wrap_type<opset6::Constant>();
|
||||
auto transpose_k = wrap_type<opset6::Transpose>({transpose_k_input, order_k});
|
||||
auto transpose_v = wrap_type<opset6::Transpose>({transpose_v_input, order_v});
|
||||
|
||||
auto order_q = wrap_type<opset6::Constant>();
|
||||
auto q_input = any_input();
|
||||
auto transpose_q = wrap_type<opset6::Transpose>({q_input, order_q});
|
||||
auto sdp0 = wrap_type<opset13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v});
|
||||
auto sdp1 = wrap_type<opset13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v, any_input()});
|
||||
auto sdp2 = wrap_type<opset13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v, any_input(), any_input()});
|
||||
auto sdp = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{sdp0, sdp1, sdp2});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
|
||||
auto present_to = out.get_target_inputs();
|
||||
if (present_to.size() != 2)
|
||||
return;
|
||||
for (auto& to : present_to) {
|
||||
auto to_node = to.get_node();
|
||||
if (auto convert = dynamic_cast<opset1::Convert*>(to_node)) {
|
||||
auto cvt_targets = convert->get_output_target_inputs(0);
|
||||
if (cvt_targets.size() == 1) {
|
||||
to_node = cvt_targets.begin()->get_node();
|
||||
cvt = convert;
|
||||
}
|
||||
}
|
||||
assign = dynamic_cast<opset6::Assign*>(to_node);
|
||||
if (assign)
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<opset1::Convert> read_cvt_k_node, read_cvt_v_node;
|
||||
const auto sdp_node = ov::as_type_ptr<opset13::ScaledDotProductAttention>(root);
|
||||
const auto past_k_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_k).get_node_shared_ptr());
|
||||
const auto past_v_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_v).get_node_shared_ptr());
|
||||
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
|
||||
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());
|
||||
if (pattern_map.count(convert_past_k)) {
|
||||
read_cvt_k_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_k).get_node_shared_ptr());
|
||||
read_cvt_v_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_v).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
// check broadcast arg has all ones
|
||||
auto check_bcst = [&](const std::shared_ptr<Node>& ptr) {
|
||||
const auto constant_node = ov::as_type_ptr<opset6::Constant>(ptr);
|
||||
const auto& bcst_arg = constant_node->cast_vector<float>();
|
||||
return std::all_of(bcst_arg.begin(), bcst_arg.end(), [](int i) {
|
||||
return i == 1.0;
|
||||
});
|
||||
};
|
||||
|
||||
if (pattern_map.count(constant_k)) {
|
||||
if (!check_bcst(pattern_map.at(constant_k).get_node_shared_ptr()))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (pattern_map.count(constant_v)) {
|
||||
if (!check_bcst(pattern_map.at(constant_v).get_node_shared_ptr()))
|
||||
return false;
|
||||
}
|
||||
|
||||
opset6::Assign* assign_k_node = nullptr, *assign_v_node = nullptr;
|
||||
opset1::Convert* assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
|
||||
find_assign(concat_k_node, assign_k_node, assign_cvt_k_node);
|
||||
if (!assign_k_node)
|
||||
return false;
|
||||
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
|
||||
return false;
|
||||
|
||||
find_assign(concat_v_node, assign_v_node, assign_cvt_v_node);
|
||||
if (!assign_v_node)
|
||||
return false;
|
||||
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
|
||||
return false;
|
||||
auto args = sdp_node->input_values();
|
||||
args[0] = pattern_map.at(q_input).get_node_shared_ptr()->output(0);
|
||||
args[1] = concat_k_node->input_value(1);
|
||||
args[2] = concat_v_node->input_value(1);
|
||||
args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0));
|
||||
args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0));
|
||||
ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config;
|
||||
|
||||
const auto order_q_node = ov::as_type_ptr<opset6::Constant>(pattern_map.at(order_q).get_node_shared_ptr());
|
||||
const auto order_k_node = ov::as_type_ptr<opset6::Constant>(pattern_map.at(order_k).get_node_shared_ptr());
|
||||
const auto order_v_node = ov::as_type_ptr<opset6::Constant>(pattern_map.at(order_v).get_node_shared_ptr());
|
||||
|
||||
const auto& permute_q = order_q_node->cast_vector<int32_t>();
|
||||
const auto& permute_k = order_k_node->cast_vector<int32_t>();
|
||||
const auto& permute_v = order_v_node->cast_vector<int32_t>();
|
||||
if (permute_q != permute_k || permute_q != permute_v) {
|
||||
return false;
|
||||
}
|
||||
|
||||
config.is_causal = sdp_node->get_causal();
|
||||
config.fuse_concat = true;
|
||||
|
||||
config.permute_axes.resize(permute_q.size());
|
||||
for (size_t i = 0; i < permute_q.size(); i++) {
|
||||
config.permute_axes[i] = static_cast<size_t>(permute_q[i]);
|
||||
}
|
||||
auto& old_node = sdp_node;
|
||||
auto new_node = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionWithKVCache>(args, config);
|
||||
new_node->set_friendly_name(old_node->get_friendly_name());
|
||||
ov::replace_node(old_node, {new_node->output(0)});
|
||||
if (assign_cvt_k_node)
|
||||
assign_cvt_k_node->set_arguments({new_node->output(1)});
|
||||
else
|
||||
assign_k_node->set_arguments({new_node->output(1)});
|
||||
|
||||
if (assign_cvt_v_node)
|
||||
assign_cvt_v_node->set_arguments({new_node->output(2)});
|
||||
else
|
||||
assign_v_node->set_arguments({new_node->output(2)});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(sdp, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,18 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
class StatefulTransposeSDPAFusion : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("StatefulTransposeSDPAFusion", "0");
|
||||
StatefulTransposeSDPAFusion();
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -114,6 +114,7 @@
|
||||
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/rope_fusion.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.hpp"
|
||||
|
||||
// Snippets
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
@ -661,6 +662,7 @@ void Transformations::PostLpt() {
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion);
|
||||
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPAFusion);
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulTransposeSDPAFusion);
|
||||
postLPTPassManager.run_passes(model);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,297 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/opsets/opset13.hpp>
|
||||
#include <transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp>
|
||||
|
||||
#include "ov_models/builders.hpp"
|
||||
#include "ov_models/utils/ov_helpers.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
|
||||
|
||||
using namespace ov::test;
|
||||
using namespace ngraph;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using InputShapeAndTransposeOrder = std::pair<std::vector<InputShape>, std::vector<size_t>>;
|
||||
using ConcatMultiQuerySDPParams = std::tuple<ElementType,
|
||||
InputShapeAndTransposeOrder,
|
||||
bool // has ShapeOf
|
||||
>;
|
||||
// Subgraph:
|
||||
/* Parameter
|
||||
* |
|
||||
* Parameter ReadValue | ReadValue Parameter
|
||||
* \ / | \ /
|
||||
* \ / | \ /
|
||||
* Concat Transpose Concat
|
||||
* / \ | / \
|
||||
* / \ | / \
|
||||
* / MultiQuery | MultiQuery \
|
||||
* / \ | / \
|
||||
* / Transpose | Transpose \
|
||||
* / \ | / \
|
||||
* Assign ScaledDotProductAttention Assign
|
||||
* |
|
||||
* Tranpose
|
||||
* |
|
||||
* Reshape
|
||||
* |
|
||||
* Add
|
||||
* |
|
||||
* Result
|
||||
*/
|
||||
|
||||
class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQuerySDPParams>,
|
||||
virtual public ov::test::SubgraphBaseTest,
|
||||
public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<ConcatMultiQuerySDPParams>& obj) {
|
||||
ElementType inType;
|
||||
InputShapeAndTransposeOrder inputShapeAndOrders;
|
||||
bool hasShapeof;
|
||||
std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param;
|
||||
std::ostringstream result;
|
||||
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
|
||||
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
|
||||
result << "IS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << ov::test::utils::partialShape2str({shape.first}) << "_";
|
||||
}
|
||||
result << "TS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << "(";
|
||||
if (!shape.second.empty()) {
|
||||
for (const auto& itr : shape.second) {
|
||||
result << ov::test::utils::vec2str(itr);
|
||||
}
|
||||
}
|
||||
result << ")_";
|
||||
}
|
||||
result << "Prc=" << inType << "_";
|
||||
result << "HasShapeOf=" << hasShapeof;
|
||||
result << "TransposeOrder=";
|
||||
result << "(";
|
||||
for (const auto& itr : transposeOrder) {
|
||||
result << itr << ",";
|
||||
}
|
||||
result << ")";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
InputShapeAndTransposeOrder inputShapeAndOrders;
|
||||
bool hasShapeOf;
|
||||
ElementType inType;
|
||||
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
|
||||
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
|
||||
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
rel_threshold = 1e-2f;
|
||||
configuration[ov::hint::inference_precision.name()] = ov::element::f32;
|
||||
if (inType == ElementType::bf16) {
|
||||
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
|
||||
rel_threshold = 0.01f;
|
||||
}
|
||||
init_input_shapes(inputShapes);
|
||||
ov::ParameterVector inputParams;
|
||||
// q,k,v
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||
inputParams[0]->set_friendly_name("q");
|
||||
inputParams[1]->set_friendly_name("k");
|
||||
inputParams[2]->set_friendly_name("v");
|
||||
// pastkv init_cost
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
|
||||
auto var_k = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastk"});
|
||||
auto pastk = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_k);
|
||||
pastk->set_friendly_name("pastk_r");
|
||||
auto var_v = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastv"});
|
||||
auto pastv = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_v);
|
||||
pastv->set_friendly_name("pastv_r");
|
||||
std::shared_ptr<Node> pastk_shapeof, pastv_shapeof;
|
||||
if (hasShapeOf) {
|
||||
pastk_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastk);
|
||||
pastv_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastv);
|
||||
}
|
||||
|
||||
// pre SDPA transpose
|
||||
auto preOrder = op::v0::Constant::create(ov::element::i32, {4}, transposeOrder);
|
||||
auto transposeQ = std::make_shared<ov::op::v1::Transpose>(inputParams[0], preOrder);
|
||||
|
||||
auto concat_axis = transposeOrder[2];
|
||||
auto concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, inputParams[1]}, concat_axis);
|
||||
auto concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, inputParams[2]}, concat_axis);
|
||||
|
||||
auto unsquezeAxis = op::v0::Constant::create(ov::element::i32, {}, {-2});
|
||||
auto unsqueezeK = std::make_shared<ov::op::v0::Unsqueeze>(concatK, unsquezeAxis);
|
||||
auto unsqueezeV = std::make_shared<ov::op::v0::Unsqueeze>(concatV, unsquezeAxis);
|
||||
|
||||
auto targetShape = op::v0::Constant::create(inType, {1, 1, 1, 4, 1}, {1});
|
||||
auto broadcastK = std::make_shared<ov::op::v1::Multiply>(unsqueezeK, targetShape);
|
||||
auto broadcastV = std::make_shared<ov::op::v1::Multiply>(unsqueezeV, targetShape);
|
||||
|
||||
auto target4D = op::v0::Constant::create(ov::element::i32, {4}, {0, 0, 8, 64});
|
||||
|
||||
auto reshapeK = std::make_shared<ov::op::v1::Reshape>(broadcastK, target4D, true);
|
||||
auto reshapeV = std::make_shared<ov::op::v1::Reshape>(broadcastV, target4D, true);
|
||||
|
||||
auto transposeK = std::make_shared<ov::op::v1::Transpose>(reshapeK, preOrder);
|
||||
auto transposeV = std::make_shared<ov::op::v1::Transpose>(reshapeV, preOrder);
|
||||
|
||||
auto sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(transposeQ, transposeK, transposeV, false);
|
||||
sdp->set_friendly_name("mha");
|
||||
|
||||
// post SDPA transpose + reshape
|
||||
auto get_reshape_order = [](const ov::PartialShape& qkv_shape,
|
||||
const std::vector<size_t>& transposeOrder) -> std::vector<size_t> {
|
||||
assert(transposeOrder.size() == 4);
|
||||
auto H = qkv_shape[transposeOrder[1]].get_length();
|
||||
auto S = qkv_shape[transposeOrder[3]].get_length();
|
||||
return std::vector<size_t>{0, 0, static_cast<size_t>(H * S)};
|
||||
};
|
||||
const auto reshapeOrder = get_reshape_order(inputDynamicShapes[0], transposeOrder);
|
||||
|
||||
auto postOrder =
|
||||
ov::op::v0::Constant::create(ov::element::i32, {4}, std::vector<size_t>{0, 2, 1, 3}); // BHLS -> BLHS
|
||||
auto transposeSDP = std::make_shared<ov::op::v1::Transpose>(sdp, postOrder);
|
||||
|
||||
auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {3}, reshapeOrder);
|
||||
auto reshapeSDP = std::make_shared<ov::op::v1::Reshape>(transposeSDP, constReshape, true); // BLHS -> B,L,HxS
|
||||
|
||||
auto add = std::make_shared<ov::op::v1::Add>(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f}));
|
||||
auto pastk_assign = std::make_shared<ov::op::v6::Assign>(concatK, var_k);
|
||||
auto pastv_assign = std::make_shared<ov::op::v6::Assign>(concatV, var_v);
|
||||
pastk_assign->set_friendly_name("pastk_w");
|
||||
pastv_assign->set_friendly_name("pastv_w");
|
||||
|
||||
ov::OutputVector results{add};
|
||||
if (hasShapeOf) {
|
||||
results.push_back(pastk_shapeof);
|
||||
results.push_back(pastv_shapeof);
|
||||
}
|
||||
SinkVector sinks{pastk_assign, pastv_assign};
|
||||
function = std::make_shared<Function>(results, sinks, inputParams, "ConcatTranposeSDP");
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
|
||||
functionRefs = function->clone();
|
||||
pass::Manager manager;
|
||||
// decompose ScaledDotProductAttention
|
||||
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
|
||||
manager.run_passes(functionRefs);
|
||||
}
|
||||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||
std::vector<ov::Shape> shapes(4);
|
||||
shapes[0] = targetInputStaticShapes[0];
|
||||
shapes[1] = targetInputStaticShapes[1];
|
||||
shapes[2] = targetInputStaticShapes[1];
|
||||
shapes[3] = targetInputStaticShapes[2];
|
||||
SubgraphBaseTest::generate_inputs(shapes);
|
||||
}
|
||||
template <typename IT, typename T>
|
||||
void strided_iota(IT first, size_t n, T value, T stride) {
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
*first++ = value;
|
||||
value += stride;
|
||||
}
|
||||
}
|
||||
void generate(int idx, const std::vector<ov::Shape>& targetInputStaticShapes) {
|
||||
inputs.clear();
|
||||
auto create_input = [this](std::shared_ptr<ov::op::v0::Parameter> param, ov::Shape shape, float val) {
|
||||
if (param->get_element_type() == element::f32) {
|
||||
ov::Tensor t{ov::element::f32, shape};
|
||||
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
|
||||
inputs.insert({param, t});
|
||||
} else {
|
||||
ov::Tensor t{ov::element::bf16, shape};
|
||||
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
|
||||
inputs.insert({param, t});
|
||||
}
|
||||
};
|
||||
// q, k, v
|
||||
create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f);
|
||||
create_input(function->get_parameters()[1], targetInputStaticShapes[1], idx + 2.0f);
|
||||
create_input(function->get_parameters()[2], targetInputStaticShapes[1], idx + 3.0f);
|
||||
create_input(function->get_parameters()[3], targetInputStaticShapes[2], idx + 4.0f);
|
||||
}
|
||||
void prepare() {
|
||||
compile_model();
|
||||
inferRequest = compiledModel.create_infer_request();
|
||||
ASSERT_TRUE(inferRequest);
|
||||
}
|
||||
void reset() {
|
||||
for (auto&& state : inferRequest.query_state()) {
|
||||
state.reset();
|
||||
}
|
||||
}
|
||||
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
|
||||
function = model;
|
||||
prepare();
|
||||
std::vector<ov::Tensor> outputs;
|
||||
int idx = 0;
|
||||
for (auto&& shapes : targetStaticShapes) {
|
||||
generate(idx++, shapes);
|
||||
for (const auto& input : inputs) {
|
||||
inferRequest.set_tensor(input.first, input.second);
|
||||
}
|
||||
inferRequest.infer();
|
||||
auto outputTensor = inferRequest.get_output_tensor(0);
|
||||
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
|
||||
outputTensor.copy_to(copy);
|
||||
outputs.push_back(copy);
|
||||
}
|
||||
reset();
|
||||
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConcatMultiQuerySDPTest, CompareWithRefs) {
|
||||
auto actualOutputs = run_test(function);
|
||||
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
|
||||
if (configuration[ov::hint::inference_precision.name()] == ov::element::bf16) {
|
||||
CheckNumberOfNodesWithType(compiledModel, "Reorder", 5);
|
||||
} else {
|
||||
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
|
||||
}
|
||||
CheckNumberOfNodesWithType(compiledModel, "Transpose", 1);
|
||||
auto expectedOutputs = run_test(functionRefs);
|
||||
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
|
||||
for (size_t i = 0; i < actualOutputs.size(); i++) {
|
||||
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
const std::vector<InputShapeAndTransposeOrder> inputShapeAndReorders = {{
|
||||
{// inputShapes ChatGLM
|
||||
{
|
||||
// L1, B, H, S
|
||||
{{-1, 1, 8, 64}, {{10, 1, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {20, 1, 8, 64}, {1, 1, 8, 64}}},
|
||||
{{-1, 1, 2, 64}, {{10, 1, 2, 64}, {1, 1, 2, 64}, {1, 1, 2, 64}, {20, 1, 2, 64}, {1, 1, 2, 64}}},
|
||||
// L0, B, H, S
|
||||
{{-1, 1, 2, 64}, {{0, 1, 2, 64}, {10, 1, 2, 64}, {11, 1, 2, 64}, {12, 1, 2, 64}, {32, 1, 2, 64}}},
|
||||
},
|
||||
// transposeOrder
|
||||
{1, 2, 0, 3}},
|
||||
}};
|
||||
// TODO: BF16 test is disabled due to CI machine limitation
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ConcatMultiQuerySDPTest,
|
||||
ConcatMultiQuerySDPTest,
|
||||
::testing::Combine(::testing::Values(ElementType::f32),
|
||||
::testing::ValuesIn(inputShapeAndReorders),
|
||||
::testing::Values(true, false)),
|
||||
ConcatMultiQuerySDPTest::getTestCaseName);
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,296 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/opsets/opset13.hpp>
|
||||
#include <transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp>
|
||||
|
||||
#include "ov_models/builders.hpp"
|
||||
#include "ov_models/utils/ov_helpers.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
|
||||
|
||||
using namespace ov::test;
|
||||
using namespace ngraph;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using InputShapeAndTransposeOrder = std::pair<std::vector<InputShape>, std::vector<size_t>>;
|
||||
using ConcatSDPTransposeTestParams = std::tuple<ElementType,
|
||||
InputShapeAndTransposeOrder,
|
||||
bool // has ShapeOf
|
||||
>;
|
||||
// Subgraph:
|
||||
/* Parameter
|
||||
* |
|
||||
* Parameter ReadValue | ReadValue Parameter
|
||||
* \ / | \ /
|
||||
* \ / | \ /
|
||||
* Concat Transpose Concat
|
||||
* / \ | / \
|
||||
* / \ | / \
|
||||
* / Transpose | Transpose \
|
||||
* / \ | / \
|
||||
* Assign ScaledDotProductAttention Assign
|
||||
* |
|
||||
* Tranpose
|
||||
* |
|
||||
* Reshape
|
||||
* |
|
||||
* Add
|
||||
* |
|
||||
* Result
|
||||
*/
|
||||
|
||||
class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTransposeTestParams>,
|
||||
virtual public ov::test::SubgraphBaseTest,
|
||||
public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<ConcatSDPTransposeTestParams>& obj) {
|
||||
ElementType inType;
|
||||
InputShapeAndTransposeOrder inputShapeAndOrders;
|
||||
bool hasShapeof;
|
||||
std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param;
|
||||
std::ostringstream result;
|
||||
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
|
||||
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
|
||||
result << "IS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << ov::test::utils::partialShape2str({shape.first}) << "_";
|
||||
}
|
||||
result << "TS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << "(";
|
||||
if (!shape.second.empty()) {
|
||||
for (const auto& itr : shape.second) {
|
||||
result << ov::test::utils::vec2str(itr);
|
||||
}
|
||||
}
|
||||
result << ")_";
|
||||
}
|
||||
result << "Prc=" << inType << "_";
|
||||
result << "HasShapeOf=" << hasShapeof;
|
||||
result << "TransposeOrder=";
|
||||
result << "(";
|
||||
for (const auto& itr : transposeOrder) {
|
||||
result << itr << ",";
|
||||
}
|
||||
result << ")";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
ElementType inType;
|
||||
InputShapeAndTransposeOrder inputShapeAndOrders;
|
||||
bool hasShapeOf;
|
||||
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
|
||||
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
|
||||
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
rel_threshold = 1e-2f;
|
||||
configuration[ov::hint::inference_precision.name()] = ov::element::f32;
|
||||
if (inType == ElementType::bf16) {
|
||||
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
|
||||
rel_threshold = 0.01f;
|
||||
}
|
||||
init_input_shapes(inputShapes);
|
||||
ov::ParameterVector inputParams;
|
||||
// q,k,v
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams[0]->set_friendly_name("q");
|
||||
inputParams[1]->set_friendly_name("k");
|
||||
inputParams[2]->set_friendly_name("v");
|
||||
// pastkv init_cost
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||
auto var_k = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastk"});
|
||||
auto pastk = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_k);
|
||||
pastk->set_friendly_name("pastk_r");
|
||||
auto var_v = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastv"});
|
||||
auto pastv = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_v);
|
||||
pastv->set_friendly_name("pastv_r");
|
||||
std::shared_ptr<Node> pastk_shapeof, pastv_shapeof;
|
||||
if (hasShapeOf) {
|
||||
pastk_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastk);
|
||||
pastv_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastv);
|
||||
}
|
||||
|
||||
// pre SDPA transpose
|
||||
auto preOrder = ov::op::v0::Constant::create(ov::element::i32, {4}, transposeOrder);
|
||||
auto transposeQ = std::make_shared<ov::op::v1::Transpose>(inputParams[0], preOrder);
|
||||
|
||||
auto concat_axis = transposeOrder[2];
|
||||
auto concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, inputParams[1]}, concat_axis);
|
||||
auto concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, inputParams[2]}, concat_axis);
|
||||
auto transposeK = std::make_shared<ov::op::v1::Transpose>(concatK, preOrder);
|
||||
auto transposeV = std::make_shared<ov::op::v1::Transpose>(concatV, preOrder);
|
||||
|
||||
auto sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(transposeQ, transposeK, transposeV, false);
|
||||
sdp->set_friendly_name("mha");
|
||||
|
||||
// post SDPA transpose + reshape
|
||||
auto get_reshape_order = [](const ov::PartialShape& qkv_shape,
|
||||
const std::vector<size_t>& transposeOrder) -> std::vector<size_t> {
|
||||
assert(transposeOrder.size() == 4);
|
||||
auto H = qkv_shape[transposeOrder[1]].get_length();
|
||||
auto S = qkv_shape[transposeOrder[3]].get_length();
|
||||
return std::vector<size_t>{0, 0, static_cast<size_t>(H * S)};
|
||||
};
|
||||
const auto reshapeOrder = get_reshape_order(inputDynamicShapes[0], transposeOrder);
|
||||
|
||||
auto postOrder =
|
||||
ov::op::v0::Constant::create(ov::element::i32, {4}, std::vector<size_t>{0, 2, 1, 3}); // BHLS -> BLHS
|
||||
auto transposeSDP = std::make_shared<ov::op::v1::Transpose>(sdp, postOrder);
|
||||
|
||||
auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {3}, reshapeOrder);
|
||||
auto reshapeSDP = std::make_shared<ov::op::v1::Reshape>(transposeSDP, constReshape, true); // BLHS -> B,L,HxS
|
||||
|
||||
auto add = std::make_shared<ov::op::v1::Add>(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f}));
|
||||
auto pastk_assign = std::make_shared<ov::op::v6::Assign>(concatK, var_k);
|
||||
auto pastv_assign = std::make_shared<ov::op::v6::Assign>(concatV, var_v);
|
||||
pastk_assign->set_friendly_name("pastk_w");
|
||||
pastv_assign->set_friendly_name("pastv_w");
|
||||
|
||||
ov::OutputVector results{add};
|
||||
if (hasShapeOf) {
|
||||
results.push_back(pastk_shapeof);
|
||||
results.push_back(pastv_shapeof);
|
||||
}
|
||||
SinkVector sinks{pastk_assign, pastv_assign};
|
||||
function = std::make_shared<Function>(results, sinks, inputParams, "ConcatTranposeSDP");
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
|
||||
functionRefs = function->clone();
|
||||
pass::Manager manager;
|
||||
// decompose ScaledDotProductAttention
|
||||
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
|
||||
manager.run_passes(functionRefs);
|
||||
}
|
||||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||
std::vector<ov::Shape> shapes(4);
|
||||
shapes[0] = targetInputStaticShapes[0];
|
||||
shapes[1] = targetInputStaticShapes[0];
|
||||
shapes[2] = targetInputStaticShapes[0];
|
||||
shapes[3] = targetInputStaticShapes[1];
|
||||
SubgraphBaseTest::generate_inputs(shapes);
|
||||
}
|
||||
template <typename IT, typename T>
|
||||
void strided_iota(IT first, size_t n, T value, T stride) {
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
*first++ = value;
|
||||
value += stride;
|
||||
}
|
||||
}
|
||||
void generate(int idx, const std::vector<ov::Shape>& targetInputStaticShapes) {
|
||||
inputs.clear();
|
||||
auto create_input = [this](std::shared_ptr<op::v0::Parameter> param, ov::Shape shape, float val) {
|
||||
if (param->get_element_type() == element::f32) {
|
||||
ov::Tensor t{ov::element::f32, shape};
|
||||
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
|
||||
inputs.insert({param, t});
|
||||
} else {
|
||||
ov::Tensor t{ov::element::bf16, shape};
|
||||
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
|
||||
inputs.insert({param, t});
|
||||
}
|
||||
};
|
||||
// q, k, v
|
||||
create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f);
|
||||
create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f);
|
||||
create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f);
|
||||
create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f);
|
||||
}
|
||||
void prepare() {
|
||||
compile_model();
|
||||
inferRequest = compiledModel.create_infer_request();
|
||||
ASSERT_TRUE(inferRequest);
|
||||
}
|
||||
void reset() {
|
||||
for (auto&& state : inferRequest.query_state()) {
|
||||
state.reset();
|
||||
}
|
||||
}
|
||||
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
|
||||
function = model;
|
||||
prepare();
|
||||
std::vector<ov::Tensor> outputs;
|
||||
int idx = 0;
|
||||
for (auto&& shapes : targetStaticShapes) {
|
||||
generate(idx++, shapes);
|
||||
for (const auto& input : inputs) {
|
||||
inferRequest.set_tensor(input.first, input.second);
|
||||
}
|
||||
inferRequest.infer();
|
||||
auto outputTensor = inferRequest.get_output_tensor(0);
|
||||
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
|
||||
outputTensor.copy_to(copy);
|
||||
outputs.push_back(copy);
|
||||
}
|
||||
reset();
|
||||
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConcatSDPTransposeTest, CompareWithRefs) {
|
||||
auto actualOutputs = run_test(function);
|
||||
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Transpose", 1);
|
||||
auto expectedOutputs = run_test(functionRefs);
|
||||
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
|
||||
for (size_t i = 0; i < actualOutputs.size(); i++) {
|
||||
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
const std::vector<InputShapeAndTransposeOrder> inputShapeAndReorders = {
|
||||
{
|
||||
// inputShapes LLama
|
||||
{
|
||||
// B, H, L1, S
|
||||
{{1, 8, -1, 64}, {{1, 8, 10, 64}, {1, 8, 1, 64}, {1, 8, 1, 64}, {1, 8, 20, 64}, {1, 8, 1, 64}}},
|
||||
// B, H, L0, S
|
||||
{{1, 8, -1, 64}, {{1, 8, 0, 64}, {1, 8, 10, 64}, {1, 8, 11, 64}, {1, 8, 12, 64}, {1, 8, 32, 64}}},
|
||||
},
|
||||
// transposeOrder
|
||||
{0, 1, 2, 3}},
|
||||
{// inputShapes QWen
|
||||
{
|
||||
// B, L1, H, S
|
||||
{{1, -1, 8, 64}, {{1, 10, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {1, 20, 8, 64}, {1, 1, 8, 64}}},
|
||||
// B, L0, H, S
|
||||
{{1, -1, 8, 64}, {{1, 0, 8, 64}, {1, 10, 8, 64}, {1, 11, 8, 64}, {1, 12, 8, 64}, {1, 32, 8, 64}}},
|
||||
},
|
||||
// transposeOrder
|
||||
{0, 2, 1, 3}},
|
||||
{// inputShapes ChatGLM
|
||||
{
|
||||
// L1, B, H, S
|
||||
{{-1, 1, 8, 64}, {{10, 1, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {20, 1, 8, 64}, {1, 1, 8, 64}}},
|
||||
// L0, B, H, S
|
||||
{{-1, 1, 8, 64}, {{0, 1, 8, 64}, {10, 1, 8, 64}, {11, 1, 8, 64}, {12, 1, 8, 64}, {32, 1, 8, 64}}},
|
||||
},
|
||||
// transposeOrder
|
||||
{1, 2, 0, 3}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTransposeTest,
|
||||
ConcatSDPTransposeTest,
|
||||
::testing::Combine(::testing::Values(ElementType::f32),
|
||||
::testing::ValuesIn(inputShapeAndReorders),
|
||||
::testing::Values(true, false)),
|
||||
ConcatSDPTransposeTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user