[CPU] SDPA supports multi-query and different input layout (#21513)

This commit is contained in:
Zhang Yi 2023-12-16 07:38:04 +08:00 committed by GitHub
parent eff9ba76ba
commit 17fb201433
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 898 additions and 42 deletions

View File

@ -139,7 +139,11 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
auto q_len = query.size(2);
auto S = query.size(3);
auto kv_len = present_key.size(2);
auto h_group_num = present_key.size(1);
size_t h_each_group_len = 1;
if (h_group_num != H) {
h_each_group_len = H / h_group_num;
}
if (d_scale == 0.0f)
d_scale = 1.0f / sqrt(S);
@ -149,20 +153,21 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
bool is_abcd = present_key.stride(1) >= present_key.stride(2);
size_t dim0 = is_abcd ? B : kv_len;
size_t dim1 = is_abcd ? H : B;
size_t dim2 = is_abcd ? kv_len : H;
size_t dim1 = is_abcd ? h_group_num : B;
size_t dim2 = is_abcd ? kv_len : h_group_num;
parallel_for3d(dim0, dim1, dim2, [&](size_t d0, size_t d1, size_t d2) {
size_t b = is_abcd ? d0 : d1;
size_t h = is_abcd ? d1 : d2;
size_t h_group = is_abcd ? d1 : d2;
size_t pk = is_abcd ? d2 : d0;
// which batch item should be used at postion pk?
auto b_kv = beams ? beams.at<int32_t>({b, pk}) : b;
for (size_t pq = 0; pq < q_len; pq++) {
buf_attn_w.at<float>({b, h, pq, pk}) = dot_product(&query.at<T>({b, h, pq, 0}),
&present_key.at<T2>({b_kv, h, pk, 0}, true),
S);
for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) {
buf_attn_w.at<float>({b, h, pq, pk}) =
dot_product(&query.at<T>({b, h, pq, 0}), &present_key.at<T2>({b_kv, h_group, pk, 0}, true), S);
}
}
});
@ -190,29 +195,31 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
// buf_attn_w {B, H, q_len, kv_len}
parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) {
size_t start{0}, end{0};
splitter(B * H * kv_len, nthr, ithr, start, end);
splitter(B * h_group_num * kv_len, nthr, ithr, start, end);
memset(&buf_attn_score.at<float>({ithr, 0, 0, 0, 0}), 0, buf_attn_score.stride(0) * sizeof(float));
size_t b, h, pv;
size_t b, h_group, pv;
if (start < end) {
if (is_abcd)
parallel_it_init(start, b, B, h, H, pv, kv_len);
parallel_it_init(start, b, B, h_group, h_group_num, pv, kv_len);
else
parallel_it_init(start, pv, kv_len, b, B, h, H);
parallel_it_init(start, pv, kv_len, b, B, h_group, h_group_num);
for (size_t iwork = start; iwork < end; ++iwork) {
auto b_kv = beams ? beams.at<int32_t>({b, pv}) : b;
auto* v = &present_value.at<T2>({b_kv, h, pv, 0}, true);
auto* v = &present_value.at<T2>({b_kv, h_group, pv, 0}, true);
for (size_t pq = 0; pq < q_len; pq++) {
attn_acc_value(&buf_attn_score.at<float>({ithr, b, pq, h, 0}),
buf_attn_w.at<float>({b, h, pq, pv}),
v,
S);
for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) {
attn_acc_value(&buf_attn_score.at<float>({ithr, b, pq, h, 0}),
buf_attn_w.at<float>({b, h, pq, pv}),
v,
S);
}
}
if (is_abcd)
parallel_it_step(b, B, h, H, pv, kv_len);
parallel_it_step(b, B, h_group, h_group_num, pv, kv_len);
else
parallel_it_step(pv, kv_len, b, B, h, H);
parallel_it_step(pv, kv_len, b, B, h_group, h_group_num);
}
}
});

View File

