[CPU] Fix ScaledDotProductAttension accuracy problem (#21217)
This commit is contained in:
parent
493a338ad2
commit
91660b1c05
@ -123,7 +123,7 @@ struct MHAKernel {
|
||||
if (auto_causal)
|
||||
ncausal = kv_len - q_len + m + 1;
|
||||
for (size_t n = 0; n < ncausal; n++) {
|
||||
auto* k = &present_key.at({b, h, n, 0});
|
||||
auto* k = &present_key.at({b, h, n, 0}, true);
|
||||
attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale;
|
||||
|
||||
// apply alibi tensor
|
||||
@ -154,7 +154,7 @@ struct MHAKernel {
|
||||
// linearly combine value
|
||||
word_vec.assign(head_size, 0.0f);
|
||||
for (size_t n = 0; n < ncausal; n++) {
|
||||
auto* v = &present_value.at({b, h, n, 0});
|
||||
auto* v = &present_value.at({b, h, n, 0}, true);
|
||||
accumulate(word_vec.data(), v, head_size, attn_score[n]);
|
||||
}
|
||||
|
||||
@ -183,7 +183,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
||||
using tag = dnnl::memory::format_tag;
|
||||
using dt = dnnl::memory::data_type;
|
||||
|
||||
void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) {
|
||||
void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t Hk, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) {
|
||||
auto make_dnnl_dims = [](const std::vector<size_t>& dims) {
|
||||
dnnl::memory::dims dnnl_dims(dims.size());
|
||||
for (size_t i = 0; i < dims.size(); i++)
|
||||
@ -192,7 +192,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
||||
};
|
||||
auto qkv_dt = precision_of<T>::value == ov::element::f32 ? dt::f32 : dt::bf16;
|
||||
dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
|
||||
dnnl::memory::desc cur_k_md(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd);
|
||||
dnnl::memory::desc cur_k_md(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
|
||||
if (cur_q_md == q_md && cur_k_md == k_md)
|
||||
return;
|
||||
|
||||
@ -204,7 +204,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
||||
qk_prim = dnnl::matmul(qk_pd);
|
||||
|
||||
weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd);
|
||||
v_md = dnnl::memory::desc(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd);
|
||||
v_md = dnnl::memory::desc(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
|
||||
out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
|
||||
if (has_out_transpose)
|
||||
out_md = out_md.permute_axes({0, 2, 1, 3});
|
||||
@ -259,12 +259,13 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
||||
auto H = query.size(1);
|
||||
auto q_len = query.size(2);
|
||||
auto head_size = query.size(3);
|
||||
auto Hk = present_key.size(1);
|
||||
auto kv_len = present_key.size(2);
|
||||
|
||||
if (d_scale == 0.0f)
|
||||
d_scale = 1.0f / sqrt(head_size);
|
||||
|
||||
prepare_prim(strm, B, H, q_len, kv_len, head_size, has_out_transpose);
|
||||
prepare_prim(strm, B, H, Hk, q_len, kv_len, head_size, has_out_transpose);
|
||||
exec_qk(strm, query, present_key);
|
||||
|
||||
PlainTensor<float> score;
|
||||
@ -495,7 +496,7 @@ struct MHASingleToken {
|
||||
std::vector<float*> cs(q_len);
|
||||
for (size_t pq = 0; pq < q_len; pq++) {
|
||||
as[pq] = &query.at({b, h, pq, 0});
|
||||
bs[pq] = &present_key.at({b_kv, h, pk, 0});
|
||||
bs[pq] = &present_key.at({b_kv, h, pk, 0}, true);
|
||||
cs[pq] = &m_attn_w.at({b, h, pq, pk});
|
||||
}
|
||||
attn_dot_products(reinterpret_cast<void**>(as.data()),
|
||||
@ -543,7 +544,7 @@ struct MHASingleToken {
|
||||
size_t idx = 0;
|
||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||
auto b_kv = beams ? beams.at({b, pv}) : b;
|
||||
auto* v = &present_value.at({b_kv, h, pv, 0});
|
||||
auto* v = &present_value.at({b_kv, h, pv, 0}, true);
|
||||
for (size_t pq = 0; pq < q_len; pq++) {
|
||||
outs[idx] = &m_temp.at({ithr, b, pq, h, 0});
|
||||
weights[idx] = m_attn_w.at({b, h, pq, pv});
|
||||
@ -636,8 +637,8 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
||||
PlainTensor<T> present_key, present_value;
|
||||
|
||||
q_input.assert_dims({B, H, L1, S});
|
||||
k_input.assert_dims({B, H, L0 + L1, S});
|
||||
v_input.assert_dims({B, H, L0 + L1, S});
|
||||
k_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||
v_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||
m_query_emb = q_input;
|
||||
present_key = k_input;
|
||||
present_value = v_input;
|
||||
@ -657,9 +658,11 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
||||
// no attn_mask but has scale, there is a 1-d fake attn_mask
|
||||
if (input_num > 3 && attn_mask.m_rank > 1) {
|
||||
assert(attn_mask);
|
||||
auto num = std::accumulate(attn_mask.m_dims, attn_mask.m_dims + attn_mask.m_rank, size_t{1}, std::multiplies<size_t>());
|
||||
num /= B * (L0 + L1);
|
||||
attn_mask = attn_mask.reshape({B, 1, num, L0 + L1});
|
||||
// spec requires at least 3, but torch sl test does use rank 2
|
||||
if (attn_mask.m_rank == 2)
|
||||
attn_mask = attn_mask.reshape({1, 1, attn_mask.m_dims[0], attn_mask.m_dims[1]});
|
||||
else if (attn_mask.m_rank == 3)
|
||||
attn_mask = attn_mask.reshape({1, attn_mask.m_dims[0], attn_mask.m_dims[1], attn_mask.m_dims[2]});
|
||||
auto_causal = false;
|
||||
use_attn_mask = true;
|
||||
} else {
|
||||
@ -758,12 +761,19 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
|
||||
errorMessage = "Only ScaledDotProductAttention operation are supported";
|
||||
return false;
|
||||
}
|
||||
// expect shape: [B, H, L, S]
|
||||
const auto inRank = op->get_input_partial_shape(0).size();
|
||||
// expect shape of q: [B, H, L, S]
|
||||
auto inRank = op->get_input_partial_shape(0).size();
|
||||
if (inRank != 4u) {
|
||||
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
|
||||
return false;
|
||||
}
|
||||
if (op->get_input_size() > 3) {
|
||||
inRank = op->get_input_partial_shape(3).size();
|
||||
if (inRank > 4u) {
|
||||
errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// using mha should be better for static shapes
|
||||
if (!op->is_dynamic()) {
|
||||
errorMessage = "Only run in dynamic mode";
|
||||
|
@ -18,21 +18,25 @@ namespace CPULayerTestsDefinitions {
|
||||
std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo<ScaledAttnCPUTestParams>& obj) {
|
||||
CPUSpecificParams cpuParams;
|
||||
ElementType inType;
|
||||
InputShape inputShape;
|
||||
std::vector<InputShape> inputShapes;
|
||||
bool is_causal;
|
||||
bool has_attn;
|
||||
bool has_scale;
|
||||
std::string targetDevice;
|
||||
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param;
|
||||
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "netPRC=" << inType << "_";
|
||||
result << "IS=" << ov::test::utils::partialShape2str({inputShape.first}) << "_";
|
||||
result << "IS=";
|
||||
for (const auto& inputShape : inputShapes) {
|
||||
result << ov::test::utils::partialShape2str({inputShape.first}) << "_";
|
||||
}
|
||||
result << "TS=";
|
||||
for (const auto& shape : inputShape.second) {
|
||||
result << "(";
|
||||
result << ov::test::utils::vec2str(shape);
|
||||
result << ")_";
|
||||
for (const auto& shapes : inputShapes) {
|
||||
for (const auto& shape : shapes.second) {
|
||||
result << ov::test::utils::vec2str(shape);
|
||||
result << "_";
|
||||
}
|
||||
}
|
||||
result << "is_causal=" << is_causal << "_";
|
||||
result << "has_attn=" << has_attn << "_";
|
||||
@ -46,8 +50,8 @@ std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo
|
||||
void ScaledAttnLayerCPUTest::SetUp() {
|
||||
ElementType inType;
|
||||
CPUSpecificParams cpuParams;
|
||||
InputShape inputShape;
|
||||
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam();
|
||||
std::vector<InputShape> inputShapes;
|
||||
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam();
|
||||
|
||||
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
||||
if (selectedType.empty()) {
|
||||
@ -58,12 +62,12 @@ void ScaledAttnLayerCPUTest::SetUp() {
|
||||
rel_threshold = 2e-2f;
|
||||
}
|
||||
selectedType = makeSelectedTypeStr(selectedType, inType);
|
||||
init_input_shapes({inputShape});
|
||||
init_input_shapes(inputShapes);
|
||||
ov::ParameterVector inputParams;
|
||||
// q,k,v
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||
inputParams[0]->set_friendly_name("q");
|
||||
inputParams[1]->set_friendly_name("k");
|
||||
inputParams[2]->set_friendly_name("v");
|
||||
@ -77,9 +81,7 @@ void ScaledAttnLayerCPUTest::SetUp() {
|
||||
inputParams.back()->set_friendly_name("scale");
|
||||
} else {
|
||||
if (has_attn) {
|
||||
// attention_mask:[B, 1, 1, L0+L1]
|
||||
ov::PartialShape attnShape{inputDynamicShapes[0][0], 1, 1, -1};
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, attnShape));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
|
||||
inputParams.back()->set_friendly_name("attention_mask");
|
||||
}
|
||||
if (has_scale) {
|
||||
@ -106,14 +108,14 @@ void ScaledAttnLayerCPUTest::SetUp() {
|
||||
void ScaledAttnLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
|
||||
std::vector<ov::Shape> shapes(3);
|
||||
shapes[0] = targetInputStaticShapes[0];
|
||||
shapes[1] = targetInputStaticShapes[0];
|
||||
shapes[2] = targetInputStaticShapes[0];
|
||||
shapes[1] = targetInputStaticShapes[1];
|
||||
shapes[2] = targetInputStaticShapes[1];
|
||||
if (!has_attn && has_scale) {
|
||||
shapes.push_back(ov::Shape{});
|
||||
shapes.push_back(ov::Shape{1});
|
||||
} else {
|
||||
if (has_attn) {
|
||||
shapes.push_back(ov::Shape{targetInputStaticShapes[0][0], 1, 1, targetInputStaticShapes[0][2]});
|
||||
shapes.push_back(targetInputStaticShapes[2]);
|
||||
}
|
||||
if (has_scale) {
|
||||
shapes.push_back(ov::Shape{1});
|
||||
|
@ -13,12 +13,12 @@ using namespace ov::test;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
typedef std::tuple<ElementType, // netPrecision
|
||||
InputShape, // shape
|
||||
bool, // is_causal
|
||||
bool, // has_attn
|
||||
bool, // has_scale
|
||||
std::string, // targetDevice
|
||||
typedef std::tuple<ElementType, // netPrecision
|
||||
std::vector<InputShape>, // shape
|
||||
bool, // is_causal
|
||||
bool, // has_attn
|
||||
bool, // has_scale
|
||||
std::string, // targetDevice
|
||||
CPUSpecificParams>
|
||||
ScaledAttnCPUTestParams;
|
||||
|
||||
|
@ -14,10 +14,52 @@ namespace CPULayerTestsDefinitions {
|
||||
namespace ScaledAttn {
|
||||
const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"};
|
||||
|
||||
const std::vector<InputShape> shapes{
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||
}
|
||||
const std::vector<std::vector<InputShape>> shapes{
|
||||
// normal case, shapes of q,k,v are same
|
||||
{
|
||||
// q shape
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||
},
|
||||
// kv shape
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||
},
|
||||
// attn shape: [B, 1, 1, L0+L1]
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 1, 1, -1},
|
||||
{ov::Shape{1, 1, 1, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 1, 10}}}
|
||||
},
|
||||
},
|
||||
// heads number of kv is 1, attn mask: [B, H, L1, L0+L1]
|
||||
{
|
||||
// q shape
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||
},
|
||||
// kv shape
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
|
||||
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
|
||||
},
|
||||
// attn shape
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1},
|
||||
{ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}}
|
||||
},
|
||||
},
|
||||
// heads number of kv is 1, attn mask: [H, L1, L0+L1]
|
||||
{
|
||||
// q shape
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||
},
|
||||
// kv shape
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
|
||||
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
|
||||
},
|
||||
// attn shape
|
||||
{ov::test::InputShape{ov::PartialShape{8, -1, -1},
|
||||
{ov::Shape{8, 100, 100}, ov::Shape{8, 1, 1}, ov::Shape{8, 10, 10}}}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16),
|
||||
|
@ -75,7 +75,8 @@ class TestTimmConvertModel(TestConvertModel):
|
||||
|
||||
@pytest.mark.parametrize("name", ["mobilevitv2_050.cvnets_in1k",
|
||||
"poolformerv2_s12.sail_in1k",
|
||||
"vit_base_patch8_224.augreg_in21k"])
|
||||
"vit_base_patch8_224.augreg_in21k",
|
||||
"beit_base_patch16_224.in22k_ft_in22k"])
|
||||
@pytest.mark.precommit
|
||||
def test_convert_model_precommit(self, name, ie_device):
|
||||
self.run(name, None, ie_device)
|
||||
|
Loading…
Reference in New Issue
Block a user