[Coverity Scan Issue] fix cs issues in scaled_attn.cpp (#21270)
* [Coverity Scan Issue] fix cs issues in scaled_attn.cpp Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com> * Fix review comments Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com> --------- Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
parent
72cb4e4820
commit
058001eb84
@ -232,7 +232,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
||||
}
|
||||
|
||||
PlainTensor<uint8_t> causal_mask;
|
||||
bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this
|
||||
bool select_nfltmax_at_0 = false; // set attn_score to -FLT_MAX when causal_mask[...] equal to this
|
||||
void set_causal_mask(PlainTensor<uint8_t> mask, bool _select_nfltmax_at_0) {
|
||||
causal_mask = mask;
|
||||
select_nfltmax_at_0 = _select_nfltmax_at_0;
|
||||
@ -300,6 +300,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_MLAS, float> {
|
||||
|
||||
MHAKernel() {
|
||||
m_block_size = 4;
|
||||
select_nfltmax_at_0 = false;
|
||||
qk_buffers.resize(parallel_get_max_threads(), PlainTensor<float>(true));
|
||||
}
|
||||
|
||||
@ -342,7 +343,9 @@ struct MHAKernel<ScaledDotProductAttention::KT_MLAS, float> {
|
||||
auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
|
||||
|
||||
parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) {
|
||||
size_t thread_id = static_cast<size_t>(parallel_get_thread_num());
|
||||
auto thread_id = parallel_get_thread_num();
|
||||
if (thread_id < 0)
|
||||
OPENVINO_THROW("The calling thread isn't initialized!");
|
||||
auto& qk_buf = qk_buffers[thread_id];
|
||||
|
||||
auto m_start = m_blk * m_block_size;
|
||||
@ -450,7 +453,7 @@ struct MHASingleToken {
|
||||
PlainTensor<float> m_attn_w;
|
||||
PlainTensor<float> m_temp;
|
||||
|
||||
MHASingleToken() : m_attn_w(true), m_temp(true) {}
|
||||
MHASingleToken() : m_attn_w(true), m_temp(true), select_nfltmax_at_0(false) {}
|
||||
|
||||
PlainTensor<uint8_t> causal_mask;
|
||||
bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this
|
||||
|
Loading…
Reference in New Issue
Block a user