@ -548,6 +548,7 @@ void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {
// Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout
precision = SDPA->getOriginalInputPrecisionAtPort(childPort);
// Just used a place holder here, the actual layout is obtained at initOptimalPrimitiveDescriptor
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
PortConfig outPortConfig;
@ -573,7 +574,6 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
"failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
const auto& childConfig = childPd->getConfig();
auto childPrecision = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc()->getPrecision();
auto selectedPd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selectedPd,
@ -582,8 +582,9 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
" failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
auto config = selectedPd->getConfig();
auto memDesc = config.outConfs.front().getMemDesc();
auto newMemDesc = memDesc->cloneWithNewPrecision(childPrecision);
// The pyscial layout varies from models, e.g. [LBHS]chatglm, [BHLS]Llama
// The SDPA knows details, so should trust the layout config provided by SPDA
auto newMemDesc = childConfig.inConfs.back().getMemDesc();
config.outConfs.front().setMemDesc(newMemDesc);
//bypass any checks, we enforce the child descriptor precision
selectedPd->setConfig(config);

View File

@ -512,10 +512,16 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
v_input.assert_dims({B, 0, L1, S}, true);
auto past_k_idx = inputs.size() - 2;
auto past_k_mem = inputs[past_k_idx + 0];
L0 = past_k_mem->getStaticDims()[2];
const auto& permute_axes = config.config.permute_axes;
L0 = permute_axes.empty() ? past_k_mem->getStaticDims()[2] : past_k_mem->getStaticDims()[permute_axes[2]];
// [B, H, L0, S]
past_k_output.reset(outputs[1]);
past_v_output.reset(outputs[2]);
if (!permute_axes.empty()) {
// [L, B, H, S] -> [B, H, L, S]
past_k_output = past_k_output.permute(permute_axes);
past_v_output = past_v_output.permute(permute_axes);
}
attn_memcpy(k_input, v_input, past_k_output.slice(2, L0, L0 + L1), past_v_output.slice(2, L0, L0 + L1));
if (!config.is_concat_inplaced) {
PlainTensor past_k_input, past_v_input;
@ -560,12 +566,18 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
}
// q: [B, H, L1, S]
const auto & permute_axes = config.config.permute_axes;
PlainTensor present_key, present_value;
if (!permute_axes.empty()) {
q_input = q_input.permute(permute_axes);
k_input = k_input.permute(permute_axes);
v_input = v_input.permute(permute_axes);
}
B = q_input.size(0);
H = q_input.size(1);
L1 = q_input.size(2);
S = q_input.size(-1);
PlainTensor present_key, present_value;
concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value);
ov::intel_cpu::PlainTensor output_emb(outputs[0]);
@ -634,9 +646,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ngrap
void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
rtPrecision = getOriginalInputPrecisionAtPort(0);
auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0);
size_t H_idx = 1;
if (!m_config.config.permute_axes.empty()) {
H_idx = m_config.config.permute_axes[1];
}
const auto& qDims = getInputShapeAtPort(0).getDims();
const auto& kDims = getInputShapeAtPort(1).getDims();
// if multi-query, enforce fp32 TODO: support BF16
if (qDims[H_idx] != kDims[H_idx]) {
rtPrecision = ov::element::f32;
}
bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && rtPrecision != ov::element::bf16;
auto kvCachePrecision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision;
@ -669,17 +692,25 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
}
if (m_config.config.fuse_concat) {
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc(
ArbitraryOrderDescCreator layoutDescCreator({2, 0, 1, 3});
const auto& permute_axes = m_config.config.permute_axes;
if (!permute_axes.empty()) {
// [L,B,H,S]->permute[1,2,0,3] ->[B,H,L,S]
// The actual index of B is permute[0], H is permute[1], L is permute[2], S is permute[3]
layoutDescCreator = ArbitraryOrderDescCreator({static_cast<size_t>(permute_axes[2]),
static_cast<size_t>(permute_axes[0]),
static_cast<size_t>(permute_axes[1]),
static_cast<size_t>(permute_axes[3])});
}
config.inConfs[orginSDPInputNumber + 0].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc(
config.inConfs[orginSDPInputNumber + 1].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));
config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc(
config.outConfs[1].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getOutputShapeAtPort(1)));
config.outConfs[1].inPlace(orginSDPInputNumber + 0);
config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc(
config.outConfs[2].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getOutputShapeAtPort(2)));
config.outConfs[2].inPlace(orginSDPInputNumber + 1);
}
@ -712,7 +743,6 @@ void ScaledDotProductAttention::createPrimitive() {
m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0;
}
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);

View File

@ -54,6 +54,7 @@ private:
Config m_config;
std::shared_ptr<Executor> m_executor;
template <KernelTypes KType, typename T> struct AttentionExecutor;
ov::element::Type rtPrecision;
};
} // namespace node

View File

