[Opset13] ScaledDotProductAttention-13 input validation improvements (#21281)

* Improve test naming to match op

* Initial draft for scaleddot shape infer

* Fix formatting

* Improve shape & type infer

* Add shape tests

* Fix issues in shape infer

* Fix type infer

* Improve type_prop tests

* Ignore attention when causal

* Fix template type shape infer

* Fix issues with scalar inputs + test improvement

* Fix `get_node_input_partial_shapes`

* Improve shape/type validation and tests

* Allow for broadcastable inputs

* Improve tests

* Add CPU shape infer test

* Add broadcast shape infer test

* Use const

* Use const &

* Improve query input handling

* Fix test issues

* Change broadcast rules

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Mateusz Mikolajczyk
2023-12-18 10:30:54 +01:00
committed by GitHub
parent 3a2958b360
commit 610e0fab5c
6 changed files with 541 additions and 120 deletions

View File

@@ -0,0 +1,114 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace v13 {
template <class T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const ScaledDotProductAttention* op,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
const bool& iscausal = op->get_causal();
using DimType = typename T::value_type;
const auto& inputs_count = input_shapes.size();
const auto& has_attention_mask_input = inputs_count >= 4;
const auto& has_scale_input = inputs_count == 5;
NODE_VALIDATION_CHECK(op, inputs_count == 3 || has_attention_mask_input || has_scale_input);
DimType e_dim{};
DimType l_dim{};
DimType s_dim{};
DimType ev_dim{};
TRShape n_dims = input_shapes[0];
const auto& n_dims_rank = n_dims.rank();
if (n_dims_rank.is_static()) {
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
n_dims_rank.get_length() >= 3,
"Query input rank length must be at least 3 or more.");
l_dim = *(n_dims.end() - 2);
e_dim = *(n_dims.end() - 1);
n_dims.resize(n_dims.size() - 2);
}
const auto& key = input_shapes[1];
const auto& key_rank = key.rank();
if (key_rank.is_static()) {
const bool& key_input_correctness =
key_rank.get_length() >= 3 &&
TRShape::broadcast_merge_into(n_dims,
TRShape(std::vector<DimType>(key.begin(), key.end() - 2)),
AutoBroadcastType::NUMPY) &&
DimType::merge(e_dim, e_dim, *(key.end() - 1));
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
key_input_correctness,
"Key input shape not compatible with other inputs.");
s_dim = *(key.end() - 2);
}
const auto& value = input_shapes[2];
const auto& value_rank = value.rank();
if (value_rank.is_static()) {
const bool& value_input_correctness =
value_rank.get_length() >= 3 &&
TRShape::broadcast_merge_into(n_dims,
TRShape(std::vector<DimType>(value.begin(), value.end() - 2)),
AutoBroadcastType::NUMPY) &&
DimType::merge(s_dim, s_dim, *(value.end() - 2));
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
value_input_correctness,
"Value input shape not compatible with other inputs.");
ev_dim = *(value.end() - 1);
}
if (has_attention_mask_input && !iscausal) {
const auto& attention_mask = input_shapes[3];
const auto& attention_mask_rank = attention_mask.rank();
if (attention_mask_rank.is_static() && attention_mask_rank != 0) {
const auto& attention_mask_rank_len = attention_mask_rank.get_length();
bool attention_mask_input_correctness = attention_mask_rank_len >= 2 &&
DimType::merge(l_dim, l_dim, *(attention_mask.end() - 2)) &&
DimType::merge(s_dim, s_dim, *(attention_mask.end() - 1));
if (attention_mask_rank_len >= 3) {
attention_mask_input_correctness =
attention_mask_input_correctness &&
TRShape::broadcast_merge_into(
n_dims,
TRShape(std::vector<DimType>(attention_mask.begin(), attention_mask.end() - 2)),
AutoBroadcastType::NUMPY);
}
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
attention_mask_input_correctness,
"Attention mask input shape not compatible with other inputs.");
}
}
if (has_scale_input) {
const auto& scale_rank = input_shapes[4].rank();
if (scale_rank.is_static() || input_shapes[4].is_static()) {
const auto& scale_is_scalar = scale_rank.compatible(0);
const auto& scale_has_one_elem = scale_rank.compatible(1) && input_shapes[4][0].compatible(1);
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
scale_is_scalar || scale_has_one_elem,
"Scale input must be scalar or have 1 element.");
}
}
if (n_dims.rank().is_static()) {
n_dims.push_back(l_dim);
n_dims.push_back(ev_dim);
}
return {n_dims};
}
} // namespace v13
} // namespace op
} // namespace ov

