Llm and sd additional ops (#20435)

* TorchFX: New ops added (baddbbmm, leaky_relu_)

* TorchFX: Initial scaled_dot_product_flash_attention

* Code Formatting: scaled_fot_product_attention translation

* TorchFX unit test enabled for SDPA

* Typo fix in comment line

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Mustafa Cavus 2023-10-19 10:21:28 -07:00 committed by GitHub
parent 5018be82c3
commit 3d5fe8d446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 7 deletions

View File

@ -41,6 +41,7 @@ class OperatorSupport(OperatorSupport):
"torch.ops.aten.arange.default": None,
"torch.ops.aten.argmax.default": None,
"torch.ops.aten.avg_pool2d.default": None,
"torch.ops.aten.baddbmm.default": None,
"torch.ops.aten.bitwise_and.Tensor": None,
"torch.ops.aten.bmm.default": None,
"torch.ops.aten.cat.default": None,
@ -67,6 +68,7 @@ class OperatorSupport(OperatorSupport):
"torch.ops.aten.hardswish_.default": None,
"torch.ops.aten.hardtanh_.default": None,
"torch.ops.aten.index.Tensor": None,
"torch.ops.aten.leaky_relu_.default": None,
"torch.ops.aten.lift_fresh_copy.default": None,
"torch.ops.aten.linalg_vector_norm.default": None,
"torch.ops.aten.lt.Tensor": None,
@ -89,6 +91,7 @@ class OperatorSupport(OperatorSupport):
"torch.ops.aten.relu.default": None,
"torch.ops.aten.relu_.default": None,
"torch.ops.aten.rsub.Scalar": None,
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
"torch.ops.aten.select.int": None,
"torch.ops.aten.sigmoid.default": None,
"torch.ops.aten.silu.default": None,

View File

@ -15,6 +15,7 @@
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/softmax.hpp"
@ -22,6 +23,7 @@
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"
namespace ov {
@ -31,10 +33,7 @@ namespace op {
using namespace ov::op;
OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 6, 6);
std::shared_ptr<ov::Node> translate_scaled_dot_product_attention_common(const NodeContext& context) {
auto query = context.get_input(0);
auto key = context.get_input(1);
auto value = context.get_input(2);
@ -68,7 +67,10 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten));
// two types of masks are supported. A boolean mask where a value of True indicates that the element should take
// part in attention. A float mask of the same type as query, key, value that is added to the attention score.
auto is_causal = context.const_input<bool>(5);
auto is_causal = false;
if (!context.input_is_none(5)) {
is_causal = context.const_input<bool>(5);
}
if (is_causal || !context.input_is_none(3)) {
Output<Node> mask;
Output<Node> atten_mask;
@ -100,10 +102,30 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask));
}
scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1));
return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))};
return context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value));
};
OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 6, 6);
return {translate_scaled_dot_product_attention_common(context)};
};
OutputVector translate_scaled_dot_product_attention_fx(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 3, 6);
auto output = translate_scaled_dot_product_attention_common(context);
// TODO: scaled_dot_product_flash_attention has 9 outputs but for most cases only
// the first input is used. Rest of the outputs should be returned properly as
// needed.
ov::OutputVector out_vec;
out_vec.push_back(output);
return {context.mark_node(make_list_construct(out_vec))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov

View File

@ -213,6 +213,7 @@ OP_CONVERTER(translate_group_norm_fx);
OP_CONVERTER(translate_index_fx);
OP_CONVERTER(translate_layer_norm_fx);
OP_CONVERTER(translate_max_poolnd_fx);
OP_CONVERTER(translate_scaled_dot_product_attention_fx);
OP_CONVERTER(translate_slice_fx);
OP_CONVERTER(translate_softmax_fx);
OP_CONVERTER(translate_transpose_fx);
@ -555,6 +556,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.arange.default", op::translate_arange_fx},
{"aten.argmax.default", op::translate_argmax},
{"aten.avg_pool2d.default", op::translate_avg_poolnd},
{"aten.baddbmm.default", op::translate_addmm},
{"aten.bitwise_and.Tensor", op::translate_bitwise_and},
{"aten.bmm.default", op::translate_1to1_match_2_inputs_align_types<opset10::MatMul>},
{"aten.cat.default", op::translate_cat_fx},
@ -581,6 +583,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.hardswish_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
{"aten.hardtanh_.default", op::inplace_op<op::translate_hardtanh>},
{"aten.index.Tensor", op::translate_index_fx},
{"aten.leaky_relu_.default", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::PRelu>>},
{"aten.lift_fresh_copy.default", op::skip_node},
{"aten.linalg_vector_norm.default", op::translate_linalg_vector_norm},
{"aten.log.default", op::translate_log},
@ -603,6 +606,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten.rsub.Scalar", op::translate_rsub},
{"aten._scaled_dot_product_flash_attention.default", op::translate_scaled_dot_product_attention_fx},
{"aten.select.int", op::translate_select},
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
{"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},

View File

@ -36,6 +36,7 @@ class TestScaledDotProductAttention(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize(['mask', "is_causal"], [(False, False), (False, True), (True, True), (True, False)])
def test_scaled_dot_product_atten(self, ie_device, precision, ir_version, mask, is_causal):
self._test(*self.create_model(mask, is_causal),ie_device, precision, ir_version)