[PT FE]: implement scaled dot product attention (#17178)
* [PT FE]: implement scaled dot product attention * Apply suggestions from code review Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * Update src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
109
src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
Normal file
109
src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/greater_eq.hpp"
|
||||
#include "openvino/op/logical_not.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/softmax.hpp"
|
||||
#include "openvino/op/sqrt.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
|
||||
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
|
||||
// dropout_p=0., bool is_causal=False)
|
||||
num_inputs_check(context, 6, 6);
|
||||
auto query = context.get_input(0);
|
||||
auto key = context.get_input(1);
|
||||
auto value = context.get_input(2);
|
||||
auto q_shape = context.mark_node(std::make_shared<v3::ShapeOf>(query, element::i32));
|
||||
auto k_shape = context.mark_node(std::make_shared<v3::ShapeOf>(key, element::i32));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto minus_two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-2}));
|
||||
auto zero_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto scale = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_one, zero_i));
|
||||
scale = context.mark_node(std::make_shared<v1::ConvertLike>(scale, query));
|
||||
auto sqrt_scale = context.mark_node(std::make_shared<v0::Sqrt>(scale));
|
||||
auto one_f = context.mark_node(std::make_shared<v1::ConvertLike>(one_i, sqrt_scale));
|
||||
auto zero_f = context.mark_node(std::make_shared<v1::ConvertLike>(zero_i, sqrt_scale));
|
||||
scale = context.mark_node(std::make_shared<v1::Divide>(one_f, sqrt_scale));
|
||||
auto q_scaled = context.mark_node(std::make_shared<v1::Multiply>(query, scale));
|
||||
auto k_rank = context.mark_node(std::make_shared<v3::ShapeOf>(k_shape, element::i32));
|
||||
auto k_last_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_one));
|
||||
auto k_next_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_two));
|
||||
k_rank = context.mark_node(std::make_shared<v0::Squeeze>(k_rank, zero_i));
|
||||
auto minus_inf =
|
||||
context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
|
||||
auto keep_dim_last = context.mark_node(std::make_shared<v0::Squeeze>(k_next_dim, zero_i));
|
||||
auto k_dims_before_transpose =
|
||||
context.mark_node(std::make_shared<v4::Range>(zero_i, keep_dim_last, one_i, element::i32));
|
||||
|
||||
auto transpose_dims = context.mark_node(
|
||||
std::make_shared<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0));
|
||||
auto k_transposed = context.mark_node(std::make_shared<v1::Transpose>(key, transpose_dims));
|
||||
auto scaled_atten = context.mark_node(std::make_shared<v0::MatMul>(q_scaled, k_transposed));
|
||||
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten));
|
||||
// two types of masks are supported. A boolean mask where a value of True indicates that the element should take
|
||||
// part in attention. A float mask of the same type as query, key, value that is added to the attention score.
|
||||
auto is_causal = context.const_input<bool>(5);
|
||||
if (is_causal || !context.input_is_none(3)) {
|
||||
Output<Node> mask;
|
||||
Output<Node> atten_mask;
|
||||
if (!context.input_is_none(3)) {
|
||||
mask = context.get_input(3);
|
||||
if (mask.get_element_type() == element::boolean) {
|
||||
atten_mask = context.mark_node(std::make_shared<v1::ConvertLike>(mask, scaled_atten));
|
||||
auto inv_mask = context.mark_node(std::make_shared<v1::LogicalNot>(mask));
|
||||
atten_mask = context.mark_node(std::make_shared<v1::Select>(inv_mask, atten_mask, minus_inf));
|
||||
} else {
|
||||
atten_mask = mask;
|
||||
}
|
||||
} else {
|
||||
auto target_s_len = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_two, zero_i));
|
||||
auto source_s_len = context.mark_node(std::make_shared<v8::Gather>(k_shape, minus_two, zero_i));
|
||||
auto ssl = context.mark_node(std::make_shared<v0::Unsqueeze>(source_s_len, zero_i));
|
||||
auto tsl = context.mark_node(std::make_shared<v0::Unsqueeze>(target_s_len, zero_i));
|
||||
auto mask_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{tsl, ssl}, 0));
|
||||
mask = context.mark_node(std::make_shared<v1::Broadcast>(minus_inf, mask_shape));
|
||||
auto horizontal_range =
|
||||
context.mark_node(std::make_shared<v4::Range>(zero_i, source_s_len, one_i, element::i32));
|
||||
horizontal_range = context.mark_node(std::make_shared<v0::Unsqueeze>(horizontal_range, zero_i));
|
||||
auto stop = context.mark_node(std::make_shared<v1::Add>(target_s_len, one_i));
|
||||
auto vertical_range = context.mark_node(std::make_shared<v4::Range>(one_i, stop, one_i, element::i32));
|
||||
vertical_range = context.mark_node(std::make_shared<v0::Unsqueeze>(vertical_range, one_i));
|
||||
auto triu = context.mark_node(std::make_shared<v1::GreaterEqual>(horizontal_range, vertical_range));
|
||||
atten_mask = context.mark_node(std::make_shared<v1::Select>(triu, mask, zero_f));
|
||||
}
|
||||
scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask));
|
||||
}
|
||||
scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1));
|
||||
return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -104,6 +104,7 @@ OP_CONVERTER(translate_roi_align);
|
||||
OP_CONVERTER(translate_roll);
|
||||
OP_CONVERTER(translate_rsqrt);
|
||||
OP_CONVERTER(translate_rsub);
|
||||
OP_CONVERTER(translate_scaled_dot_product_attention);
|
||||
OP_CONVERTER(translate_select);
|
||||
OP_CONVERTER(translate_set_item);
|
||||
OP_CONVERTER(translate_selu);
|
||||
@@ -296,6 +297,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::rsqrt", op::translate_rsqrt},
|
||||
{"aten::rsub", op::translate_rsub},
|
||||
{"aten::ScalarImplicit", op::skip_node},
|
||||
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
|
||||
{"aten::select", op::translate_select},
|
||||
{"aten::selu", op::translate_selu},
|
||||
{"aten::selu_", op::inplace_op<op::translate_selu>},
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestScaledDotProductAttention(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(1, 2, 8, 4).astype(np.float32), np.random.randn(1, 2, 8, 4).astype(np.float32), np.random.randn(1, 2, 8, 4).astype(np.float32))
|
||||
|
||||
def create_model(self, mask, is_causal):
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
class aten_scaled_dot_product_atten(torch.nn.Module):
|
||||
|
||||
def __init__(self, mask=False, is_causal=False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.mask = None if not mask else torch.from_numpy(np.random.randint(0, 2, (8, 8)).astype(np.float32))
|
||||
self.is_causal = is_causal
|
||||
if is_causal and mask:
|
||||
self.mask.to(torch.bool)
|
||||
self.is_causal = False
|
||||
|
||||
def forward(self, query, key, value):
|
||||
return F.scaled_dot_product_attention(query, key, value, attn_mask=self.mask, is_causal=self.is_causal)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_scaled_dot_product_atten(mask, is_causal), ref_net, "aten::scaled_dot_product_attention"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize(['mask', "is_causal"], [(False, False), (False, True), (True, True), (True, False)])
|
||||
def test_scaled_dot_product_atten(self, ie_device, precision, ir_version, mask, is_causal):
|
||||
self._test(*self.create_model(mask, is_causal),ie_device, precision, ir_version)
|
||||
Reference in New Issue
Block a user