[PT FE] Add aten::_native_multi_head_attention (#17550)

* [PT FE] Add implementation of MHA

* [PT FE] Add tests, add scaled dot product attention

* [PT FE] Fix missing transpose for Q,K,V & output Attention

* [PT FE] Formatting errors

* [PT FE] Fix testing class with nn.Linear

* [PT FE] Fix incorrect key franspose in dot product attention computation

* [PT FE] Fix incorrect matmul due to lack of transpose

* [PT FE] Enable support for all boolean masks

* [PT FE] Fix returned weights

* [PT FE] Remove debugging artifacts

* [PT FE] Remove unused nodes, optimize transpose nodes' usage, add comments to floating masks

* [PT FE] Further reduce node usage, return None instead of 0 for return_weights=false

* [PT FE] Allow for dynamic num_num_head, embed_dim

* [PT FE] Improve error comment, remove unnecessary Unsqueeze

* [PT FE] Clang format

* Update tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* [PT FE] Add masks comments, improve mask broadcasting

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Piotr Krzemiński 2023-06-05 10:55:03 +02:00 committed by GitHub
parent c0fb831c6e
commit 3d8a620ac3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 283 additions and 1 deletions

View File

@ -0,0 +1,201 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_native_multi_head_attention(const NodeContext& context) {
/*
aten::_native_multi_head_attention(
Tensor query,
Tensor key,
Tensor value,
int64 embed_dim,
int64 num_head,
Tensor qkv_weight,
Tensor qkv_bias,
Tensor proj_weight,
Tensor proj_bias,
Optional[Tensor] mask = None,
bool need_weights = true,
bool average_attn_weights = true,
Optional[int64] mask_type = None
)
*/
num_inputs_check(context, 13, 13);
const auto query = context.get_input(0);
const auto key = context.get_input(1);
const auto value = context.get_input(2);
const auto embed_dim = context.get_input(3);
const auto num_head = context.get_input(4);
const auto qkv_weight = context.get_input(5);
const auto qkv_bias = context.get_input(6);
const auto proj_weight = context.get_input(7);
const auto proj_bias = context.get_input(8);
const auto need_weights = context.const_input<bool>(10);
const auto average_weights = context.const_input<bool>(11);
const auto minus_inf =
context.mark_node(opset10::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
const auto embed_dim_i64 = context.mark_node(std::make_shared<opset10::Convert>(embed_dim, element::i64));
const auto num_head_i64 = context.mark_node(std::make_shared<opset10::Convert>(num_head, element::i64));
const auto neg_one_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {-1}));
const auto zero_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0}));
const auto one_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
const auto two_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2}));
const auto three_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {3}));
const auto heads_1d = context.mark_node(std::make_shared<opset10::Unsqueeze>(num_head_i64, zero_1d));
const auto ev_1_slice_1d = context.mark_node(std::make_shared<opset10::Multiply>(one_1d, embed_dim_i64));
const auto ev_2_slice_1d = context.mark_node(std::make_shared<opset10::Multiply>(two_1d, embed_dim_i64));
const auto ev_3_slice_1d = context.mark_node(std::make_shared<opset10::Multiply>(three_1d, embed_dim_i64));
const auto qkv_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(query));
const auto batch_size = context.mark_node(std::make_shared<opset10::Gather>(qkv_shape, zero_1d, zero_1d));
const auto seq_size = context.mark_node(std::make_shared<opset10::Gather>(qkv_shape, one_1d, zero_1d));
const auto embed_div_heads = context.mark_node(std::make_shared<opset10::Divide>(embed_dim_i64, heads_1d, true));
const auto query_proj_weight =
context.mark_node(std::make_shared<opset10::Slice>(qkv_weight, zero_1d, ev_1_slice_1d, one_1d, zero_1d));
const auto key_proj_weight =
context.mark_node(std::make_shared<opset10::Slice>(qkv_weight, ev_1_slice_1d, ev_2_slice_1d, one_1d, zero_1d));
const auto value_proj_weight =
context.mark_node(std::make_shared<opset10::Slice>(qkv_weight, ev_2_slice_1d, ev_3_slice_1d, one_1d, zero_1d));
const auto query_proj_bias =
context.mark_node(std::make_shared<opset10::Slice>(qkv_bias, zero_1d, ev_1_slice_1d, one_1d, zero_1d));
const auto key_proj_bias =
context.mark_node(std::make_shared<opset10::Slice>(qkv_bias, ev_1_slice_1d, ev_2_slice_1d, one_1d, zero_1d));
const auto value_proj_bias =
context.mark_node(std::make_shared<opset10::Slice>(qkv_bias, ev_2_slice_1d, ev_3_slice_1d, one_1d, zero_1d));
const auto query_weighted =
context.mark_node(std::make_shared<opset10::MatMul>(query, query_proj_weight, false, true));
const auto key_weighted = context.mark_node(std::make_shared<opset10::MatMul>(key, key_proj_weight, false, true));
const auto value_weighted =
context.mark_node(std::make_shared<opset10::MatMul>(value, value_proj_weight, false, true));
const auto query_biased = context.mark_node(std::make_shared<opset10::Add>(query_weighted, query_proj_bias));
const auto key_biased = context.mark_node(std::make_shared<opset10::Add>(key_weighted, key_proj_bias));
const auto value_biased = context.mark_node(std::make_shared<opset10::Add>(value_weighted, value_proj_bias));
const auto qkv_reshape_dims = context.mark_node(
std::make_shared<opset10::Concat>(OutputVector{batch_size, seq_size, heads_1d, neg_one_1d}, 0));
const auto qv_transpose_dims = context.mark_node(opset10::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3}));
const auto k_transpose_dims = context.mark_node(opset10::Constant::create(element::i64, Shape{4}, {0, 2, 3, 1}));
const auto query_reshaped =
context.mark_node(std::make_shared<opset10::Reshape>(query_biased, qkv_reshape_dims, false));
const auto key_reshaped =
context.mark_node(std::make_shared<opset10::Reshape>(key_biased, qkv_reshape_dims, false));
const auto value_reshaped =
context.mark_node(std::make_shared<opset10::Reshape>(value_biased, qkv_reshape_dims, false));
const auto query_transposed =
context.mark_node(std::make_shared<opset10::Transpose>(query_reshaped, qv_transpose_dims));
const auto key_transposed = context.mark_node(std::make_shared<opset10::Transpose>(key_reshaped, k_transpose_dims));
const auto value_transposed =
context.mark_node(std::make_shared<opset10::Transpose>(value_reshaped, qv_transpose_dims));
const auto scale_one = context.mark_node(std::make_shared<opset10::ConvertLike>(one_1d, query_transposed));
const auto scale_dim = context.mark_node(std::make_shared<opset10::ConvertLike>(embed_div_heads, query_transposed));
const auto scale_dim_sqrt = context.mark_node(std::make_shared<opset10::Sqrt>(scale_dim));
const auto scale = context.mark_node(std::make_shared<opset10::Divide>(scale_one, scale_dim_sqrt));
const auto query_key_transpose_dot_product =
context.mark_node(std::make_shared<opset10::MatMul>(query_transposed, key_transposed));
auto scaled_dot_product =
context.mark_node(std::make_shared<opset10::Multiply>(query_key_transpose_dot_product, scale));
// Mask handling
if (!context.input_is_none(9) && !context.input_is_none(12)) {
auto atten_mask = context.get_input(9);
// Only allow boolean masks - PyTorch automatically converts
// non-boolean to boolean masks
if (atten_mask.get_element_type() == element::boolean) {
const auto minus_inf_conv =
context.mark_node(std::make_shared<opset10::ConvertLike>(minus_inf, scaled_dot_product));
const auto mask_inverse = context.mark_node(std::make_shared<opset10::LogicalNot>(atten_mask));
atten_mask = context.mark_node(std::make_shared<opset10::ConvertLike>(atten_mask, scaled_dot_product));
atten_mask = context.mark_node(std::make_shared<opset10::Select>(mask_inverse, atten_mask, minus_inf_conv));
} else {
// Once int/float mask type is supported in PyTorch,
// remove this assert to allow for such masks in OV
FRONT_END_OP_CONVERSION_CHECK(1, "Non-boolean masks are not supported.");
atten_mask = context.mark_node(std::make_shared<opset10::ConvertLike>(atten_mask, scaled_dot_product));
}
// If mask type is 1 (only key-padding) then mask's shape is (N, S).
// It must be reshaped to (N, 1, 1, S) to properly broadcast it proper dims in the next step
if (context.const_input<int64_t>(12) == 1) {
const auto target_mask_reshape = context.mark_node(
std::make_shared<opset10::Concat>(OutputVector{batch_size, one_1d, one_1d, seq_size}, 0));
atten_mask = context.mark_node(std::make_shared<opset10::Reshape>(atten_mask, target_mask_reshape, false));
}
// All mask types should be broadcast to this shape,
// Except for type 2 which already has this shape
if (context.const_input<int64_t>(12) != 2) {
const auto target_mask_shape = context.mark_node(
std::make_shared<opset10::Concat>(OutputVector{batch_size, heads_1d, seq_size, seq_size}, 0));
atten_mask = context.mark_node(std::make_shared<opset10::Broadcast>(atten_mask, target_mask_shape));
}
scaled_dot_product = context.mark_node(std::make_shared<opset10::Add>(scaled_dot_product, atten_mask));
}
const auto scaled_dot_product_softmax =
context.mark_node(std::make_shared<opset10::Softmax>(scaled_dot_product, -1));
const auto scaled_dot_product_attention =
context.mark_node(std::make_shared<opset10::MatMul>(scaled_dot_product_softmax, value_transposed));
const auto sdp_reshape_dims =
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{batch_size, seq_size, neg_one_1d}, 0));
// Undo transpose (transpose back to original qv shape)
const auto scaled_dot_product_attention_transposed =
context.mark_node(std::make_shared<opset10::Transpose>(scaled_dot_product_attention, qv_transpose_dims));
const auto scaled_dot_product_attention_reshaped = context.mark_node(
std::make_shared<opset10::Reshape>(scaled_dot_product_attention_transposed, sdp_reshape_dims, false));
const auto scaled_dot_product_attention_weighted = context.mark_node(
std::make_shared<opset10::MatMul>(scaled_dot_product_attention_reshaped, proj_weight, false, true));
const auto scaled_dot_product_attention_biased =
context.mark_node(std::make_shared<opset10::Add>(scaled_dot_product_attention_weighted, proj_bias));
if (average_weights) {
const auto target_div_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(scaled_dot_product));
const auto heads_div = context.mark_node(std::make_shared<opset10::Broadcast>(heads_1d, target_div_shape));
const auto heads_div_conv =
context.mark_node(std::make_shared<opset10::ConvertLike>(heads_div, scaled_dot_product));
scaled_dot_product =
context.mark_node(std::make_shared<opset10::Divide>(scaled_dot_product, heads_div_conv, false));
scaled_dot_product = context.mark_node(std::make_shared<opset10::ReduceSum>(scaled_dot_product, one_1d));
}
if (need_weights) {
return {scaled_dot_product_attention_biased, scaled_dot_product};
} else {
// When need_weights == false, returns None as a second output
const auto none = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
auto attrs = none->get_attrs();
attrs["none_value"] = "";
none->set_attrs(attrs);
const auto none_marked = context.mark_node(none);
return {scaled_dot_product_attention_biased, none_marked};
}
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -89,6 +89,7 @@ OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_meshgrid);
OP_CONVERTER(translate_min);
OP_CONVERTER(translate_narrow);
OP_CONVERTER(translate_native_multi_head_attention);
OP_CONVERTER(translate_neg);
OP_CONVERTER(translate_new_full);
OP_CONVERTER(translate_new_ones);
@ -160,6 +161,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::__range_length", op::translate_range_length},
{"aten::_convolution", op::translate_convolution},
{"aten::_convolution_mode", op::translate_convolution_mode},
{"aten::_native_multi_head_attention", op::translate_native_multi_head_attention},
{"aten::_set_item", op::translate_set_item},
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"aten::acos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},

