From 91660b1c0557500a095f3d087b326404330f7677 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Sun, 26 Nov 2023 18:47:04 +0800 Subject: [PATCH] [CPU] Fix ScaledDotProductAttension accuracy problem (#21217) --- .../intel_cpu/src/nodes/scaled_attn.cpp | 40 +++++++++------ .../classes/scaled_attn.cpp | 38 +++++++------- .../classes/scaled_attn.hpp | 12 ++--- .../instances/x64/scaled_attn.cpp | 50 +++++++++++++++++-- .../model_hub_tests/torch_tests/test_timm.py | 3 +- 5 files changed, 99 insertions(+), 44 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 36a39531307..4db6a272e1a 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -123,7 +123,7 @@ struct MHAKernel { if (auto_causal) ncausal = kv_len - q_len + m + 1; for (size_t n = 0; n < ncausal; n++) { - auto* k = &present_key.at({b, h, n, 0}); + auto* k = &present_key.at({b, h, n, 0}, true); attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale; // apply alibi tensor @@ -154,7 +154,7 @@ struct MHAKernel { // linearly combine value word_vec.assign(head_size, 0.0f); for (size_t n = 0; n < ncausal; n++) { - auto* v = &present_value.at({b, h, n, 0}); + auto* v = &present_value.at({b, h, n, 0}, true); accumulate(word_vec.data(), v, head_size, attn_score[n]); } @@ -183,7 +183,7 @@ struct MHAKernel { using tag = dnnl::memory::format_tag; using dt = dnnl::memory::data_type; - void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) { + void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t Hk, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) { auto make_dnnl_dims = [](const std::vector& dims) { dnnl::memory::dims dnnl_dims(dims.size()); for (size_t i = 0; i < dims.size(); i++) @@ -192,7 +192,7 @@ struct MHAKernel { }; auto qkv_dt = precision_of::value == ov::element::f32 ? dt::f32 : dt::bf16; dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd); - dnnl::memory::desc cur_k_md(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd); + dnnl::memory::desc cur_k_md(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd); if (cur_q_md == q_md && cur_k_md == k_md) return; @@ -204,7 +204,7 @@ struct MHAKernel { qk_prim = dnnl::matmul(qk_pd); weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd); - v_md = dnnl::memory::desc(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd); + v_md = dnnl::memory::desc(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd); out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd); if (has_out_transpose) out_md = out_md.permute_axes({0, 2, 1, 3}); @@ -259,12 +259,13 @@ struct MHAKernel { auto H = query.size(1); auto q_len = query.size(2); auto head_size = query.size(3); + auto Hk = present_key.size(1); auto kv_len = present_key.size(2); if (d_scale == 0.0f) d_scale = 1.0f / sqrt(head_size); - prepare_prim(strm, B, H, q_len, kv_len, head_size, has_out_transpose); + prepare_prim(strm, B, H, Hk, q_len, kv_len, head_size, has_out_transpose); exec_qk(strm, query, present_key); PlainTensor score; @@ -495,7 +496,7 @@ struct MHASingleToken { std::vector cs(q_len); for (size_t pq = 0; pq < q_len; pq++) { as[pq] = &query.at({b, h, pq, 0}); - bs[pq] = &present_key.at({b_kv, h, pk, 0}); + bs[pq] = &present_key.at({b_kv, h, pk, 0}, true); cs[pq] = &m_attn_w.at({b, h, pq, pk}); } attn_dot_products(reinterpret_cast(as.data()), @@ -543,7 +544,7 @@ struct MHASingleToken { size_t idx = 0; for (size_t iwork = start; iwork < end; ++iwork) { auto b_kv = beams ? beams.at({b, pv}) : b; - auto* v = &present_value.at({b_kv, h, pv, 0}); + auto* v = &present_value.at({b_kv, h, pv, 0}, true); for (size_t pq = 0; pq < q_len; pq++) { outs[idx] = &m_temp.at({ithr, b, pq, h, 0}); weights[idx] = m_attn_w.at({b, h, pq, pv}); @@ -636,8 +637,8 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt PlainTensor present_key, present_value; q_input.assert_dims({B, H, L1, S}); - k_input.assert_dims({B, H, L0 + L1, S}); - v_input.assert_dims({B, H, L0 + L1, S}); + k_input.assert_dims({B, 0, L0 + L1, S}, true); + v_input.assert_dims({B, 0, L0 + L1, S}, true); m_query_emb = q_input; present_key = k_input; present_value = v_input; @@ -657,9 +658,11 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt // no attn_mask but has scale, there is a 1-d fake attn_mask if (input_num > 3 && attn_mask.m_rank > 1) { assert(attn_mask); - auto num = std::accumulate(attn_mask.m_dims, attn_mask.m_dims + attn_mask.m_rank, size_t{1}, std::multiplies()); - num /= B * (L0 + L1); - attn_mask = attn_mask.reshape({B, 1, num, L0 + L1}); + // spec requires at least 3, but torch sl test does use rank 2 + if (attn_mask.m_rank == 2) + attn_mask = attn_mask.reshape({1, 1, attn_mask.m_dims[0], attn_mask.m_dims[1]}); + else if (attn_mask.m_rank == 3) + attn_mask = attn_mask.reshape({1, attn_mask.m_dims[0], attn_mask.m_dims[1], attn_mask.m_dims[2]}); auto_causal = false; use_attn_mask = true; } else { @@ -758,12 +761,19 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptrget_input_partial_shape(0).size(); + // expect shape of q: [B, H, L, S] + auto inRank = op->get_input_partial_shape(0).size(); if (inRank != 4u) { errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank); return false; } + if (op->get_input_size() > 3) { + inRank = op->get_input_partial_shape(3).size(); + if (inRank > 4u) { + errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank); + return false; + } + } // using mha should be better for static shapes if (!op->is_dynamic()) { errorMessage = "Only run in dynamic mode"; diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.cpp index af329788348..644f94bc2bf 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.cpp @@ -18,21 +18,25 @@ namespace CPULayerTestsDefinitions { std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo& obj) { CPUSpecificParams cpuParams; ElementType inType; - InputShape inputShape; + std::vector inputShapes; bool is_causal; bool has_attn; bool has_scale; std::string targetDevice; - std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param; + std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param; std::ostringstream result; result << "netPRC=" << inType << "_"; - result << "IS=" << ov::test::utils::partialShape2str({inputShape.first}) << "_"; + result << "IS="; + for (const auto& inputShape : inputShapes) { + result << ov::test::utils::partialShape2str({inputShape.first}) << "_"; + } result << "TS="; - for (const auto& shape : inputShape.second) { - result << "("; - result << ov::test::utils::vec2str(shape); - result << ")_"; + for (const auto& shapes : inputShapes) { + for (const auto& shape : shapes.second) { + result << ov::test::utils::vec2str(shape); + result << "_"; + } } result << "is_causal=" << is_causal << "_"; result << "has_attn=" << has_attn << "_"; @@ -46,8 +50,8 @@ std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo void ScaledAttnLayerCPUTest::SetUp() { ElementType inType; CPUSpecificParams cpuParams; - InputShape inputShape; - std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam(); + std::vector inputShapes; + std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam(); std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; if (selectedType.empty()) { @@ -58,12 +62,12 @@ void ScaledAttnLayerCPUTest::SetUp() { rel_threshold = 2e-2f; } selectedType = makeSelectedTypeStr(selectedType, inType); - init_input_shapes({inputShape}); + init_input_shapes(inputShapes); ov::ParameterVector inputParams; // q,k,v inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); - inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); - inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); inputParams[0]->set_friendly_name("q"); inputParams[1]->set_friendly_name("k"); inputParams[2]->set_friendly_name("v"); @@ -77,9 +81,7 @@ void ScaledAttnLayerCPUTest::SetUp() { inputParams.back()->set_friendly_name("scale"); } else { if (has_attn) { - // attention_mask:[B, 1, 1, L0+L1] - ov::PartialShape attnShape{inputDynamicShapes[0][0], 1, 1, -1}; - inputParams.push_back(std::make_shared(inType, attnShape)); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[2])); inputParams.back()->set_friendly_name("attention_mask"); } if (has_scale) { @@ -106,14 +108,14 @@ void ScaledAttnLayerCPUTest::SetUp() { void ScaledAttnLayerCPUTest::generate_inputs(const std::vector& targetInputStaticShapes) { std::vector shapes(3); shapes[0] = targetInputStaticShapes[0]; - shapes[1] = targetInputStaticShapes[0]; - shapes[2] = targetInputStaticShapes[0]; + shapes[1] = targetInputStaticShapes[1]; + shapes[2] = targetInputStaticShapes[1]; if (!has_attn && has_scale) { shapes.push_back(ov::Shape{}); shapes.push_back(ov::Shape{1}); } else { if (has_attn) { - shapes.push_back(ov::Shape{targetInputStaticShapes[0][0], 1, 1, targetInputStaticShapes[0][2]}); + shapes.push_back(targetInputStaticShapes[2]); } if (has_scale) { shapes.push_back(ov::Shape{1}); diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.hpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.hpp index 107aac79637..0a11d159b50 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.hpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/classes/scaled_attn.hpp @@ -13,12 +13,12 @@ using namespace ov::test; namespace CPULayerTestsDefinitions { -typedef std::tuple, // shape + bool, // is_causal + bool, // has_attn + bool, // has_scale + std::string, // targetDevice CPUSpecificParams> ScaledAttnCPUTestParams; diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/instances/x64/scaled_attn.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/instances/x64/scaled_attn.cpp index e1d063ea613..a4b993c58a7 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/instances/x64/scaled_attn.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/instances/x64/scaled_attn.cpp @@ -14,10 +14,52 @@ namespace CPULayerTestsDefinitions { namespace ScaledAttn { const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"}; -const std::vector shapes{ - {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, - {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} - } +const std::vector> shapes{ + // normal case, shapes of q,k,v are same + { + // q shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + }, + // kv shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + }, + // attn shape: [B, 1, 1, L0+L1] + {ov::test::InputShape{ov::PartialShape{-1, 1, 1, -1}, + {ov::Shape{1, 1, 1, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 1, 10}}} + }, + }, + // heads number of kv is 1, attn mask: [B, H, L1, L0+L1] + { + // q shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + }, + // kv shape + {ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64}, + {ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}} + }, + // attn shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1}, + {ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}} + }, + }, + // heads number of kv is 1, attn mask: [H, L1, L0+L1] + { + // q shape + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, + {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + }, + // kv shape + {ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64}, + {ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}} + }, + // attn shape + {ov::test::InputShape{ov::PartialShape{8, -1, -1}, + {ov::Shape{8, 100, 100}, ov::Shape{8, 1, 1}, ov::Shape{8, 10, 10}}} + }, + }, }; const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16), diff --git a/tests/model_hub_tests/torch_tests/test_timm.py b/tests/model_hub_tests/torch_tests/test_timm.py index d6dd438df5c..0749926f6ef 100644 --- a/tests/model_hub_tests/torch_tests/test_timm.py +++ b/tests/model_hub_tests/torch_tests/test_timm.py @@ -75,7 +75,8 @@ class TestTimmConvertModel(TestConvertModel): @pytest.mark.parametrize("name", ["mobilevitv2_050.cvnets_in1k", "poolformerv2_s12.sail_in1k", - "vit_base_patch8_224.augreg_in21k"]) + "vit_base_patch8_224.augreg_in21k", + "beit_base_patch16_224.in22k_ft_in22k"]) @pytest.mark.precommit def test_convert_model_precommit(self, name, ie_device): self.run(name, None, ie_device)