@ -29,18 +29,43 @@ void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_typ
// [B, H, L0, S]
auto past_kv_ps = get_input_partial_shape(input_num - 1);
auto output_logits = q_ps;
NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false);
NODE_VALIDATION_CHECK(this, q_ps.size() >= 3);
// permute_axes from original to [B, H, L, S]
const auto& permute_axes = this->m_config.permute_axes;
if (past_kv_ps.rank().is_static()) {
const size_t length_index = permute_axes.empty() ? q_ps.size() - 2 : permute_axes[permute_axes.size() - 2];
const size_t head_num_index = permute_axes.empty() ? q_ps.size() - 3 : permute_axes[permute_axes.size() - 3];
NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size());
for (size_t i = 0; i < q_ps.size(); i++) {
if (i == q_ps.size() - 2)
if (i == head_num_index) {
if (q_ps[i].is_static() && past_kv_ps[i].is_static()) {
NODE_VALIDATION_CHECK(this,
q_ps[i].get_length() % past_kv_ps[i].get_length() == 0,
"shape not compatiable at index ",
i);
}
} else if (i == length_index) {
continue;
NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i]));
} else {
NODE_VALIDATION_CHECK(this,
q_ps[i].compatible(past_kv_ps[i]),
"shape not compatiable at index ",
i);
}
}
past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2];
past_kv_ps[length_index] += q_ps[length_index];
}
set_output_type(0, get_input_element_type(0), q_ps);
if (!permute_axes.empty()) {
if (q_ps.rank().is_static()) {
// q_ps needs permute to BHLS
for (size_t i = 0; i < q_ps.size(); i++) {
output_logits[i] = q_ps[permute_axes[i]];
}
}
}
set_output_type(0, get_input_element_type(0), output_logits);
set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps);
set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps);
}
@ -52,6 +77,7 @@ bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::A
visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn);
visitor.on_attribute("is_causal", m_config.is_causal);
visitor.on_attribute("fuse_concat", m_config.fuse_concat);
visitor.on_attribute("permute_axes", m_config.permute_axes);
visitor.finish_structure();
return true;
}
}

View File

@ -21,11 +21,13 @@ public:
ScaledDotProductAttentionWithKVCache() = default;
struct Config {
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
bool is_causal = false; // apply causal mask internally
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
bool is_causal = false; // apply causal mask internally
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
std::vector<size_t> permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S]
// e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S]
};
ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg);
@ -47,4 +49,4 @@ private:
};
} // namespace intel_cpu
} // namespace ov
} // namespace ov

View File

