From 3d8a620ac3219aa4f86c3379d08841b39912c59b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Krzemi=C5=84ski?= Date: Mon, 5 Jun 2023 10:55:03 +0200 Subject: [PATCH] [PT FE] Add aten::_native_multi_head_attention (#17550) * [PT FE] Add implementation of MHA * [PT FE] Add tests, add scaled dot product attention * [PT FE] Fix missing transpose for Q,K,V & output Attention * [PT FE] Formatting errors * [PT FE] Fix testing class with nn.Linear * [PT FE] Fix incorrect key franspose in dot product attention computation * [PT FE] Fix incorrect matmul due to lack of transpose * [PT FE] Enable support for all boolean masks * [PT FE] Fix returned weights * [PT FE] Remove debugging artifacts * [PT FE] Remove unused nodes, optimize transpose nodes' usage, add comments to floating masks * [PT FE] Further reduce node usage, return None instead of 0 for return_weights=false * [PT FE] Allow for dynamic num_num_head, embed_dim * [PT FE] Improve error comment, remove unnecessary Unsqueeze * [PT FE] Clang format * Update tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py Co-authored-by: Maxim Vafin * [PT FE] Add masks comments, improve mask broadcasting --------- Co-authored-by: Maxim Vafin --- .../src/op/native_multi_head_attention.cpp | 201 ++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + tests/layer_tests/conftest.py | 2 +- .../test_native_multi_head_attention.py | 79 +++++++ 4 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 src/frontends/pytorch/src/op/native_multi_head_attention.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py diff --git a/src/frontends/pytorch/src/op/native_multi_head_attention.cpp b/src/frontends/pytorch/src/op/native_multi_head_attention.cpp new file mode 100644 index 00000000000..6ecc798b439 --- /dev/null +++ b/src/frontends/pytorch/src/op/native_multi_head_attention.cpp @@ -0,0 +1,201 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset10.hpp" +#include "pt_framework_node.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_native_multi_head_attention(const NodeContext& context) { + /* + aten::_native_multi_head_attention( + Tensor query, + Tensor key, + Tensor value, + int64 embed_dim, + int64 num_head, + Tensor qkv_weight, + Tensor qkv_bias, + Tensor proj_weight, + Tensor proj_bias, + Optional[Tensor] mask = None, + bool need_weights = true, + bool average_attn_weights = true, + Optional[int64] mask_type = None + ) + */ + num_inputs_check(context, 13, 13); + const auto query = context.get_input(0); + const auto key = context.get_input(1); + const auto value = context.get_input(2); + const auto embed_dim = context.get_input(3); + const auto num_head = context.get_input(4); + const auto qkv_weight = context.get_input(5); + const auto qkv_bias = context.get_input(6); + const auto proj_weight = context.get_input(7); + const auto proj_bias = context.get_input(8); + const auto need_weights = context.const_input(10); + const auto average_weights = context.const_input(11); + + const auto minus_inf = + context.mark_node(opset10::Constant::create(element::f32, Shape{}, {-std::numeric_limits::infinity()})); + const auto embed_dim_i64 = context.mark_node(std::make_shared(embed_dim, element::i64)); + const auto num_head_i64 = context.mark_node(std::make_shared(num_head, element::i64)); + + const auto neg_one_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {-1})); + const auto zero_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0})); + const auto one_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1})); + const auto two_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2})); + const auto three_1d = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {3})); + const auto heads_1d = context.mark_node(std::make_shared(num_head_i64, zero_1d)); + + const auto ev_1_slice_1d = context.mark_node(std::make_shared(one_1d, embed_dim_i64)); + const auto ev_2_slice_1d = context.mark_node(std::make_shared(two_1d, embed_dim_i64)); + const auto ev_3_slice_1d = context.mark_node(std::make_shared(three_1d, embed_dim_i64)); + + const auto qkv_shape = context.mark_node(std::make_shared(query)); + const auto batch_size = context.mark_node(std::make_shared(qkv_shape, zero_1d, zero_1d)); + const auto seq_size = context.mark_node(std::make_shared(qkv_shape, one_1d, zero_1d)); + const auto embed_div_heads = context.mark_node(std::make_shared(embed_dim_i64, heads_1d, true)); + + const auto query_proj_weight = + context.mark_node(std::make_shared(qkv_weight, zero_1d, ev_1_slice_1d, one_1d, zero_1d)); + const auto key_proj_weight = + context.mark_node(std::make_shared(qkv_weight, ev_1_slice_1d, ev_2_slice_1d, one_1d, zero_1d)); + const auto value_proj_weight = + context.mark_node(std::make_shared(qkv_weight, ev_2_slice_1d, ev_3_slice_1d, one_1d, zero_1d)); + const auto query_proj_bias = + context.mark_node(std::make_shared(qkv_bias, zero_1d, ev_1_slice_1d, one_1d, zero_1d)); + const auto key_proj_bias = + context.mark_node(std::make_shared(qkv_bias, ev_1_slice_1d, ev_2_slice_1d, one_1d, zero_1d)); + const auto value_proj_bias = + context.mark_node(std::make_shared(qkv_bias, ev_2_slice_1d, ev_3_slice_1d, one_1d, zero_1d)); + + const auto query_weighted = + context.mark_node(std::make_shared(query, query_proj_weight, false, true)); + const auto key_weighted = context.mark_node(std::make_shared(key, key_proj_weight, false, true)); + const auto value_weighted = + context.mark_node(std::make_shared(value, value_proj_weight, false, true)); + + const auto query_biased = context.mark_node(std::make_shared(query_weighted, query_proj_bias)); + const auto key_biased = context.mark_node(std::make_shared(key_weighted, key_proj_bias)); + const auto value_biased = context.mark_node(std::make_shared(value_weighted, value_proj_bias)); + + const auto qkv_reshape_dims = context.mark_node( + std::make_shared(OutputVector{batch_size, seq_size, heads_1d, neg_one_1d}, 0)); + const auto qv_transpose_dims = context.mark_node(opset10::Constant::create(element::i64, Shape{4}, {0, 2, 1, 3})); + const auto k_transpose_dims = context.mark_node(opset10::Constant::create(element::i64, Shape{4}, {0, 2, 3, 1})); + + const auto query_reshaped = + context.mark_node(std::make_shared(query_biased, qkv_reshape_dims, false)); + const auto key_reshaped = + context.mark_node(std::make_shared(key_biased, qkv_reshape_dims, false)); + const auto value_reshaped = + context.mark_node(std::make_shared(value_biased, qkv_reshape_dims, false)); + + const auto query_transposed = + context.mark_node(std::make_shared(query_reshaped, qv_transpose_dims)); + const auto key_transposed = context.mark_node(std::make_shared(key_reshaped, k_transpose_dims)); + const auto value_transposed = + context.mark_node(std::make_shared(value_reshaped, qv_transpose_dims)); + + const auto scale_one = context.mark_node(std::make_shared(one_1d, query_transposed)); + const auto scale_dim = context.mark_node(std::make_shared(embed_div_heads, query_transposed)); + const auto scale_dim_sqrt = context.mark_node(std::make_shared(scale_dim)); + const auto scale = context.mark_node(std::make_shared(scale_one, scale_dim_sqrt)); + + const auto query_key_transpose_dot_product = + context.mark_node(std::make_shared(query_transposed, key_transposed)); + + auto scaled_dot_product = + context.mark_node(std::make_shared(query_key_transpose_dot_product, scale)); + + // Mask handling + if (!context.input_is_none(9) && !context.input_is_none(12)) { + auto atten_mask = context.get_input(9); + // Only allow boolean masks - PyTorch automatically converts + // non-boolean to boolean masks + if (atten_mask.get_element_type() == element::boolean) { + const auto minus_inf_conv = + context.mark_node(std::make_shared(minus_inf, scaled_dot_product)); + const auto mask_inverse = context.mark_node(std::make_shared(atten_mask)); + atten_mask = context.mark_node(std::make_shared(atten_mask, scaled_dot_product)); + atten_mask = context.mark_node(std::make_shared(mask_inverse, atten_mask, minus_inf_conv)); + } else { + // Once int/float mask type is supported in PyTorch, + // remove this assert to allow for such masks in OV + FRONT_END_OP_CONVERSION_CHECK(1, "Non-boolean masks are not supported."); + atten_mask = context.mark_node(std::make_shared(atten_mask, scaled_dot_product)); + } + + // If mask type is 1 (only key-padding) then mask's shape is (N, S). + // It must be reshaped to (N, 1, 1, S) to properly broadcast it proper dims in the next step + if (context.const_input(12) == 1) { + const auto target_mask_reshape = context.mark_node( + std::make_shared(OutputVector{batch_size, one_1d, one_1d, seq_size}, 0)); + atten_mask = context.mark_node(std::make_shared(atten_mask, target_mask_reshape, false)); + } + + // All mask types should be broadcast to this shape, + // Except for type 2 which already has this shape + if (context.const_input(12) != 2) { + const auto target_mask_shape = context.mark_node( + std::make_shared(OutputVector{batch_size, heads_1d, seq_size, seq_size}, 0)); + atten_mask = context.mark_node(std::make_shared(atten_mask, target_mask_shape)); + } + scaled_dot_product = context.mark_node(std::make_shared(scaled_dot_product, atten_mask)); + } + + const auto scaled_dot_product_softmax = + context.mark_node(std::make_shared(scaled_dot_product, -1)); + const auto scaled_dot_product_attention = + context.mark_node(std::make_shared(scaled_dot_product_softmax, value_transposed)); + + const auto sdp_reshape_dims = + context.mark_node(std::make_shared(OutputVector{batch_size, seq_size, neg_one_1d}, 0)); + // Undo transpose (transpose back to original qv shape) + const auto scaled_dot_product_attention_transposed = + context.mark_node(std::make_shared(scaled_dot_product_attention, qv_transpose_dims)); + const auto scaled_dot_product_attention_reshaped = context.mark_node( + std::make_shared(scaled_dot_product_attention_transposed, sdp_reshape_dims, false)); + + const auto scaled_dot_product_attention_weighted = context.mark_node( + std::make_shared(scaled_dot_product_attention_reshaped, proj_weight, false, true)); + const auto scaled_dot_product_attention_biased = + context.mark_node(std::make_shared(scaled_dot_product_attention_weighted, proj_bias)); + + if (average_weights) { + const auto target_div_shape = context.mark_node(std::make_shared(scaled_dot_product)); + const auto heads_div = context.mark_node(std::make_shared(heads_1d, target_div_shape)); + const auto heads_div_conv = + context.mark_node(std::make_shared(heads_div, scaled_dot_product)); + scaled_dot_product = + context.mark_node(std::make_shared(scaled_dot_product, heads_div_conv, false)); + scaled_dot_product = context.mark_node(std::make_shared(scaled_dot_product, one_1d)); + } + + if (need_weights) { + return {scaled_dot_product_attention_biased, scaled_dot_product}; + } else { + // When need_weights == false, returns None as a second output + const auto none = std::make_shared(context.get_decoder(), context.inputs()); + auto attrs = none->get_attrs(); + attrs["none_value"] = ""; + none->set_attrs(attrs); + const auto none_marked = context.mark_node(none); + return {scaled_dot_product_attention_biased, none_marked}; + } +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index c0e00c0dd5b..48eecc550e1 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -89,6 +89,7 @@ OP_CONVERTER(translate_mean); OP_CONVERTER(translate_meshgrid); OP_CONVERTER(translate_min); OP_CONVERTER(translate_narrow); +OP_CONVERTER(translate_native_multi_head_attention); OP_CONVERTER(translate_neg); OP_CONVERTER(translate_new_full); OP_CONVERTER(translate_new_ones); @@ -160,6 +161,7 @@ const std::map get_supported_ops() { {"aten::__range_length", op::translate_range_length}, {"aten::_convolution", op::translate_convolution}, {"aten::_convolution_mode", op::translate_convolution_mode}, + {"aten::_native_multi_head_attention", op::translate_native_multi_head_attention}, {"aten::_set_item", op::translate_set_item}, {"aten::abs", op::translate_1to1_match_1_inputs}, {"aten::acos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, diff --git a/tests/layer_tests/conftest.py b/tests/layer_tests/conftest.py index 89827737ffd..5c9e4ea9cc7 100644 --- a/tests/layer_tests/conftest.py +++ b/tests/layer_tests/conftest.py @@ -38,7 +38,7 @@ def pytest_collection_modifyitems(items): test.add_marker(pytest.mark.xfail(reason=mark.kwargs["reason"])) -@pytest.mark.hookwrapper +@pytest.hookimpl(hookwrapper=True) def pytest_runtest_makereport(item, call): pytest_html = item.config.pluginmanager.getplugin('html') outcome = yield diff --git a/tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py b/tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py new file mode 100644 index 00000000000..95304109604 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_native_multi_head_attention.py @@ -0,0 +1,79 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +EMBED_DIM = 8 +NUM_HEADS = 4 +SEQ_LENGTH = 6 +BATCH_SIZE = 1 + +NO_MASK, ATTN_MASK, KEY_PAD_MASK, MERGED_MASK = -1, 0, 1, 2 + +class aten_native_multi_head_attention(torch.nn.Module): + def __init__(self, mask, need_weights, average_attn_weights) -> None: + super().__init__() + self.qkv = torch.nn.Linear(EMBED_DIM, 3 * EMBED_DIM, dtype = torch.float32) + self.qkv.requires_grad_(False) + self.proj = torch.nn.Linear(EMBED_DIM, EMBED_DIM, dtype = torch.float32) + self.proj.requires_grad_(False) + + self.embed_dim = EMBED_DIM + self.num_heads = NUM_HEADS + self.need_weights = need_weights + self.average_attn_weights = average_attn_weights + + # Currently only int masks are working correctly, they are converted to bool. + # Float masks raise a warning in PyTorch and are (incorrectly) converted to bool, + # which later returns NaNs as MHA's output + if mask == 0: + self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool)) + self.mask_type = 0 + elif mask == 1: + self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool)) + self.mask_type = 1 + elif mask == 2: + self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool)) + self.mask_type = 2 + else: + self.mask = None + self.mask_type = None + + print(self.mask) + + def forward(self, query, key, value): + return torch.ops.aten._native_multi_head_attention( + query, key, value, + embed_dim=self.embed_dim, num_head=self.num_heads, + qkv_weight=self.qkv.weight, qkv_bias=self.qkv.bias, + proj_weight=self.proj.weight, proj_bias=self.proj.bias, + mask = self.mask, need_weights=self.need_weights, + average_attn_weights = self.average_attn_weights, + mask_type = self.mask_type + )[0] + +class TestNativeMultiHeadAttention(PytorchLayerTest): + def _prepare_input(self): + # NativeMHA is self-attention + qkv_tensor = np.random.randn(BATCH_SIZE, SEQ_LENGTH, EMBED_DIM).astype(np.float32) + return (qkv_tensor.copy(), + qkv_tensor.copy(), + qkv_tensor.copy()) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize( + "mask", + [NO_MASK, ATTN_MASK, KEY_PAD_MASK, MERGED_MASK] + ) + @pytest.mark.parametrize( + ["need_weights", "average_attn_weights"], + [[False, False], [True, False], [True, True]] + ) + def test_native_multi_head_attention(self, ie_device, precision, ir_version, mask, need_weights, average_attn_weights): + self._test(aten_native_multi_head_attention(mask, need_weights, average_attn_weights), + None, "aten::_native_multi_head_attention", ie_device, precision, ir_version)