[PT FE] Add aten::_native_multi_head_attention (#17550)
* [PT FE] Add implementation of MHA * [PT FE] Add tests, add scaled dot product attention * [PT FE] Fix missing transpose for Q,K,V & output Attention * [PT FE] Formatting errors * [PT FE] Fix testing class with nn.Linear * [PT FE] Fix incorrect key franspose in dot product attention computation * [PT FE] Fix incorrect matmul due to lack of transpose * [PT FE] Enable support for all boolean masks * [PT FE] Fix returned weights * [PT FE] Remove debugging artifacts * [PT FE] Remove unused nodes, optimize transpose nodes' usage, add comments to floating masks * [PT FE] Further reduce node usage, return None instead of 0 for return_weights=false * [PT FE] Allow for dynamic num_num_head, embed_dim * [PT FE] Improve error comment, remove unnecessary Unsqueeze * [PT FE] Clang format * Update tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * [PT FE] Add masks comments, improve mask broadcasting --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
c0fb831c6e
commit
3d8a620ac3
201
src/frontends/pytorch/src/op/native_multi_head_attention.cpp
Normal file
201
src/frontends/pytorch/src/op/native_multi_head_attention.cpp
Normal file
@ -0,0 +1,201 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_native_multi_head_attention(const NodeContext& context) {
|
||||
/*
|
||||
aten::_native_multi_head_attention(
|
||||
Tensor query,
|
||||
Tensor key,
|
||||
Tensor value,
|
||||
int64 embed_dim,
|
||||
int64 num_head,
|
||||
Tensor qkv_weight,
|
||||
Tensor qkv_bias,
|
||||
Tensor proj_weight,
|
||||
Tensor proj_bias,
|
||||
Optional[Tensor] mask = None,
|
||||
bool need_weights = true,
|
||||
bool average_attn_weights = true,
|
||||
Optional[int64] mask_type = None
|
||||
)
|
||||
*/
|
||||
num_inputs_check(context, 13, 13);
|
||||
const auto query = context.get_input(0);
|
||||
const auto key = context.get_input(1);
|
||||
const auto value = context.get_input(2);
|
||||
const auto embed_dim = context.get_input(3);
|
||||
const auto num_head = context.get_input(4);
|
||||
const auto qkv_weight = context.get_input(5);
|
||||
const auto qkv_bias = context.get_input(6);
|
||||
const auto proj_weight = context.get_input(7);
|
||||
const auto proj_bias = context.get_input(8);
|
||||
const auto need_weights = context.const_input<bool>(10);
|
||||
const auto average_weights = context.const_input<bool>(11);
|
||||
|
||||
const auto minus_inf =
|
||||
context.mark_node(opset10::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
|
||||
const auto embed_dim_i64 = context.mark_node(std::make_shared<opset10::Convert>(embed_dim, element::i64));
|
||||
const auto num_head_i64 = context.mark_node(std::make_shared<opset10::Convert>(num_head, element::i64));
|
||||
|
||||
const auto neg_one_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {-1}));
|
||||
const auto zero_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0}));
|
||||
const auto one_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||
const auto two_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2}));
|
||||
const auto three_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {3}));
|
||||
const auto heads_1d = context.mark_node(std::make_shared<opset10::Unsqueeze>(num_head_i64, zero_1d));
|
||||
|
||||
const auto ev_1_slice_1d = context.mark_node(std::make_shared<opset10::Multiply>(one_1d, embed_dim_i64));
|
||||
const auto ev_2_slice_1d = context.mark_node(std::make_shared<opset10::Multiply>(two_1d, embed_dim_i64));
|
||||
const auto ev_3_slice_1d = context.mark_node(std::make_shared<opset10::Multiply>(three_1d, embed_dim_i64));
|
||||
|
||||
const auto qkv_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(query));
|
||||
const auto batch_size = context.mark_node(std::make_shared<opset10::Gather>(qkv_shape, zero_1d, zero_1d));
|
||||
const auto seq_size = context.mark_node(std::make_shared<opset10::Gather>(qkv_shape, one_1d, zero_1d));
|
||||
const auto embed_div_heads = context.mark_node(std::make_shared<opset10::Divide>(embed_dim_i64, heads_1d, true));
|
||||
|
||||
const auto query_proj_weight =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(qkv_weight, zero_1d, ev_1_slice_1d, one_1d, zero_1d));
|
||||
const auto key_proj_weight =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(qkv_weight, ev_1_slice_1d, ev_2_slice_1d, one_1d, zero_1d));
|
||||
const auto value_proj_weight =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(qkv_weight, ev_2_slice_1d, ev_3_slice_1d, one_1d, zero_1d));
|
||||
const auto query_proj_bias =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(qkv_bias, zero_1d, ev_1_slice_1d, one_1d, zero_1d));
|
||||
const auto key_proj_bias =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(qkv_bias, ev_1_slice_1d, ev_2_slice_1d, one_1d, zero_1d));
|
||||
const auto value_proj_bias =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(qkv_bias, ev_2_slice_1d, ev_3_slice_1d, one_1d, zero_1d));
|
||||
|
||||
const auto query_weighted =
|
||||
context.mark_node(std::make_shared<opset10::MatMul>(query, query_proj_weight, false, true));
|
||||
const auto key_weighted = context.mark_node(std::make_shared<opset10::MatMul>(key, key_proj_weight, false, true));
|
||||
const auto value_weighted =
|
||||
context.mark_node(std::make_shared<opset10::MatMul>(value, value_proj_weight, false, true));
|
||||
|
||||
const auto query_biased = context.mark_node(std::make_shared<opset10::Add>(query_weighted, query_proj_bias));
|
||||
const auto key_biased = context.mark_node(std::make_shared<opset10::Add>(key_weighted, key_proj_bias));
|
||||
const auto value_biased = context.mark_node(std::make_shared<opset10::Add>(value_weighted, value_proj_bias));
|
||||
|
||||
const auto qkv_reshape_dims = context.mark_node(
|
||||
std::make_shared<opset10::Concat>(OutputVector{batch_size, seq_size, heads_1d, neg_one_1d}, 0));
|
||||
const auto qv_transpose_dims = context.mark_node(opset10::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}));
|
||||
const auto k_transpose_dims = context.mark_node(opset10::Constant::create(element::i64, Shape{4}, {0, 2, 3, 1}));
|
||||
|
||||
const auto query_reshaped =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(query_biased, qkv_reshape_dims, false));
|
||||
const auto key_reshaped =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(key_biased, qkv_reshape_dims, false));
|
||||
const auto value_reshaped =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(value_biased, qkv_reshape_dims, false));
|
||||
|
||||
const auto query_transposed =
|
||||
context.mark_node(std::make_shared<opset10::Transpose>(query_reshaped, qv_transpose_dims));
|
||||
const auto key_transposed = context.mark_node(std::make_shared<opset10::Transpose>(key_reshaped, k_transpose_dims));
|
||||
const auto value_transposed =
|
||||
context.mark_node(std::make_shared<opset10::Transpose>(value_reshaped, qv_transpose_dims));
|
||||
|
||||
const auto scale_one = context.mark_node(std::make_shared<opset10::ConvertLike>(one_1d, query_transposed));
|
||||
const auto scale_dim = context.mark_node(std::make_shared<opset10::ConvertLike>(embed_div_heads, query_transposed));
|
||||
const auto scale_dim_sqrt = context.mark_node(std::make_shared<opset10::Sqrt>(scale_dim));
|
||||
const auto scale = context.mark_node(std::make_shared<opset10::Divide>(scale_one, scale_dim_sqrt));
|
||||
|
||||
const auto query_key_transpose_dot_product =
|
||||
context.mark_node(std::make_shared<opset10::MatMul>(query_transposed, key_transposed));
|
||||
|
||||
auto scaled_dot_product =
|
||||
context.mark_node(std::make_shared<opset10::Multiply>(query_key_transpose_dot_product, scale));
|
||||
|
||||
// Mask handling
|
||||
if (!context.input_is_none(9) && !context.input_is_none(12)) {
|
||||
auto atten_mask = context.get_input(9);
|
||||
// Only allow boolean masks - PyTorch automatically converts
|
||||
// non-boolean to boolean masks
|
||||
if (atten_mask.get_element_type() == element::boolean) {
|
||||
const auto minus_inf_conv =
|
||||
context.mark_node(std::make_shared<opset10::ConvertLike>(minus_inf, scaled_dot_product));
|
||||
const auto mask_inverse = context.mark_node(std::make_shared<opset10::LogicalNot>(atten_mask));
|
||||
atten_mask = context.mark_node(std::make_shared<opset10::ConvertLike>(atten_mask, scaled_dot_product));
|
||||
atten_mask = context.mark_node(std::make_shared<opset10::Select>(mask_inverse, atten_mask, minus_inf_conv));
|
||||
} else {
|
||||
// Once int/float mask type is supported in PyTorch,
|
||||
// remove this assert to allow for such masks in OV
|
||||
FRONT_END_OP_CONVERSION_CHECK(1, "Non-boolean masks are not supported.");
|
||||
atten_mask = context.mark_node(std::make_shared<opset10::ConvertLike>(atten_mask, scaled_dot_product));
|
||||
}
|
||||
|
||||
// If mask type is 1 (only key-padding) then mask's shape is (N, S).
|
||||
// It must be reshaped to (N, 1, 1, S) to properly broadcast it proper dims in the next step
|
||||
if (context.const_input<int64_t>(12) == 1) {
|
||||
const auto target_mask_reshape = context.mark_node(
|
||||
std::make_shared<opset10::Concat>(OutputVector{batch_size, one_1d, one_1d, seq_size}, 0));
|
||||
atten_mask = context.mark_node(std::make_shared<opset10::Reshape>(atten_mask, target_mask_reshape, false));
|
||||
}
|
||||
|
||||
// All mask types should be broadcast to this shape,
|
||||
// Except for type 2 which already has this shape
|
||||
if (context.const_input<int64_t>(12) != 2) {
|
||||
const auto target_mask_shape = context.mark_node(
|
||||
std::make_shared<opset10::Concat>(OutputVector{batch_size, heads_1d, seq_size, seq_size}, 0));
|
||||
atten_mask = context.mark_node(std::make_shared<opset10::Broadcast>(atten_mask, target_mask_shape));
|
||||
}
|
||||
scaled_dot_product = context.mark_node(std::make_shared<opset10::Add>(scaled_dot_product, atten_mask));
|
||||
}
|
||||
|
||||
const auto scaled_dot_product_softmax =
|
||||
context.mark_node(std::make_shared<opset10::Softmax>(scaled_dot_product, -1));
|
||||
const auto scaled_dot_product_attention =
|
||||
context.mark_node(std::make_shared<opset10::MatMul>(scaled_dot_product_softmax, value_transposed));
|
||||
|
||||
const auto sdp_reshape_dims =
|
||||
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{batch_size, seq_size, neg_one_1d}, 0));
|
||||
// Undo transpose (transpose back to original qv shape)
|
||||
const auto scaled_dot_product_attention_transposed =
|
||||
context.mark_node(std::make_shared<opset10::Transpose>(scaled_dot_product_attention, qv_transpose_dims));
|
||||
const auto scaled_dot_product_attention_reshaped = context.mark_node(
|
||||
std::make_shared<opset10::Reshape>(scaled_dot_product_attention_transposed, sdp_reshape_dims, false));
|
||||
|
||||
const auto scaled_dot_product_attention_weighted = context.mark_node(
|
||||
std::make_shared<opset10::MatMul>(scaled_dot_product_attention_reshaped, proj_weight, false, true));
|
||||
const auto scaled_dot_product_attention_biased =
|
||||
context.mark_node(std::make_shared<opset10::Add>(scaled_dot_product_attention_weighted, proj_bias));
|
||||
|
||||
if (average_weights) {
|
||||
const auto target_div_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(scaled_dot_product));
|
||||
const auto heads_div = context.mark_node(std::make_shared<opset10::Broadcast>(heads_1d, target_div_shape));
|
||||
const auto heads_div_conv =
|
||||
context.mark_node(std::make_shared<opset10::ConvertLike>(heads_div, scaled_dot_product));
|
||||
scaled_dot_product =
|
||||
context.mark_node(std::make_shared<opset10::Divide>(scaled_dot_product, heads_div_conv, false));
|
||||
scaled_dot_product = context.mark_node(std::make_shared<opset10::ReduceSum>(scaled_dot_product, one_1d));
|
||||
}
|
||||
|
||||
if (need_weights) {
|
||||
return {scaled_dot_product_attention_biased, scaled_dot_product};
|
||||
} else {
|
||||
// When need_weights == false, returns None as a second output
|
||||
const auto none = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
|
||||
auto attrs = none->get_attrs();
|
||||
attrs["none_value"] = "";
|
||||
none->set_attrs(attrs);
|
||||
const auto none_marked = context.mark_node(none);
|
||||
return {scaled_dot_product_attention_biased, none_marked};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -89,6 +89,7 @@ OP_CONVERTER(translate_mean);
|
||||
OP_CONVERTER(translate_meshgrid);
|
||||
OP_CONVERTER(translate_min);
|
||||
OP_CONVERTER(translate_narrow);
|
||||
OP_CONVERTER(translate_native_multi_head_attention);
|
||||
OP_CONVERTER(translate_neg);
|
||||
OP_CONVERTER(translate_new_full);
|
||||
OP_CONVERTER(translate_new_ones);
|
||||
@ -160,6 +161,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::__range_length", op::translate_range_length},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
{"aten::_native_multi_head_attention", op::translate_native_multi_head_attention},
|
||||
{"aten::_set_item", op::translate_set_item},
|
||||
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
|
||||
{"aten::acos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},
|
||||
|
@ -38,7 +38,7 @@ def pytest_collection_modifyitems(items):
|
||||
test.add_marker(pytest.mark.xfail(reason=mark.kwargs["reason"]))
|
||||
|
||||
|
||||
@pytest.mark.hookwrapper
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
pytest_html = item.config.pluginmanager.getplugin('html')
|
||||
outcome = yield
|
||||
|
@ -0,0 +1,79 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
EMBED_DIM = 8
|
||||
NUM_HEADS = 4
|
||||
SEQ_LENGTH = 6
|
||||
BATCH_SIZE = 1
|
||||
|
||||
NO_MASK, ATTN_MASK, KEY_PAD_MASK, MERGED_MASK = -1, 0, 1, 2
|
||||
|
||||
class aten_native_multi_head_attention(torch.nn.Module):
|
||||
def __init__(self, mask, need_weights, average_attn_weights) -> None:
|
||||
super().__init__()
|
||||
self.qkv = torch.nn.Linear(EMBED_DIM, 3 * EMBED_DIM, dtype = torch.float32)
|
||||
self.qkv.requires_grad_(False)
|
||||
self.proj = torch.nn.Linear(EMBED_DIM, EMBED_DIM, dtype = torch.float32)
|
||||
self.proj.requires_grad_(False)
|
||||
|
||||
self.embed_dim = EMBED_DIM
|
||||
self.num_heads = NUM_HEADS
|
||||
self.need_weights = need_weights
|
||||
self.average_attn_weights = average_attn_weights
|
||||
|
||||
# Currently only int masks are working correctly, they are converted to bool.
|
||||
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
|
||||
# which later returns NaNs as MHA's output
|
||||
if mask == 0:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask_type = 0
|
||||
elif mask == 1:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask_type = 1
|
||||
elif mask == 2:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask_type = 2
|
||||
else:
|
||||
self.mask = None
|
||||
self.mask_type = None
|
||||
|
||||
print(self.mask)
|
||||
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten._native_multi_head_attention(
|
||||
query, key, value,
|
||||
embed_dim=self.embed_dim, num_head=self.num_heads,
|
||||
qkv_weight=self.qkv.weight, qkv_bias=self.qkv.bias,
|
||||
proj_weight=self.proj.weight, proj_bias=self.proj.bias,
|
||||
mask = self.mask, need_weights=self.need_weights,
|
||||
average_attn_weights = self.average_attn_weights,
|
||||
mask_type = self.mask_type
|
||||
)[0]
|
||||
|
||||
class TestNativeMultiHeadAttention(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
# NativeMHA is self-attention
|
||||
qkv_tensor = np.random.randn(BATCH_SIZE, SEQ_LENGTH, EMBED_DIM).astype(np.float32)
|
||||
return (qkv_tensor.copy(),
|
||||
qkv_tensor.copy(),
|
||||
qkv_tensor.copy())
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize(
|
||||
"mask",
|
||||
[NO_MASK, ATTN_MASK, KEY_PAD_MASK, MERGED_MASK]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
["need_weights", "average_attn_weights"],
|
||||
[[False, False], [True, False], [True, True]]
|
||||
)
|
||||
def test_native_multi_head_attention(self, ie_device, precision, ir_version, mask, need_weights, average_attn_weights):
|
||||
self._test(aten_native_multi_head_attention(mask, need_weights, average_attn_weights),
|
||||
None, "aten::_native_multi_head_attention", ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user