@ -0,0 +1,176 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "stateful_transpose_sdpa_fusion.hpp"
#include <cstdint>
#include <limits>
#include <openvino/core/rt_info.hpp>
#include <openvino/opsets/opset1.hpp>
#include <openvino/opsets/opset13.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
#include "itt.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
namespace ov {
namespace intel_cpu {
StatefulTransposeSDPAFusion::StatefulTransposeSDPAFusion() {
MATCHER_SCOPE(StatefulTransposeSDPAFusion);
using namespace ov::pass::pattern;
auto past_k = wrap_type<opset6::ReadValue>();
auto past_v = wrap_type<opset6::ReadValue>();
auto convert_past_k = wrap_type<opset1::Convert>({past_k});
auto convert_past_v = wrap_type<opset1::Convert>({past_v});
auto concat_input_k = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_k, convert_past_k});
auto concat_input_v = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_v, convert_past_v});
auto concat_k = wrap_type<opset6::Concat>({concat_input_k, any_input()});
auto concat_v = wrap_type<opset6::Concat>({concat_input_v, any_input()});
// multi-query branch
auto reshape_k = wrap_type<opset6::Reshape>({concat_k, any_input()});
auto reshape_v = wrap_type<opset6::Reshape>({concat_v, any_input()});
auto constant_k = wrap_type<opset6::Constant>();
auto constant_v = wrap_type<opset6::Constant>();
auto multiply_k = wrap_type<opset6::Multiply>({reshape_k, constant_k});
auto multiply_v = wrap_type<opset6::Multiply>({reshape_v, constant_v});
auto reshape1_k = wrap_type<opset6::Reshape>({multiply_k, any_input()});
auto reshape1_v = wrap_type<opset6::Reshape>({multiply_v, any_input()});
auto transpose_k_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{reshape1_k, concat_k});
auto transpose_v_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{reshape1_v, concat_v});
auto order_k = wrap_type<opset6::Constant>();
auto order_v = wrap_type<opset6::Constant>();
auto transpose_k = wrap_type<opset6::Transpose>({transpose_k_input, order_k});
auto transpose_v = wrap_type<opset6::Transpose>({transpose_v_input, order_v});
auto order_q = wrap_type<opset6::Constant>();
auto q_input = any_input();
auto transpose_q = wrap_type<opset6::Transpose>({q_input, order_q});
auto sdp0 = wrap_type<opset13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v});
auto sdp1 = wrap_type<opset13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v, any_input()});
auto sdp2 = wrap_type<opset13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v, any_input(), any_input()});
auto sdp = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{sdp0, sdp1, sdp2});
ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();
auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
auto present_to = out.get_target_inputs();
if (present_to.size() != 2)
return;
for (auto& to : present_to) {
auto to_node = to.get_node();
if (auto convert = dynamic_cast<opset1::Convert*>(to_node)) {
auto cvt_targets = convert->get_output_target_inputs(0);
if (cvt_targets.size() == 1) {
to_node = cvt_targets.begin()->get_node();
cvt = convert;
}
}
assign = dynamic_cast<opset6::Assign*>(to_node);
if (assign)
return;
}
};
std::shared_ptr<opset1::Convert> read_cvt_k_node, read_cvt_v_node;
const auto sdp_node = ov::as_type_ptr<opset13::ScaledDotProductAttention>(root);
const auto past_k_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_k).get_node_shared_ptr());
const auto past_v_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_v).get_node_shared_ptr());
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());
if (pattern_map.count(convert_past_k)) {
read_cvt_k_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_k).get_node_shared_ptr());
read_cvt_v_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_v).get_node_shared_ptr());
}
// check broadcast arg has all ones
auto check_bcst = [&](const std::shared_ptr<Node>& ptr) {
const auto constant_node = ov::as_type_ptr<opset6::Constant>(ptr);
const auto& bcst_arg = constant_node->cast_vector<float>();
return std::all_of(bcst_arg.begin(), bcst_arg.end(), [](int i) {
return i == 1.0;
});
};
if (pattern_map.count(constant_k)) {
if (!check_bcst(pattern_map.at(constant_k).get_node_shared_ptr()))
return false;
}
if (pattern_map.count(constant_v)) {
if (!check_bcst(pattern_map.at(constant_v).get_node_shared_ptr()))
return false;
}
opset6::Assign* assign_k_node = nullptr, *assign_v_node = nullptr;
opset1::Convert* assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
find_assign(concat_k_node, assign_k_node, assign_cvt_k_node);
if (!assign_k_node)
return false;
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
return false;
find_assign(concat_v_node, assign_v_node, assign_cvt_v_node);
if (!assign_v_node)
return false;
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
return false;
auto args = sdp_node->input_values();
args[0] = pattern_map.at(q_input).get_node_shared_ptr()->output(0);
args[1] = concat_k_node->input_value(1);
args[2] = concat_v_node->input_value(1);
args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0));
args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0));
ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config;
const auto order_q_node = ov::as_type_ptr<opset6::Constant>(pattern_map.at(order_q).get_node_shared_ptr());
const auto order_k_node = ov::as_type_ptr<opset6::Constant>(pattern_map.at(order_k).get_node_shared_ptr());
const auto order_v_node = ov::as_type_ptr<opset6::Constant>(pattern_map.at(order_v).get_node_shared_ptr());
const auto& permute_q = order_q_node->cast_vector<int32_t>();
const auto& permute_k = order_k_node->cast_vector<int32_t>();
const auto& permute_v = order_v_node->cast_vector<int32_t>();
if (permute_q != permute_k || permute_q != permute_v) {
return false;
}
config.is_causal = sdp_node->get_causal();
config.fuse_concat = true;
config.permute_axes.resize(permute_q.size());
for (size_t i = 0; i < permute_q.size(); i++) {
config.permute_axes[i] = static_cast<size_t>(permute_q[i]);
}
auto& old_node = sdp_node;
auto new_node = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionWithKVCache>(args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::replace_node(old_node, {new_node->output(0)});
if (assign_cvt_k_node)
assign_cvt_k_node->set_arguments({new_node->output(1)});
else
assign_k_node->set_arguments({new_node->output(1)});
if (assign_cvt_v_node)
assign_cvt_v_node->set_arguments({new_node->output(2)});
else
assign_v_node->set_arguments({new_node->output(2)});
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(sdp, matcher_name);
this->register_matcher(m, callback);
}
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,18 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace ov {
namespace intel_cpu {
class StatefulTransposeSDPAFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("StatefulTransposeSDPAFusion", "0");
StatefulTransposeSDPAFusion();
};
} // namespace intel_cpu
} // namespace ov

View File

@ -114,6 +114,7 @@
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
#include "transformations/cpu_opset/common/pass/rope_fusion.hpp"
#include "transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp"
#include "transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.hpp"
// Snippets
#include "snippets/pass/tokenization.hpp"
@ -661,6 +662,7 @@ void Transformations::PostLpt() {
CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPAFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulTransposeSDPAFusion);
postLPTPassManager.run_passes(model);
}

View File

