[CPU] Fix ScaledDotProductAttension accuracy problem (#21217)

This commit is contained in:
Luo Cheng 2023-11-26 18:47:04 +08:00 committed by GitHub
parent 493a338ad2
commit 91660b1c05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 44 deletions

View File

@ -123,7 +123,7 @@ struct MHAKernel {
if (auto_causal)
ncausal = kv_len - q_len + m + 1;
for (size_t n = 0; n < ncausal; n++) {
auto* k = &present_key.at({b, h, n, 0});
auto* k = &present_key.at({b, h, n, 0}, true);
attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale;
// apply alibi tensor
@ -154,7 +154,7 @@ struct MHAKernel {
// linearly combine value
word_vec.assign(head_size, 0.0f);
for (size_t n = 0; n < ncausal; n++) {
auto* v = &present_value.at({b, h, n, 0});
auto* v = &present_value.at({b, h, n, 0}, true);
accumulate(word_vec.data(), v, head_size, attn_score[n]);
}
@ -183,7 +183,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;
void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) {
void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t Hk, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) {
auto make_dnnl_dims = [](const std::vector<size_t>& dims) {
dnnl::memory::dims dnnl_dims(dims.size());
for (size_t i = 0; i < dims.size(); i++)
@ -192,7 +192,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
};
auto qkv_dt = precision_of<T>::value == ov::element::f32 ? dt::f32 : dt::bf16;
dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
dnnl::memory::desc cur_k_md(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd);
dnnl::memory::desc cur_k_md(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
if (cur_q_md == q_md && cur_k_md == k_md)
return;
@ -204,7 +204,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
qk_prim = dnnl::matmul(qk_pd);
weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd);
v_md = dnnl::memory::desc(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd);
v_md = dnnl::memory::desc(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
if (has_out_transpose)
out_md = out_md.permute_axes({0, 2, 1, 3});
@ -259,12 +259,13 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
auto H = query.size(1);
auto q_len = query.size(2);
auto head_size = query.size(3);
auto Hk = present_key.size(1);
auto kv_len = present_key.size(2);
if (d_scale == 0.0f)
d_scale = 1.0f / sqrt(head_size);
prepare_prim(strm, B, H, q_len, kv_len, head_size, has_out_transpose);
prepare_prim(strm, B, H, Hk, q_len, kv_len, head_size, has_out_transpose);
exec_qk(strm, query, present_key);
PlainTensor<float> score;
@ -495,7 +496,7 @@ struct MHASingleToken {
std::vector<float*> cs(q_len);
for (size_t pq = 0; pq < q_len; pq++) {
as[pq] = &query.at({b, h, pq, 0});
bs[pq] = &present_key.at({b_kv, h, pk, 0});
bs[pq] = &present_key.at({b_kv, h, pk, 0}, true);
cs[pq] = &m_attn_w.at({b, h, pq, pk});
}
attn_dot_products(reinterpret_cast<void**>(as.data()),
@ -543,7 +544,7 @@ struct MHASingleToken {
size_t idx = 0;
for (size_t iwork = start; iwork < end; ++iwork) {
auto b_kv = beams ? beams.at({b, pv}) : b;
auto* v = &present_value.at({b_kv, h, pv, 0});
auto* v = &present_value.at({b_kv, h, pv, 0}, true);
for (size_t pq = 0; pq < q_len; pq++) {
outs[idx] = &m_temp.at({ithr, b, pq, h, 0});
weights[idx] = m_attn_w.at({b, h, pq, pv});
@ -636,8 +637,8 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
PlainTensor<T> present_key, present_value;
q_input.assert_dims({B, H, L1, S});
k_input.assert_dims({B, H, L0 + L1, S});
v_input.assert_dims({B, H, L0 + L1, S});
k_input.assert_dims({B, 0, L0 + L1, S}, true);
v_input.assert_dims({B, 0, L0 + L1, S}, true);
m_query_emb = q_input;
present_key = k_input;
present_value = v_input;
@ -657,9 +658,11 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
// no attn_mask but has scale, there is a 1-d fake attn_mask
if (input_num > 3 && attn_mask.m_rank > 1) {
assert(attn_mask);
auto num = std::accumulate(attn_mask.m_dims, attn_mask.m_dims + attn_mask.m_rank, size_t{1}, std::multiplies<size_t>());
num /= B * (L0 + L1);
attn_mask = attn_mask.reshape({B, 1, num, L0 + L1});
// spec requires at least 3, but torch sl test does use rank 2
if (attn_mask.m_rank == 2)
attn_mask = attn_mask.reshape({1, 1, attn_mask.m_dims[0], attn_mask.m_dims[1]});
else if (attn_mask.m_rank == 3)
attn_mask = attn_mask.reshape({1, attn_mask.m_dims[0], attn_mask.m_dims[1], attn_mask.m_dims[2]});
auto_causal = false;
use_attn_mask = true;
} else {
@ -758,12 +761,19 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
errorMessage = "Only ScaledDotProductAttention operation are supported";
return false;
}
// expect shape: [B, H, L, S]
const auto inRank = op->get_input_partial_shape(0).size();
// expect shape of q: [B, H, L, S]
auto inRank = op->get_input_partial_shape(0).size();
if (inRank != 4u) {
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
return false;
}
if (op->get_input_size() > 3) {
inRank = op->get_input_partial_shape(3).size();
if (inRank > 4u) {
errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);
return false;
}
}
// using mha should be better for static shapes
if (!op->is_dynamic()) {
errorMessage = "Only run in dynamic mode";

View File

@ -18,21 +18,25 @@ namespace CPULayerTestsDefinitions {
std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo<ScaledAttnCPUTestParams>& obj) {
CPUSpecificParams cpuParams;
ElementType inType;
InputShape inputShape;
std::vector<InputShape> inputShapes;
bool is_causal;
bool has_attn;
bool has_scale;
std::string targetDevice;
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param;
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param;
std::ostringstream result;
result << "netPRC=" << inType << "_";
result << "IS=" << ov::test::utils::partialShape2str({inputShape.first}) << "_";
result << "IS=";
for (const auto& inputShape : inputShapes) {
result << ov::test::utils::partialShape2str({inputShape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShape.second) {
result << "(";
result << ov::test::utils::vec2str(shape);
result << ")_";
for (const auto& shapes : inputShapes) {
for (const auto& shape : shapes.second) {
result << ov::test::utils::vec2str(shape);
result << "_";
}
}
result << "is_causal=" << is_causal << "_";
result << "has_attn=" << has_attn << "_";
@ -46,8 +50,8 @@ std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo
void ScaledAttnLayerCPUTest::SetUp() {
ElementType inType;
CPUSpecificParams cpuParams;
InputShape inputShape;
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam();
std::vector<InputShape> inputShapes;
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam();
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
if (selectedType.empty()) {
@ -58,12 +62,12 @@ void ScaledAttnLayerCPUTest::SetUp() {
rel_threshold = 2e-2f;
}
selectedType = makeSelectedTypeStr(selectedType, inType);
init_input_shapes({inputShape});
init_input_shapes(inputShapes);
ov::ParameterVector inputParams;
// q,k,v
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams[0]->set_friendly_name("q");
inputParams[1]->set_friendly_name("k");
inputParams[2]->set_friendly_name("v");
@ -77,9 +81,7 @@ void ScaledAttnLayerCPUTest::SetUp() {
inputParams.back()->set_friendly_name("scale");
} else {
if (has_attn) {
// attention_mask[B, 1, 1, L0+L1]
ov::PartialShape attnShape{inputDynamicShapes[0][0], 1, 1, -1};
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, attnShape));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
inputParams.back()->set_friendly_name("attention_mask");
}
if (has_scale) {
@ -106,14 +108,14 @@ void ScaledAttnLayerCPUTest::SetUp() {
void ScaledAttnLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
std::vector<ov::Shape> shapes(3);
shapes[0] = targetInputStaticShapes[0];
shapes[1] = targetInputStaticShapes[0];
shapes[2] = targetInputStaticShapes[0];
shapes[1] = targetInputStaticShapes[1];
shapes[2] = targetInputStaticShapes[1];
if (!has_attn && has_scale) {
shapes.push_back(ov::Shape{});
shapes.push_back(ov::Shape{1});
} else {
if (has_attn) {
shapes.push_back(ov::Shape{targetInputStaticShapes[0][0], 1, 1, targetInputStaticShapes[0][2]});
shapes.push_back(targetInputStaticShapes[2]);
}
if (has_scale) {
shapes.push_back(ov::Shape{1});

View File

@ -13,12 +13,12 @@ using namespace ov::test;
namespace CPULayerTestsDefinitions {
typedef std::tuple<ElementType, // netPrecision
InputShape, // shape
bool, // is_causal
bool, // has_attn
bool, // has_scale
std::string, // targetDevice
typedef std::tuple<ElementType, // netPrecision
std::vector<InputShape>, // shape
bool, // is_causal
bool, // has_attn
bool, // has_scale
std::string, // targetDevice
CPUSpecificParams>
ScaledAttnCPUTestParams;

View File

@ -14,10 +14,52 @@ namespace CPULayerTestsDefinitions {
namespace ScaledAttn {
const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"};
const std::vector<InputShape> shapes{
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
}
const std::vector<std::vector<InputShape>> shapes{
// normal case, shapes of q,k,v are same
{
// q shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// attn shape: [B, 1, 1, L0+L1]
{ov::test::InputShape{ov::PartialShape{-1, 1, 1, -1},
{ov::Shape{1, 1, 1, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 1, 10}}}
},
},
// heads number of kv is 1, attn mask: [B, H, L1, L0+L1]
{
// q shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
},
// attn shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1},
{ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}}
},
},
// heads number of kv is 1, attn mask: [H, L1, L0+L1]
{
// q shape
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
},
// attn shape
{ov::test::InputShape{ov::PartialShape{8, -1, -1},
{ov::Shape{8, 100, 100}, ov::Shape{8, 1, 1}, ov::Shape{8, 10, 10}}}
},
},
};
const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16),

View File

@ -75,7 +75,8 @@ class TestTimmConvertModel(TestConvertModel):
@pytest.mark.parametrize("name", ["mobilevitv2_050.cvnets_in1k",
"poolformerv2_s12.sail_in1k",
"vit_base_patch8_224.augreg_in21k"])
"vit_base_patch8_224.augreg_in21k",
"beit_base_patch16_224.in22k_ft_in22k"])
@pytest.mark.precommit
def test_convert_model_precommit(self, name, ie_device):
self.run(name, None, ie_device)