[PT FE]: implement scaled dot product attention (#17178)

* [PT FE]: implement scaled dot product attention

* Apply suggestions from code review

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Update src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Ekaterina Aidova
2023-04-26 12:51:02 +04:00
committed by GitHub
parent 5857c4438b
commit 6389f423bf
3 changed files with 152 additions and 0 deletions

View File

@@ -0,0 +1,109 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/logical_not.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/softmax.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 6, 6);
auto query = context.get_input(0);
auto key = context.get_input(1);
auto value = context.get_input(2);
auto q_shape = context.mark_node(std::make_shared<v3::ShapeOf>(query, element::i32));
auto k_shape = context.mark_node(std::make_shared<v3::ShapeOf>(key, element::i32));
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto minus_two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-2}));
auto zero_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto scale = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_one, zero_i));
scale = context.mark_node(std::make_shared<v1::ConvertLike>(scale, query));
auto sqrt_scale = context.mark_node(std::make_shared<v0::Sqrt>(scale));
auto one_f = context.mark_node(std::make_shared<v1::ConvertLike>(one_i, sqrt_scale));
auto zero_f = context.mark_node(std::make_shared<v1::ConvertLike>(zero_i, sqrt_scale));
scale = context.mark_node(std::make_shared<v1::Divide>(one_f, sqrt_scale));
auto q_scaled = context.mark_node(std::make_shared<v1::Multiply>(query, scale));
auto k_rank = context.mark_node(std::make_shared<v3::ShapeOf>(k_shape, element::i32));
auto k_last_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_one));
auto k_next_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_two));
k_rank = context.mark_node(std::make_shared<v0::Squeeze>(k_rank, zero_i));
auto minus_inf =
context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
auto keep_dim_last = context.mark_node(std::make_shared<v0::Squeeze>(k_next_dim, zero_i));
auto k_dims_before_transpose =
context.mark_node(std::make_shared<v4::Range>(zero_i, keep_dim_last, one_i, element::i32));
auto transpose_dims = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0));
auto k_transposed = context.mark_node(std::make_shared<v1::Transpose>(key, transpose_dims));
auto scaled_atten = context.mark_node(std::make_shared<v0::MatMul>(q_scaled, k_transposed));
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten));
// two types of masks are supported. A boolean mask where a value of True indicates that the element should take
// part in attention. A float mask of the same type as query, key, value that is added to the attention score.
auto is_causal = context.const_input<bool>(5);
if (is_causal || !context.input_is_none(3)) {
Output<Node> mask;
Output<Node> atten_mask;
if (!context.input_is_none(3)) {
mask = context.get_input(3);
if (mask.get_element_type() == element::boolean) {
atten_mask = context.mark_node(std::make_shared<v1::ConvertLike>(mask, scaled_atten));
auto inv_mask = context.mark_node(std::make_shared<v1::LogicalNot>(mask));
atten_mask = context.mark_node(std::make_shared<v1::Select>(inv_mask, atten_mask, minus_inf));
} else {
atten_mask = mask;
}
} else {
auto target_s_len = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_two, zero_i));
auto source_s_len = context.mark_node(std::make_shared<v8::Gather>(k_shape, minus_two, zero_i));
auto ssl = context.mark_node(std::make_shared<v0::Unsqueeze>(source_s_len, zero_i));
auto tsl = context.mark_node(std::make_shared<v0::Unsqueeze>(target_s_len, zero_i));
auto mask_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{tsl, ssl}, 0));
mask = context.mark_node(std::make_shared<v1::Broadcast>(minus_inf, mask_shape));
auto horizontal_range =
context.mark_node(std::make_shared<v4::Range>(zero_i, source_s_len, one_i, element::i32));
horizontal_range = context.mark_node(std::make_shared<v0::Unsqueeze>(horizontal_range, zero_i));
auto stop = context.mark_node(std::make_shared<v1::Add>(target_s_len, one_i));
auto vertical_range = context.mark_node(std::make_shared<v4::Range>(one_i, stop, one_i, element::i32));
vertical_range = context.mark_node(std::make_shared<v0::Unsqueeze>(vertical_range, one_i));
auto triu = context.mark_node(std::make_shared<v1::GreaterEqual>(horizontal_range, vertical_range));
atten_mask = context.mark_node(std::make_shared<v1::Select>(triu, mask, zero_f));
}
scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask));
}
scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1));
return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -104,6 +104,7 @@ OP_CONVERTER(translate_roi_align);
OP_CONVERTER(translate_roll);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_rsub);
OP_CONVERTER(translate_scaled_dot_product_attention);
OP_CONVERTER(translate_select);
OP_CONVERTER(translate_set_item);
OP_CONVERTER(translate_selu);
@@ -296,6 +297,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::rsqrt", op::translate_rsqrt},
{"aten::rsub", op::translate_rsub},
{"aten::ScalarImplicit", op::skip_node},
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
{"aten::select", op::translate_select},
{"aten::selu", op::translate_selu},
{"aten::selu_", op::inplace_op<op::translate_selu>},

View File

@@ -0,0 +1,41 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestScaledDotProductAttention(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 2, 8, 4).astype(np.float32), np.random.randn(1, 2, 8, 4).astype(np.float32), np.random.randn(1, 2, 8, 4).astype(np.float32))
def create_model(self, mask, is_causal):
import torch.nn.functional as F
import torch
class aten_scaled_dot_product_atten(torch.nn.Module):
def __init__(self, mask=False, is_causal=False) -> None:
super().__init__()
self.mask = None if not mask else torch.from_numpy(np.random.randint(0, 2, (8, 8)).astype(np.float32))
self.is_causal = is_causal
if is_causal and mask:
self.mask.to(torch.bool)
self.is_causal = False
def forward(self, query, key, value):
return F.scaled_dot_product_attention(query, key, value, attn_mask=self.mask, is_causal=self.is_causal)
ref_net = None
return aten_scaled_dot_product_atten(mask, is_causal), ref_net, "aten::scaled_dot_product_attention"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(['mask', "is_causal"], [(False, False), (False, True), (True, True), (True, False)])
def test_scaled_dot_product_atten(self, ie_device, precision, ir_version, mask, is_causal):
self._test(*self.create_model(mask, is_causal),ie_device, precision, ir_version)