@ -0,0 +1,297 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/opsets/opset13.hpp>
#include <transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp>
#include "ov_models/builders.hpp"
#include "ov_models/utils/ov_helpers.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
using namespace ov::test;
using namespace ngraph;
using namespace CPUTestUtils;
using namespace InferenceEngine;
namespace SubgraphTestsDefinitions {
using InputShapeAndTransposeOrder = std::pair<std::vector<InputShape>, std::vector<size_t>>;
using ConcatMultiQuerySDPParams = std::tuple<ElementType,
InputShapeAndTransposeOrder,
bool // has ShapeOf
>;
// Subgraph:
/* Parameter
* |
* Parameter ReadValue | ReadValue Parameter
* \ / | \ /
* \ / | \ /
* Concat Transpose Concat
* / \ | / \
* / \ | / \
* / MultiQuery | MultiQuery \
* / \ | / \
* / Transpose | Transpose \
* / \ | / \
* Assign ScaledDotProductAttention Assign
* |
* Tranpose
* |
* Reshape
* |
* Add
* |
* Result
*/
class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQuerySDPParams>,
virtual public ov::test::SubgraphBaseTest,
public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConcatMultiQuerySDPParams>& obj) {
ElementType inType;
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeof;
std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param;
std::ostringstream result;
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
result << "IS=";
for (const auto& shape : inputShapes) {
result << ov::test::utils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
if (!shape.second.empty()) {
for (const auto& itr : shape.second) {
result << ov::test::utils::vec2str(itr);
}
}
result << ")_";
}
result << "Prc=" << inType << "_";
result << "HasShapeOf=" << hasShapeof;
result << "TransposeOrder=";
result << "(";
for (const auto& itr : transposeOrder) {
result << itr << ",";
}
result << ")";
return result.str();
}
void SetUp() override {
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeOf;
ElementType inType;
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
targetDevice = ov::test::utils::DEVICE_CPU;
rel_threshold = 1e-2f;
configuration[ov::hint::inference_precision.name()] = ov::element::f32;
if (inType == ElementType::bf16) {
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
rel_threshold = 0.01f;
}
init_input_shapes(inputShapes);
ov::ParameterVector inputParams;
// q,k,v
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams[0]->set_friendly_name("q");
inputParams[1]->set_friendly_name("k");
inputParams[2]->set_friendly_name("v");
// pastkv init_cost
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
auto var_k = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastk"});
auto pastk = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_k);
pastk->set_friendly_name("pastk_r");
auto var_v = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastv"});
auto pastv = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_v);
pastv->set_friendly_name("pastv_r");
std::shared_ptr<Node> pastk_shapeof, pastv_shapeof;
if (hasShapeOf) {
pastk_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastk);
pastv_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastv);
}
// pre SDPA transpose
auto preOrder = op::v0::Constant::create(ov::element::i32, {4}, transposeOrder);
auto transposeQ = std::make_shared<ov::op::v1::Transpose>(inputParams[0], preOrder);
auto concat_axis = transposeOrder[2];
auto concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, inputParams[1]}, concat_axis);
auto concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, inputParams[2]}, concat_axis);
auto unsquezeAxis = op::v0::Constant::create(ov::element::i32, {}, {-2});
auto unsqueezeK = std::make_shared<ov::op::v0::Unsqueeze>(concatK, unsquezeAxis);
auto unsqueezeV = std::make_shared<ov::op::v0::Unsqueeze>(concatV, unsquezeAxis);
auto targetShape = op::v0::Constant::create(inType, {1, 1, 1, 4, 1}, {1});
auto broadcastK = std::make_shared<ov::op::v1::Multiply>(unsqueezeK, targetShape);
auto broadcastV = std::make_shared<ov::op::v1::Multiply>(unsqueezeV, targetShape);
auto target4D = op::v0::Constant::create(ov::element::i32, {4}, {0, 0, 8, 64});
auto reshapeK = std::make_shared<ov::op::v1::Reshape>(broadcastK, target4D, true);
auto reshapeV = std::make_shared<ov::op::v1::Reshape>(broadcastV, target4D, true);
auto transposeK = std::make_shared<ov::op::v1::Transpose>(reshapeK, preOrder);
auto transposeV = std::make_shared<ov::op::v1::Transpose>(reshapeV, preOrder);
auto sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(transposeQ, transposeK, transposeV, false);
sdp->set_friendly_name("mha");
// post SDPA transpose + reshape
auto get_reshape_order = [](const ov::PartialShape& qkv_shape,
const std::vector<size_t>& transposeOrder) -> std::vector<size_t> {
assert(transposeOrder.size() == 4);
auto H = qkv_shape[transposeOrder[1]].get_length();
auto S = qkv_shape[transposeOrder[3]].get_length();
return std::vector<size_t>{0, 0, static_cast<size_t>(H * S)};
};
const auto reshapeOrder = get_reshape_order(inputDynamicShapes[0], transposeOrder);
auto postOrder =
ov::op::v0::Constant::create(ov::element::i32, {4}, std::vector<size_t>{0, 2, 1, 3}); // BHLS -> BLHS
auto transposeSDP = std::make_shared<ov::op::v1::Transpose>(sdp, postOrder);
auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {3}, reshapeOrder);
auto reshapeSDP = std::make_shared<ov::op::v1::Reshape>(transposeSDP, constReshape, true); // BLHS -> B,L,HxS
auto add = std::make_shared<ov::op::v1::Add>(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f}));
auto pastk_assign = std::make_shared<ov::op::v6::Assign>(concatK, var_k);
auto pastv_assign = std::make_shared<ov::op::v6::Assign>(concatV, var_v);
pastk_assign->set_friendly_name("pastk_w");
pastv_assign->set_friendly_name("pastv_w");
ov::OutputVector results{add};
if (hasShapeOf) {
results.push_back(pastk_shapeof);
results.push_back(pastv_shapeof);
}
SinkVector sinks{pastk_assign, pastv_assign};
function = std::make_shared<Function>(results, sinks, inputParams, "ConcatTranposeSDP");
targetDevice = ov::test::utils::DEVICE_CPU;
functionRefs = function->clone();
pass::Manager manager;
// decompose ScaledDotProductAttention
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
manager.run_passes(functionRefs);
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
std::vector<ov::Shape> shapes(4);
shapes[0] = targetInputStaticShapes[0];
shapes[1] = targetInputStaticShapes[1];
shapes[2] = targetInputStaticShapes[1];
shapes[3] = targetInputStaticShapes[2];
SubgraphBaseTest::generate_inputs(shapes);
}
template <typename IT, typename T>
void strided_iota(IT first, size_t n, T value, T stride) {
for (size_t i = 0; i < n; i++) {
*first++ = value;
value += stride;
}
}
void generate(int idx, const std::vector<ov::Shape>& targetInputStaticShapes) {
inputs.clear();
auto create_input = [this](std::shared_ptr<ov::op::v0::Parameter> param, ov::Shape shape, float val) {
if (param->get_element_type() == element::f32) {
ov::Tensor t{ov::element::f32, shape};
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
} else {
ov::Tensor t{ov::element::bf16, shape};
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
}
};
// q, k, v
create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f);
create_input(function->get_parameters()[1], targetInputStaticShapes[1], idx + 2.0f);
create_input(function->get_parameters()[2], targetInputStaticShapes[1], idx + 3.0f);
create_input(function->get_parameters()[3], targetInputStaticShapes[2], idx + 4.0f);
}
void prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}
void reset() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
}
}
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
function = model;
prepare();
std::vector<ov::Tensor> outputs;
int idx = 0;
for (auto&& shapes : targetStaticShapes) {
generate(idx++, shapes);
for (const auto& input : inputs) {
inferRequest.set_tensor(input.first, input.second);
}
inferRequest.infer();
auto outputTensor = inferRequest.get_output_tensor(0);
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
outputTensor.copy_to(copy);
outputs.push_back(copy);
}
reset();
return outputs;
}
};
TEST_P(ConcatMultiQuerySDPTest, CompareWithRefs) {
auto actualOutputs = run_test(function);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
if (configuration[ov::hint::inference_precision.name()] == ov::element::bf16) {
CheckNumberOfNodesWithType(compiledModel, "Reorder", 5);
} else {
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
}
CheckNumberOfNodesWithType(compiledModel, "Transpose", 1);
auto expectedOutputs = run_test(functionRefs);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
for (size_t i = 0; i < actualOutputs.size(); i++) {
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
}
}
namespace {
const std::vector<InputShapeAndTransposeOrder> inputShapeAndReorders = {{
{// inputShapes ChatGLM
{
// L1, B, H, S
{{-1, 1, 8, 64}, {{10, 1, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {20, 1, 8, 64}, {1, 1, 8, 64}}},
{{-1, 1, 2, 64}, {{10, 1, 2, 64}, {1, 1, 2, 64}, {1, 1, 2, 64}, {20, 1, 2, 64}, {1, 1, 2, 64}}},
// L0, B, H, S
{{-1, 1, 2, 64}, {{0, 1, 2, 64}, {10, 1, 2, 64}, {11, 1, 2, 64}, {12, 1, 2, 64}, {32, 1, 2, 64}}},
},
// transposeOrder
{1, 2, 0, 3}},
}};
// TODO: BF16 test is disabled due to CI machine limitation
INSTANTIATE_TEST_SUITE_P(smoke_ConcatMultiQuerySDPTest,
ConcatMultiQuerySDPTest,
::testing::Combine(::testing::Values(ElementType::f32),
::testing::ValuesIn(inputShapeAndReorders),
::testing::Values(true, false)),
ConcatMultiQuerySDPTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@ -0,0 +1,296 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/opsets/opset13.hpp>
#include <transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp>
#include "ov_models/builders.hpp"
#include "ov_models/utils/ov_helpers.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
using namespace ov::test;
using namespace ngraph;
using namespace CPUTestUtils;
using namespace InferenceEngine;
namespace SubgraphTestsDefinitions {
using InputShapeAndTransposeOrder = std::pair<std::vector<InputShape>, std::vector<size_t>>;
using ConcatSDPTransposeTestParams = std::tuple<ElementType,
InputShapeAndTransposeOrder,
bool // has ShapeOf
>;
// Subgraph:
/* Parameter
* |
* Parameter ReadValue | ReadValue Parameter
* \ / | \ /
* \ / | \ /
* Concat Transpose Concat
* / \ | / \
* / \ | / \
* / Transpose | Transpose \
* / \ | / \
* Assign ScaledDotProductAttention Assign
* |
* Tranpose
* |
* Reshape
* |
* Add
* |
* Result
*/
class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTransposeTestParams>,
virtual public ov::test::SubgraphBaseTest,
public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConcatSDPTransposeTestParams>& obj) {
ElementType inType;
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeof;
std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param;
std::ostringstream result;
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
result << "IS=";
for (const auto& shape : inputShapes) {
result << ov::test::utils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
if (!shape.second.empty()) {
for (const auto& itr : shape.second) {
result << ov::test::utils::vec2str(itr);
}
}
result << ")_";
}
result << "Prc=" << inType << "_";
result << "HasShapeOf=" << hasShapeof;
result << "TransposeOrder=";
result << "(";
for (const auto& itr : transposeOrder) {
result << itr << ",";
}
result << ")";
return result.str();
}
void SetUp() override {
ElementType inType;
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeOf;
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
targetDevice = ov::test::utils::DEVICE_CPU;
rel_threshold = 1e-2f;
configuration[ov::hint::inference_precision.name()] = ov::element::f32;
if (inType == ElementType::bf16) {
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
rel_threshold = 0.01f;
}
init_input_shapes(inputShapes);
ov::ParameterVector inputParams;
// q,k,v
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams[0]->set_friendly_name("q");
inputParams[1]->set_friendly_name("k");
inputParams[2]->set_friendly_name("v");
// pastkv init_cost
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
auto var_k = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastk"});
auto pastk = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_k);
pastk->set_friendly_name("pastk_r");
auto var_v = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastv"});
auto pastv = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_v);
pastv->set_friendly_name("pastv_r");
std::shared_ptr<Node> pastk_shapeof, pastv_shapeof;
if (hasShapeOf) {
pastk_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastk);
pastv_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastv);
}
// pre SDPA transpose
auto preOrder = ov::op::v0::Constant::create(ov::element::i32, {4}, transposeOrder);
auto transposeQ = std::make_shared<ov::op::v1::Transpose>(inputParams[0], preOrder);
auto concat_axis = transposeOrder[2];
auto concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, inputParams[1]}, concat_axis);
auto concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, inputParams[2]}, concat_axis);
auto transposeK = std::make_shared<ov::op::v1::Transpose>(concatK, preOrder);
auto transposeV = std::make_shared<ov::op::v1::Transpose>(concatV, preOrder);
auto sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(transposeQ, transposeK, transposeV, false);
sdp->set_friendly_name("mha");
// post SDPA transpose + reshape
auto get_reshape_order = [](const ov::PartialShape& qkv_shape,
const std::vector<size_t>& transposeOrder) -> std::vector<size_t> {
assert(transposeOrder.size() == 4);
auto H = qkv_shape[transposeOrder[1]].get_length();
auto S = qkv_shape[transposeOrder[3]].get_length();
return std::vector<size_t>{0, 0, static_cast<size_t>(H * S)};
};
const auto reshapeOrder = get_reshape_order(inputDynamicShapes[0], transposeOrder);
auto postOrder =
ov::op::v0::Constant::create(ov::element::i32, {4}, std::vector<size_t>{0, 2, 1, 3}); // BHLS -> BLHS
auto transposeSDP = std::make_shared<ov::op::v1::Transpose>(sdp, postOrder);
auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {3}, reshapeOrder);
auto reshapeSDP = std::make_shared<ov::op::v1::Reshape>(transposeSDP, constReshape, true); // BLHS -> B,L,HxS
auto add = std::make_shared<ov::op::v1::Add>(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f}));
auto pastk_assign = std::make_shared<ov::op::v6::Assign>(concatK, var_k);
auto pastv_assign = std::make_shared<ov::op::v6::Assign>(concatV, var_v);
pastk_assign->set_friendly_name("pastk_w");
pastv_assign->set_friendly_name("pastv_w");
ov::OutputVector results{add};
if (hasShapeOf) {
results.push_back(pastk_shapeof);
results.push_back(pastv_shapeof);
}
SinkVector sinks{pastk_assign, pastv_assign};
function = std::make_shared<Function>(results, sinks, inputParams, "ConcatTranposeSDP");
targetDevice = ov::test::utils::DEVICE_CPU;
functionRefs = function->clone();
pass::Manager manager;
// decompose ScaledDotProductAttention
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
manager.run_passes(functionRefs);
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
std::vector<ov::Shape> shapes(4);
shapes[0] = targetInputStaticShapes[0];
shapes[1] = targetInputStaticShapes[0];
shapes[2] = targetInputStaticShapes[0];
shapes[3] = targetInputStaticShapes[1];
SubgraphBaseTest::generate_inputs(shapes);
}
template <typename IT, typename T>
void strided_iota(IT first, size_t n, T value, T stride) {
for (size_t i = 0; i < n; i++) {
*first++ = value;
value += stride;
}
}
void generate(int idx, const std::vector<ov::Shape>& targetInputStaticShapes) {
inputs.clear();
auto create_input = [this](std::shared_ptr<op::v0::Parameter> param, ov::Shape shape, float val) {
if (param->get_element_type() == element::f32) {
ov::Tensor t{ov::element::f32, shape};
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
} else {
ov::Tensor t{ov::element::bf16, shape};
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
}
};
// q, k, v
create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f);
create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f);
create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f);
create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f);
}
void prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}
void reset() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
}
}
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
function = model;
prepare();
std::vector<ov::Tensor> outputs;
int idx = 0;
for (auto&& shapes : targetStaticShapes) {
generate(idx++, shapes);
for (const auto& input : inputs) {
inferRequest.set_tensor(input.first, input.second);
}
inferRequest.infer();
auto outputTensor = inferRequest.get_output_tensor(0);
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
outputTensor.copy_to(copy);
outputs.push_back(copy);
}
reset();
return outputs;
}
};
TEST_P(ConcatSDPTransposeTest, CompareWithRefs) {
auto actualOutputs = run_test(function);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
CheckNumberOfNodesWithType(compiledModel, "Transpose", 1);
auto expectedOutputs = run_test(functionRefs);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
for (size_t i = 0; i < actualOutputs.size(); i++) {
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
}
}
namespace {
const std::vector<InputShapeAndTransposeOrder> inputShapeAndReorders = {
{
// inputShapes LLama
{
// B, H, L1, S
{{1, 8, -1, 64}, {{1, 8, 10, 64}, {1, 8, 1, 64}, {1, 8, 1, 64}, {1, 8, 20, 64}, {1, 8, 1, 64}}},
// B, H, L0, S
{{1, 8, -1, 64}, {{1, 8, 0, 64}, {1, 8, 10, 64}, {1, 8, 11, 64}, {1, 8, 12, 64}, {1, 8, 32, 64}}},
},
// transposeOrder
{0, 1, 2, 3}},
{// inputShapes QWen
{
// B, L1, H, S
{{1, -1, 8, 64}, {{1, 10, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {1, 20, 8, 64}, {1, 1, 8, 64}}},
// B, L0, H, S
{{1, -1, 8, 64}, {{1, 0, 8, 64}, {1, 10, 8, 64}, {1, 11, 8, 64}, {1, 12, 8, 64}, {1, 32, 8, 64}}},
},
// transposeOrder
{0, 2, 1, 3}},
{// inputShapes ChatGLM
{
// L1, B, H, S
{{-1, 1, 8, 64}, {{10, 1, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {20, 1, 8, 64}, {1, 1, 8, 64}}},
// L0, B, H, S
{{-1, 1, 8, 64}, {{0, 1, 8, 64}, {10, 1, 8, 64}, {11, 1, 8, 64}, {12, 1, 8, 64}, {32, 1, 8, 64}}},
},
// transposeOrder
{1, 2, 0, 3}},
};
INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTransposeTest,
ConcatSDPTransposeTest,
::testing::Combine(::testing::Values(ElementType::f32),
::testing::ValuesIn(inputShapeAndReorders),
::testing::Values(true, false)),
ConcatSDPTransposeTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions