[CPU] Fix ScaledDotProductAttension accuracy problem (#21217)

This commit is contained in:
Luo Cheng 2023-11-26 18:47:04 +08:00 committed by GitHub
parent 493a338ad2
commit 91660b1c05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 44 deletions

View File

@ -123,7 +123,7 @@ struct MHAKernel {
if (auto_causal) if (auto_causal)
ncausal = kv_len - q_len + m + 1; ncausal = kv_len - q_len + m + 1;
for (size_t n = 0; n < ncausal; n++) { for (size_t n = 0; n < ncausal; n++) {
auto* k = &present_key.at({b, h, n, 0}); auto* k = &present_key.at({b, h, n, 0}, true);
attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale; attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale;
// apply alibi tensor // apply alibi tensor
@ -154,7 +154,7 @@ struct MHAKernel {
// linearly combine value // linearly combine value
word_vec.assign(head_size, 0.0f); word_vec.assign(head_size, 0.0f);
for (size_t n = 0; n < ncausal; n++) { for (size_t n = 0; n < ncausal; n++) {
auto* v = &present_value.at({b, h, n, 0}); auto* v = &present_value.at({b, h, n, 0}, true);
accumulate(word_vec.data(), v, head_size, attn_score[n]); accumulate(word_vec.data(), v, head_size, attn_score[n]);
} }
@ -183,7 +183,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
using tag = dnnl::memory::format_tag; using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type; using dt = dnnl::memory::data_type;
void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) { void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t Hk, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) {
auto make_dnnl_dims = [](const std::vector<size_t>& dims) { auto make_dnnl_dims = [](const std::vector<size_t>& dims) {
dnnl::memory::dims dnnl_dims(dims.size()); dnnl::memory::dims dnnl_dims(dims.size());
for (size_t i = 0; i < dims.size(); i++) for (size_t i = 0; i < dims.size(); i++)
@ -192,7 +192,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
}; };
auto qkv_dt = precision_of<T>::value == ov::element::f32 ? dt::f32 : dt::bf16; auto qkv_dt = precision_of<T>::value == ov::element::f32 ? dt::f32 : dt::bf16;
dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd); dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
dnnl::memory::desc cur_k_md(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd); dnnl::memory::desc cur_k_md(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
if (cur_q_md == q_md && cur_k_md == k_md) if (cur_q_md == q_md && cur_k_md == k_md)
return; return;
@ -204,7 +204,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
qk_prim = dnnl::matmul(qk_pd); qk_prim = dnnl::matmul(qk_pd);
weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd); weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd);
v_md = dnnl::memory::desc(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd); v_md = dnnl::memory::desc(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd); out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
if (has_out_transpose) if (has_out_transpose)
out_md = out_md.permute_axes({0, 2, 1, 3}); out_md = out_md.permute_axes({0, 2, 1, 3});
@ -259,12 +259,13 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
auto H = query.size(1); auto H = query.size(1);
auto q_len = query.size(2); auto q_len = query.size(2);
auto head_size = query.size(3); auto head_size = query.size(3);
auto Hk = present_key.size(1);
auto kv_len = present_key.size(2); auto kv_len = present_key.size(2);
if (d_scale == 0.0f) if (d_scale == 0.0f)
d_scale = 1.0f / sqrt(head_size); d_scale = 1.0f / sqrt(head_size);
prepare_prim(strm, B, H, q_len, kv_len, head_size, has_out_transpose); prepare_prim(strm, B, H, Hk, q_len, kv_len, head_size, has_out_transpose);
exec_qk(strm, query, present_key); exec_qk(strm, query, present_key);
PlainTensor<float> score; PlainTensor<float> score;
@ -495,7 +496,7 @@ struct MHASingleToken {
std::vector<float*> cs(q_len); std::vector<float*> cs(q_len);
for (size_t pq = 0; pq < q_len; pq++) { for (size_t pq = 0; pq < q_len; pq++) {
as[pq] = &query.at({b, h, pq, 0}); as[pq] = &query.at({b, h, pq, 0});
bs[pq] = &present_key.at({b_kv, h, pk, 0}); bs[pq] = &present_key.at({b_kv, h, pk, 0}, true);
cs[pq] = &m_attn_w.at({b, h, pq, pk}); cs[pq] = &m_attn_w.at({b, h, pq, pk});
} }
attn_dot_products(reinterpret_cast<void**>(as.data()), attn_dot_products(reinterpret_cast<void**>(as.data()),
@ -543,7 +544,7 @@ struct MHASingleToken {
size_t idx = 0; size_t idx = 0;
for (size_t iwork = start; iwork < end; ++iwork) { for (size_t iwork = start; iwork < end; ++iwork) {
auto b_kv = beams ? beams.at({b, pv}) : b; auto b_kv = beams ? beams.at({b, pv}) : b;
auto* v = &present_value.at({b_kv, h, pv, 0}); auto* v = &present_value.at({b_kv, h, pv, 0}, true);
for (size_t pq = 0; pq < q_len; pq++) { for (size_t pq = 0; pq < q_len; pq++) {
outs[idx] = &m_temp.at({ithr, b, pq, h, 0}); outs[idx] = &m_temp.at({ithr, b, pq, h, 0});
weights[idx] = m_attn_w.at({b, h, pq, pv}); weights[idx] = m_attn_w.at({b, h, pq, pv});
@ -636,8 +637,8 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
PlainTensor<T> present_key, present_value; PlainTensor<T> present_key, present_value;
q_input.assert_dims({B, H, L1, S}); q_input.assert_dims({B, H, L1, S});
k_input.assert_dims({B, H, L0 + L1, S}); k_input.assert_dims({B, 0, L0 + L1, S}, true);
v_input.assert_dims({B, H, L0 + L1, S}); v_input.assert_dims({B, 0, L0 + L1, S}, true);
m_query_emb = q_input; m_query_emb = q_input;
present_key = k_input; present_key = k_input;
present_value = v_input; present_value = v_input;
@ -657,9 +658,11 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
// no attn_mask but has scale, there is a 1-d fake attn_mask // no attn_mask but has scale, there is a 1-d fake attn_mask
if (input_num > 3 && attn_mask.m_rank > 1) { if (input_num > 3 && attn_mask.m_rank > 1) {
assert(attn_mask); assert(attn_mask);
auto num = std::accumulate(attn_mask.m_dims, attn_mask.m_dims + attn_mask.m_rank, size_t{1}, std::multiplies<size_t>()); // spec requires at least 3, but torch sl test does use rank 2
num /= B * (L0 + L1); if (attn_mask.m_rank == 2)
attn_mask = attn_mask.reshape({B, 1, num, L0 + L1}); attn_mask = attn_mask.reshape({1, 1, attn_mask.m_dims[0], attn_mask.m_dims[1]});
else if (attn_mask.m_rank == 3)
attn_mask = attn_mask.reshape({1, attn_mask.m_dims[0], attn_mask.m_dims[1], attn_mask.m_dims[2]});
auto_causal = false; auto_causal = false;
use_attn_mask = true; use_attn_mask = true;
} else { } else {
@ -758,12 +761,19 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
errorMessage = "Only ScaledDotProductAttention operation are supported"; errorMessage = "Only ScaledDotProductAttention operation are supported";
return false; return false;
} }
// expect shape: [B, H, L, S] // expect shape of q: [B, H, L, S]
const auto inRank = op->get_input_partial_shape(0).size(); auto inRank = op->get_input_partial_shape(0).size();
if (inRank != 4u) { if (inRank != 4u) {
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank); errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
return false; return false;
} }
if (op->get_input_size() > 3) {
inRank = op->get_input_partial_shape(3).size();
if (inRank > 4u) {
errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);
return false;
}
}
// using mha should be better for static shapes // using mha should be better for static shapes
if (!op->is_dynamic()) { if (!op->is_dynamic()) {
errorMessage = "Only run in dynamic mode"; errorMessage = "Only run in dynamic mode";

View File

@ -18,21 +18,25 @@ namespace CPULayerTestsDefinitions {
std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo<ScaledAttnCPUTestParams>& obj) { std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo<ScaledAttnCPUTestParams>& obj) {
CPUSpecificParams cpuParams; CPUSpecificParams cpuParams;
ElementType inType; ElementType inType;
InputShape inputShape; std::vector<InputShape> inputShapes;
bool is_causal; bool is_causal;
bool has_attn; bool has_attn;
bool has_scale; bool has_scale;
std::string targetDevice; std::string targetDevice;
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param; std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param;
std::ostringstream result; std::ostringstream result;
result << "netPRC=" << inType << "_"; result << "netPRC=" << inType << "_";
result << "IS=" << ov::test::utils::partialShape2str({inputShape.first}) << "_"; result << "IS=";
for (const auto& inputShape : inputShapes) {
result << ov::test::utils::partialShape2str({inputShape.first}) << "_";
}
result << "TS="; result << "TS=";
for (const auto& shape : inputShape.second) { for (const auto& shapes : inputShapes) {
result << "("; for (const auto& shape : shapes.second) {
result << ov::test::utils::vec2str(shape); result << ov::test::utils::vec2str(shape);
result << ")_"; result << "_";
}
} }
result << "is_causal=" << is_causal << "_"; result << "is_causal=" << is_causal << "_";
result << "has_attn=" << has_attn << "_"; result << "has_attn=" << has_attn << "_";
@ -46,8 +50,8 @@ std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo
void ScaledAttnLayerCPUTest::SetUp() { void ScaledAttnLayerCPUTest::SetUp() {
ElementType inType; ElementType inType;
CPUSpecificParams cpuParams; CPUSpecificParams cpuParams;
InputShape inputShape; std::vector<InputShape> inputShapes;
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam(); std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam();
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
if (selectedType.empty()) { if (selectedType.empty()) {
@ -58,12 +62,12 @@ void ScaledAttnLayerCPUTest::SetUp() {
rel_threshold = 2e-2f; rel_threshold = 2e-2f;
} }
selectedType = makeSelectedTypeStr(selectedType, inType); selectedType = makeSelectedTypeStr(selectedType, inType);
init_input_shapes({inputShape}); init_input_shapes(inputShapes);
ov::ParameterVector inputParams; ov::ParameterVector inputParams;
// q,k,v // q,k,v
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0])); inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0])); inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0])); inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams[0]->set_friendly_name("q"); inputParams[0]->set_friendly_name("q");
inputParams[1]->set_friendly_name("k"); inputParams[1]->set_friendly_name("k");
inputParams[2]->set_friendly_name("v"); inputParams[2]->set_friendly_name("v");
@ -77,9 +81,7 @@ void ScaledAttnLayerCPUTest::SetUp() {
inputParams.back()->set_friendly_name("scale"); inputParams.back()->set_friendly_name("scale");
} else { } else {
if (has_attn) { if (has_attn) {
// attention_mask[B, 1, 1, L0+L1] inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
ov::PartialShape attnShape{inputDynamicShapes[0][0], 1, 1, -1};
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, attnShape));
inputParams.back()->set_friendly_name("attention_mask"); inputParams.back()->set_friendly_name("attention_mask");
} }
if (has_scale) { if (has_scale) {
@ -106,14 +108,14 @@ void ScaledAttnLayerCPUTest::SetUp() {
void ScaledAttnLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) { void ScaledAttnLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
std::vector<ov::Shape> shapes(3); std::vector<ov::Shape> shapes(3);
shapes[0] = targetInputStaticShapes[0]; shapes[0] = targetInputStaticShapes[0];
shapes[1] = targetInputStaticShapes[0]; shapes[1] = targetInputStaticShapes[1];
shapes[2] = targetInputStaticShapes[0]; shapes[2] = targetInputStaticShapes[1];
if (!has_attn && has_scale) { if (!has_attn && has_scale) {
shapes.push_back(ov::Shape{}); shapes.push_back(ov::Shape{});
shapes.push_back(ov::Shape{1}); shapes.push_back(ov::Shape{1});
} else { } else {
if (has_attn) { if (has_attn) {
shapes.push_back(ov::Shape{targetInputStaticShapes[0][0], 1, 1, targetInputStaticShapes[0][2]}); shapes.push_back(targetInputStaticShapes[2]);
} }
if (has_scale) { if (has_scale) {
shapes.push_back(ov::Shape{1}); shapes.push_back(ov::Shape{1});

View File

@ -13,12 +13,12 @@ using namespace ov::test;
namespace CPULayerTestsDefinitions { namespace CPULayerTestsDefinitions {
typedef std::tuple<ElementType, // netPrecision typedef std::tuple<ElementType, // netPrecision
InputShape, // shape std::vector<InputShape>, // shape
bool, // is_causal bool, // is_causal
bool, // has_attn bool, // has_attn
bool, // has_scale bool, // has_scale
std::string, // targetDevice std::string, // targetDevice
CPUSpecificParams> CPUSpecificParams>
ScaledAttnCPUTestParams; ScaledAttnCPUTestParams;

View File

@ -14,10 +14,52 @@ namespace CPULayerTestsDefinitions {
namespace ScaledAttn { namespace ScaledAttn {
const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"}; const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"};
const std::vector<InputShape> shapes{ const std::vector<std::vector<InputShape>> shapes{
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, // normal case, shapes of q,k,v are same
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} {
} // q shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// attn shape: [B, 1, 1, L0+L1]
{ov::test::InputShape{ov::PartialShape{-1, 1, 1, -1},
{ov::Shape{1, 1, 1, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 1, 10}}}
},
},
// heads number of kv is 1, attn mask: [B, H, L1, L0+L1]
{
// q shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
},
// attn shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1},
{ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}}
},
},
// heads number of kv is 1, attn mask: [H, L1, L0+L1]
{
// q shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
},
// attn shape
{ov::test::InputShape{ov::PartialShape{8, -1, -1},
{ov::Shape{8, 100, 100}, ov::Shape{8, 1, 1}, ov::Shape{8, 10, 10}}}
},
},
}; };
const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16), const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16),

View File

@ -75,7 +75,8 @@ class TestTimmConvertModel(TestConvertModel):
@pytest.mark.parametrize("name", ["mobilevitv2_050.cvnets_in1k", @pytest.mark.parametrize("name", ["mobilevitv2_050.cvnets_in1k",
"poolformerv2_s12.sail_in1k", "poolformerv2_s12.sail_in1k",
"vit_base_patch8_224.augreg_in21k"]) "vit_base_patch8_224.augreg_in21k",
"beit_base_patch16_224.in22k_ft_in22k"])
@pytest.mark.precommit @pytest.mark.precommit
def test_convert_model_precommit(self, name, ie_device): def test_convert_model_precommit(self, name, ie_device):
self.run(name, None, ie_device) self.run(name, None, ie_device)