View File

@@ -5,6 +5,7 @@
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "itt.hpp"
#include "scaled_dot_product_attention_shape_inference.hpp"
using namespace std;
namespace ov {
@@ -38,25 +39,33 @@ op::v13::ScaledDotProductAttention::ScaledDotProductAttention(const Output<Node>
void op::v13::ScaledDotProductAttention::validate_and_infer_types() {
OV_OP_SCOPE(v13_ScaledDotProductAttention_validate_and_infer_types);
NODE_VALIDATION_CHECK(this, get_input_size() >= 3 && get_input_size() <= 5);
// TODO: More checks and accurate deduction of dimensions in case when various
// dynamic combinations appear.
auto query = get_input_partial_shape(0);
auto key = get_input_partial_shape(1);
auto value = get_input_partial_shape(2);
// using particular dimensions from query and value, to do that need to have them statically ranked
if (query.rank().is_dynamic() || value.rank().is_dynamic()) {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
auto out_type = get_input_element_type(0);
const auto& input_size = get_input_size();
const auto& causal = get_causal();
if (input_size >= 4 && !causal) {
const auto& attention_type = get_input_element_type(3);
NODE_VALIDATION_CHECK(
this,
attention_type.is_real() || attention_type == element::boolean || attention_type.is_dynamic(),
"The element type of attention_mask must be either floating-point or boolean.");
}
for (size_t i = 1; i < input_size; i++) {
const auto& element_type = get_input_element_type(i);
if (i == 3 && (element_type == element::boolean || causal)) {
// Skip checking attention_mask in loop when boolean or skipped to not affect merged dtype.
continue;
}
NODE_VALIDATION_CHECK(this,
element::Type::merge(out_type, out_type, element_type),
"Mixed input types are not supported.");
}
NODE_VALIDATION_CHECK(this,
out_type.is_real() || out_type.is_dynamic(),
"The element type of the input tensor must be a floating-point.");
OPENVINO_ASSERT(query.rank().get_length() >= 3);
OPENVINO_ASSERT(value.rank().get_length() >= 3);
auto dimensions = std::vector<Dimension>(query.begin(), query.end() - 1);
dimensions.push_back(*(value.end() - 1));
set_output_type(0, get_input_element_type(0), PartialShape(dimensions));
const auto& input_shapes = ov::util::get_node_input_partial_shapes(*this);
const auto& output_shapes = shape_infer(this, input_shapes);
set_output_type(0, out_type, output_shapes[0]);
}
std::shared_ptr<Node> op::v13::ScaledDotProductAttention::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@@ -12,156 +12,382 @@
using namespace ov;
using namespace testing;
TEST(type_prop, scale_dot_product_attention_static_5_inputs) {
TEST(type_prop, scaled_dot_product_attention_static_5_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{1});
auto causal = false;
const auto gn =
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (Shape{2, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_4_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 3, 5});
TEST(type_prop, scaled_dot_product_attention_static_4_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 5});
auto causal = false;
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_3_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
TEST(type_prop, scaled_dot_product_attention_static_3_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
auto causal = false;
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_3_inputs_causal) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
TEST(type_prop, scaled_dot_product_attention_static_3_inputs_causal) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
auto causal = true;
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_ignored_attention_mask) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{7, 8, 9, 10, 11});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
TEST(type_prop, scaled_dot_product_scalar_attention_causal_false) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = false;
const auto gn =
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_5_inputs_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 1, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
auto causal = false;
const auto gn =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_4_inputs_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 1, 3, 5});
auto causal = false;
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_3_inputs_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
auto causal = false;
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_3_inputs_extra_batch_causal_true) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 8, 9, 10, 11});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = true;
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_static_ignored_attention_mask_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{7, 8, 9, 10, 11});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
TEST(type_prop, scaled_dot_product_attention_static_5_inputs_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 1, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = false;
const auto gn =
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
}
TEST(type_prop, scale_dot_product_attention_dynamic_3d) {
TEST(type_prop, scaled_dot_product_attention_static_4_inputs_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 1, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 1, 3, 5});
auto causal = false;
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
}
TEST(type_prop, scaled_dot_product_attention_static_3_inputs_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 6});
auto causal = false;
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
}
TEST(type_prop, scaled_dot_product_attention_static_3_inputs_extra_batch_causal_true) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 5, 6});
auto causal = true;
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
}
TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 1, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 7, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 8, 9, 10, 11});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = true;
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
}
TEST(type_prop, scaled_dot_product_attention_5_inputs_dynamic_rank) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
auto causal = false;
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape::dynamic()));
}
TEST(type_prop, scaled_dot_product_attention_dynamic_3d) {
const auto dynamic = PartialShape{-1, -1, -1};
const auto query = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto key = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto value = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = false;
const auto gn =
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_output_partial_shape(0), (dynamic));
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (dynamic));
}
TEST(type_prop, scale_dot_product_attention_dynamic_4d) {
TEST(type_prop, scaled_dot_product_attention_dynamic_4d) {
const auto dynamic = PartialShape{-1, -1, -1, -1};
const auto query = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto key = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto value = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = false;
const auto gn =
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_output_partial_shape(0), (dynamic));
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (dynamic));
}
TEST(type_prop, scaled_dot_product_attention_mixed_shape_infer_5_inputs) {
PartialShape query_shape{{2, 4}, 3, {2, 5}, 4};
set_shape_labels(query_shape, 10);
PartialShape key_shape{{4, 8}, {2, 4}, 5, 4};
set_shape_labels(key_shape, 20);
PartialShape value_shape{{2, 4}, 3, 5, {3, 7}};
set_shape_labels(value_shape, 40);
PartialShape attention_mask_shape{{2, 7}, 3, {4, 7}, 5};
set_shape_labels(attention_mask_shape, 50);
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, query_shape);
const auto key = std::make_shared<opset13::Parameter>(element::f64, key_shape);
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, value_shape);
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, attention_mask_shape);
const auto scale = std::make_shared<opset13::Parameter>(element::f64, PartialShape{-1});
auto causal = false;
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, {4, 5}, {3, 7}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), testing::ElementsAre(50, 51, 52, 43));
}
TEST(type_prop, scaled_dot_product_attention_mixed_shape_infer_5_inputs_ignore_attention) {
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{1, 4}, 3, {1, 5}, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{{4, 8}, {1, 4}, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{2, 4}, 3, 5, {3, 7}});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::i64, PartialShape{57, 3, {4, 7}, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f64, PartialShape{});
auto causal = true;
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, {1, 5}, {3, 7}}));
}
TEST(type_prop, scaled_dot_product_attention_infer_5_dynamic_attn_partial) {
const auto dynamic = PartialShape::dynamic();
const auto query = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto key = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto value = std::make_shared<opset13::Parameter>(element::f32, dynamic);
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, -1, 5, 7});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = false;
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (dynamic));
}
TEST(type_prop, scaled_dot_product_attention_mixed_shape_infer_4_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{1, 4}, 4, {2, 5}, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{{2, 8}, {1, 4}, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{2, 4}, 4, 5, {3, 7}});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, PartialShape{4, {4, 7}, 5});
auto causal = false;
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{2, 4}, 4, {4, 5}, {3, 7}}));
}
TEST(type_prop, scaled_dot_product_attention_type_infer_5_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{1, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = false;
const auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scaled_dot_product_attention_type_infer_4_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, PartialShape{3, 5});
auto causal = false;
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scaled_dot_product_attention_type_infer_4_inputs_bool_attention) {
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::boolean, PartialShape{1, 3, 5});
auto causal = false;
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}
TEST(type_prop, scaled_dot_product_unsupported_key_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
auto causal = false;
OV_EXPECT_THROW(
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("Key input shape not compatible with other inputs."));
}
TEST(type_prop, scaled_dot_product_unsupported_value_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
auto causal = false;
OV_EXPECT_THROW(
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("Value input shape not compatible with other inputs."));
}
TEST(type_prop, scaled_dot_product_unsupported_attention_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
auto causal = false;
OV_EXPECT_THROW(
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("Attention mask input shape not compatible with other inputs."));
}
TEST(type_prop, scaled_dot_product_unsupported_scale_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2});
auto causal = false;
OV_EXPECT_THROW(
auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal),
AssertFailure,
testing::HasSubstr("Scale input must be scalar or have 1 element."));
}
TEST(type_prop, scaled_dot_product_unsupported_dtype) {
const auto query = std::make_shared<opset13::Parameter>(element::i32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::i32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::i32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::boolean, PartialShape{3, 3, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::i32, PartialShape{});
auto causal = false;
OV_EXPECT_THROW(
auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal),
AssertFailure,
testing::HasSubstr("The element type of the input tensor must be a floating-point."));
}
TEST(type_prop, scaled_dot_product_unsupported_value_dtype_mixed) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, PartialShape{3, 3, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
auto causal = false;
OV_EXPECT_THROW(
auto op =
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal),
AssertFailure,
testing::HasSubstr("Mixed input types are not supported."));
}
TEST(type_prop, scaled_dot_product_unsuported_attention_type) {
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::i32, PartialShape{1, 3, 5});
auto causal = false;
OV_EXPECT_THROW(
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("The element type of attention_mask must be either floating-point or boolean."));
}