View File

@ -38,7 +38,7 @@ def pytest_collection_modifyitems(items):
test.add_marker(pytest.mark.xfail(reason=mark.kwargs["reason"]))
@pytest.mark.hookwrapper
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item, call):
pytest_html = item.config.pluginmanager.getplugin('html')
outcome = yield

View File

@ -0,0 +1,79 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
EMBED_DIM = 8
NUM_HEADS = 4
SEQ_LENGTH = 6
BATCH_SIZE = 1
NO_MASK, ATTN_MASK, KEY_PAD_MASK, MERGED_MASK = -1, 0, 1, 2
class aten_native_multi_head_attention(torch.nn.Module):
def __init__(self, mask, need_weights, average_attn_weights) -> None:
super().__init__()
self.qkv = torch.nn.Linear(EMBED_DIM, 3 * EMBED_DIM, dtype = torch.float32)
self.qkv.requires_grad_(False)
self.proj = torch.nn.Linear(EMBED_DIM, EMBED_DIM, dtype = torch.float32)
self.proj.requires_grad_(False)
self.embed_dim = EMBED_DIM
self.num_heads = NUM_HEADS
self.need_weights = need_weights
self.average_attn_weights = average_attn_weights
# Currently only int masks are working correctly, they are converted to bool.
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
# which later returns NaNs as MHA's output
if mask == 0:
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
self.mask_type = 0
elif mask == 1:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool))
self.mask_type = 1
elif mask == 2:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
self.mask_type = 2
else:
self.mask = None
self.mask_type = None
print(self.mask)
def forward(self, query, key, value):
return torch.ops.aten._native_multi_head_attention(
query, key, value,
embed_dim=self.embed_dim, num_head=self.num_heads,
qkv_weight=self.qkv.weight, qkv_bias=self.qkv.bias,
proj_weight=self.proj.weight, proj_bias=self.proj.bias,
mask = self.mask, need_weights=self.need_weights,
average_attn_weights = self.average_attn_weights,
mask_type = self.mask_type
)[0]
class TestNativeMultiHeadAttention(PytorchLayerTest):
def _prepare_input(self):
# NativeMHA is self-attention
qkv_tensor = np.random.randn(BATCH_SIZE, SEQ_LENGTH, EMBED_DIM).astype(np.float32)
return (qkv_tensor.copy(),
qkv_tensor.copy(),
qkv_tensor.copy())
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(
"mask",
[NO_MASK, ATTN_MASK, KEY_PAD_MASK, MERGED_MASK]
)
@pytest.mark.parametrize(
["need_weights", "average_attn_weights"],
[[False, False], [True, False], [True, True]]
)
def test_native_multi_head_attention(self, ie_device, precision, ir_version, mask, need_weights, average_attn_weights):
self._test(aten_native_multi_head_attention(mask, need_weights, average_attn_weights),
None, "aten::_native_multi_head_attention", ie_device, precision, ir_version)