Llm and sd additional ops (#20435)
* TorchFX: New ops added (baddbbmm, leaky_relu_) * TorchFX: Initial scaled_dot_product_flash_attention * Code Formatting: scaled_fot_product_attention translation * TorchFX unit test enabled for SDPA * Typo fix in comment line Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
5018be82c3
commit
3d5fe8d446
@ -41,6 +41,7 @@ class OperatorSupport(OperatorSupport):
|
||||
"torch.ops.aten.arange.default": None,
|
||||
"torch.ops.aten.argmax.default": None,
|
||||
"torch.ops.aten.avg_pool2d.default": None,
|
||||
"torch.ops.aten.baddbmm.default": None,
|
||||
"torch.ops.aten.bitwise_and.Tensor": None,
|
||||
"torch.ops.aten.bmm.default": None,
|
||||
"torch.ops.aten.cat.default": None,
|
||||
@ -67,6 +68,7 @@ class OperatorSupport(OperatorSupport):
|
||||
"torch.ops.aten.hardswish_.default": None,
|
||||
"torch.ops.aten.hardtanh_.default": None,
|
||||
"torch.ops.aten.index.Tensor": None,
|
||||
"torch.ops.aten.leaky_relu_.default": None,
|
||||
"torch.ops.aten.lift_fresh_copy.default": None,
|
||||
"torch.ops.aten.linalg_vector_norm.default": None,
|
||||
"torch.ops.aten.lt.Tensor": None,
|
||||
@ -89,6 +91,7 @@ class OperatorSupport(OperatorSupport):
|
||||
"torch.ops.aten.relu.default": None,
|
||||
"torch.ops.aten.relu_.default": None,
|
||||
"torch.ops.aten.rsub.Scalar": None,
|
||||
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
|
||||
"torch.ops.aten.select.int": None,
|
||||
"torch.ops.aten.sigmoid.default": None,
|
||||
"torch.ops.aten.silu.default": None,
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/softmax.hpp"
|
||||
@ -22,6 +23,7 @@
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -31,10 +33,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
|
||||
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
|
||||
// dropout_p=0., bool is_causal=False)
|
||||
num_inputs_check(context, 6, 6);
|
||||
std::shared_ptr<ov::Node> translate_scaled_dot_product_attention_common(const NodeContext& context) {
|
||||
auto query = context.get_input(0);
|
||||
auto key = context.get_input(1);
|
||||
auto value = context.get_input(2);
|
||||
@ -68,7 +67,10 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
|
||||
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten));
|
||||
// two types of masks are supported. A boolean mask where a value of True indicates that the element should take
|
||||
// part in attention. A float mask of the same type as query, key, value that is added to the attention score.
|
||||
auto is_causal = context.const_input<bool>(5);
|
||||
auto is_causal = false;
|
||||
if (!context.input_is_none(5)) {
|
||||
is_causal = context.const_input<bool>(5);
|
||||
}
|
||||
if (is_causal || !context.input_is_none(3)) {
|
||||
Output<Node> mask;
|
||||
Output<Node> atten_mask;
|
||||
@ -100,10 +102,30 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
|
||||
scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask));
|
||||
}
|
||||
scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1));
|
||||
return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))};
|
||||
return context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value));
|
||||
};
|
||||
|
||||
OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
|
||||
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
|
||||
// dropout_p=0., bool is_causal=False)
|
||||
num_inputs_check(context, 6, 6);
|
||||
return {translate_scaled_dot_product_attention_common(context)};
|
||||
};
|
||||
|
||||
OutputVector translate_scaled_dot_product_attention_fx(const NodeContext& context) {
|
||||
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
|
||||
// dropout_p=0., bool is_causal=False)
|
||||
num_inputs_check(context, 3, 6);
|
||||
auto output = translate_scaled_dot_product_attention_common(context);
|
||||
// TODO: scaled_dot_product_flash_attention has 9 outputs but for most cases only
|
||||
// the first input is used. Rest of the outputs should be returned properly as
|
||||
// needed.
|
||||
ov::OutputVector out_vec;
|
||||
out_vec.push_back(output);
|
||||
return {context.mark_node(make_list_construct(out_vec))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -213,6 +213,7 @@ OP_CONVERTER(translate_group_norm_fx);
|
||||
OP_CONVERTER(translate_index_fx);
|
||||
OP_CONVERTER(translate_layer_norm_fx);
|
||||
OP_CONVERTER(translate_max_poolnd_fx);
|
||||
OP_CONVERTER(translate_scaled_dot_product_attention_fx);
|
||||
OP_CONVERTER(translate_slice_fx);
|
||||
OP_CONVERTER(translate_softmax_fx);
|
||||
OP_CONVERTER(translate_transpose_fx);
|
||||
@ -555,6 +556,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
|
||||
{"aten.arange.default", op::translate_arange_fx},
|
||||
{"aten.argmax.default", op::translate_argmax},
|
||||
{"aten.avg_pool2d.default", op::translate_avg_poolnd},
|
||||
{"aten.baddbmm.default", op::translate_addmm},
|
||||
{"aten.bitwise_and.Tensor", op::translate_bitwise_and},
|
||||
{"aten.bmm.default", op::translate_1to1_match_2_inputs_align_types<opset10::MatMul>},
|
||||
{"aten.cat.default", op::translate_cat_fx},
|
||||
@ -581,6 +583,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
|
||||
{"aten.hardswish_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
|
||||
{"aten.hardtanh_.default", op::inplace_op<op::translate_hardtanh>},
|
||||
{"aten.index.Tensor", op::translate_index_fx},
|
||||
{"aten.leaky_relu_.default", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::PRelu>>},
|
||||
{"aten.lift_fresh_copy.default", op::skip_node},
|
||||
{"aten.linalg_vector_norm.default", op::translate_linalg_vector_norm},
|
||||
{"aten.log.default", op::translate_log},
|
||||
@ -603,6 +606,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
|
||||
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
|
||||
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
|
||||
{"aten.rsub.Scalar", op::translate_rsub},
|
||||
{"aten._scaled_dot_product_flash_attention.default", op::translate_scaled_dot_product_attention_fx},
|
||||
{"aten.select.int", op::translate_select},
|
||||
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
|
||||
{"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},
|
||||
|
@ -36,6 +36,7 @@ class TestScaledDotProductAttention(PytorchLayerTest):
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.precommit_fx_backend
|
||||
@pytest.mark.parametrize(['mask', "is_causal"], [(False, False), (False, True), (True, True), (True, False)])
|
||||
def test_scaled_dot_product_atten(self, ie_device, precision, ir_version, mask, is_causal):
|
||||
self._test(*self.create_model(mask, is_causal),ie_device, precision, ir_version)
|
||||
|
Loading…
Reference in New Issue
Block a user