View File

@@ -94,6 +94,7 @@
#include "roi_align_shape_inference.hpp"
#include "roi_pooling_shape_inference.hpp"
#include "roll_shape_inference.hpp"
#include "scaled_dot_product_attention_shape_inference.hpp"
#include "scatter_elements_update_shape_inference.hpp"
#include "scatter_nd_base_shape_inference.hpp"
#include "select_shape_inference.hpp"
@@ -398,6 +399,7 @@ template <>
const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
// opset13
_OV_OP_SHAPE_INFER_MASK_REG(opset13::Multinomial, ShapeInferTA, util::bit::mask(1)),
_OV_OP_SHAPE_INFER_MASK_REG(opset13::ScaledDotProductAttention, ShapeInferTA, util::bit::mask(3, 5)),
// opset12
_OV_OP_SHAPE_INFER_MASK_REG(opset12::Pad, ShapeInferTA, util::bit::mask(1, 2)),
_OV_OP_SHAPE_INFER_MASK_REG(opset12::ScatterElementsUpdate, ShapeInferTA, util::bit::mask(3)),

View File

@@ -23,9 +23,9 @@ const std::vector<std::vector<InputShape>> shapes{
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
},
// attn shape: [B, 1, 1, L0+L1]
{ov::test::InputShape{ov::PartialShape{-1, 1, 1, -1},
{ov::Shape{1, 1, 1, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 1, 10}}}
// attn shape: [B, 1, -1, L0+L1]
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1},
{ov::Shape{1, 1, 100, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 10, 10}}}
},
},
// heads number of kv is 1, attn mask: [B, H, L1, L0+L1]
@@ -75,4 +75,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_CPU,
} // namespace ScaledAttn
} // namespace test
} // namespace ov
} // namespace ov

View File

@@ -0,0 +1,70 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gmock/gmock.h>
#include "common_test_utils/test_assertions.hpp"
#include "openvino/opsets/opset13.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
class ScaledDotProductAttentionV13StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v13::ScaledDotProductAttention> {
};
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, default_ctor) {
op = make_op();
input_shapes = ShapeVector{{3, 2, 3, 4}, {2, 5, 4}, {1, 5, 6}, {1, 3, 5}, {}};
output_shapes = shape_inference(op.get(), input_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({3, 2, 3, 6}));
}
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, dynamic_shapes) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1});
auto causal = false;
op = make_op(query, key, value, attention_mask, scale, causal);
input_shapes = ShapeVector{{2, 3, 4}, {2, 5, 4}, {2, 5, 6}, {1, 3, 5}, {}};
output_shapes = shape_inference(op.get(), input_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({2, 3, 6}));
}
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, static_shapes) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{1});
auto causal = false;
op = make_op(query, key, value, attention_mask, scale, causal);
input_shapes = ShapeVector{{2, 3, 4}, {2, 5, 4}, {2, 5, 6}, {1, 3, 5}, {1}};
output_shapes = shape_inference(op.get(), input_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({2, 3, 6}));
}
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, mixed_shapes) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, {2, 3}, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, {2, 7}, -1});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, {3, 5}, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
auto causal = false;
op = make_op(query, key, value, attention_mask, scale, causal);
input_shapes = ShapeVector{{2, 3, 4}, {2, 5, 4}, {2, 5, 6}, {1, 3, 5}, {}};
output_shapes = shape_inference(op.get(), input_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({2, 3, 6}));
}