[Opset13] ScaledDotProductAttention-13 input validation improvements (#21281)
* Improve test naming to match op * Initial draft for scaleddot shape infer * Fix formatting * Improve shape & type infer * Add shape tests * Fix issues in shape infer * Fix type infer * Improve type_prop tests * Ignore attention when causal * Fix template type shape infer * Fix issues with scalar inputs + test improvement * Fix `get_node_input_partial_shapes` * Improve shape/type validation and tests * Allow for broadcastable inputs * Improve tests * Add CPU shape infer test * Add broadcast shape infer test * Use const * Use const & * Improve query input handling * Fix test issues * Change broadcast rules --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
committed by
GitHub
parent
3a2958b360
commit
610e0fab5c
@@ -0,0 +1,114 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/scaled_dot_product_attention.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v13 {
|
||||
template <class T, class TRShape = result_shape_t<T>>
|
||||
std::vector<TRShape> shape_infer(const ScaledDotProductAttention* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
const ITensorAccessor& ta = make_tensor_accessor()) {
|
||||
const bool& iscausal = op->get_causal();
|
||||
using DimType = typename T::value_type;
|
||||
const auto& inputs_count = input_shapes.size();
|
||||
const auto& has_attention_mask_input = inputs_count >= 4;
|
||||
const auto& has_scale_input = inputs_count == 5;
|
||||
NODE_VALIDATION_CHECK(op, inputs_count == 3 || has_attention_mask_input || has_scale_input);
|
||||
DimType e_dim{};
|
||||
DimType l_dim{};
|
||||
DimType s_dim{};
|
||||
DimType ev_dim{};
|
||||
|
||||
TRShape n_dims = input_shapes[0];
|
||||
const auto& n_dims_rank = n_dims.rank();
|
||||
if (n_dims_rank.is_static()) {
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
n_dims_rank.get_length() >= 3,
|
||||
"Query input rank length must be at least 3 or more.");
|
||||
l_dim = *(n_dims.end() - 2);
|
||||
e_dim = *(n_dims.end() - 1);
|
||||
n_dims.resize(n_dims.size() - 2);
|
||||
}
|
||||
|
||||
const auto& key = input_shapes[1];
|
||||
const auto& key_rank = key.rank();
|
||||
if (key_rank.is_static()) {
|
||||
const bool& key_input_correctness =
|
||||
key_rank.get_length() >= 3 &&
|
||||
TRShape::broadcast_merge_into(n_dims,
|
||||
TRShape(std::vector<DimType>(key.begin(), key.end() - 2)),
|
||||
AutoBroadcastType::NUMPY) &&
|
||||
DimType::merge(e_dim, e_dim, *(key.end() - 1));
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
key_input_correctness,
|
||||
"Key input shape not compatible with other inputs.");
|
||||
s_dim = *(key.end() - 2);
|
||||
}
|
||||
|
||||
const auto& value = input_shapes[2];
|
||||
const auto& value_rank = value.rank();
|
||||
if (value_rank.is_static()) {
|
||||
const bool& value_input_correctness =
|
||||
value_rank.get_length() >= 3 &&
|
||||
TRShape::broadcast_merge_into(n_dims,
|
||||
TRShape(std::vector<DimType>(value.begin(), value.end() - 2)),
|
||||
AutoBroadcastType::NUMPY) &&
|
||||
DimType::merge(s_dim, s_dim, *(value.end() - 2));
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
value_input_correctness,
|
||||
"Value input shape not compatible with other inputs.");
|
||||
ev_dim = *(value.end() - 1);
|
||||
}
|
||||
|
||||
if (has_attention_mask_input && !iscausal) {
|
||||
const auto& attention_mask = input_shapes[3];
|
||||
const auto& attention_mask_rank = attention_mask.rank();
|
||||
if (attention_mask_rank.is_static() && attention_mask_rank != 0) {
|
||||
const auto& attention_mask_rank_len = attention_mask_rank.get_length();
|
||||
bool attention_mask_input_correctness = attention_mask_rank_len >= 2 &&
|
||||
DimType::merge(l_dim, l_dim, *(attention_mask.end() - 2)) &&
|
||||
DimType::merge(s_dim, s_dim, *(attention_mask.end() - 1));
|
||||
if (attention_mask_rank_len >= 3) {
|
||||
attention_mask_input_correctness =
|
||||
attention_mask_input_correctness &&
|
||||
TRShape::broadcast_merge_into(
|
||||
n_dims,
|
||||
TRShape(std::vector<DimType>(attention_mask.begin(), attention_mask.end() - 2)),
|
||||
AutoBroadcastType::NUMPY);
|
||||
}
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
attention_mask_input_correctness,
|
||||
"Attention mask input shape not compatible with other inputs.");
|
||||
}
|
||||
}
|
||||
|
||||
if (has_scale_input) {
|
||||
const auto& scale_rank = input_shapes[4].rank();
|
||||
if (scale_rank.is_static() || input_shapes[4].is_static()) {
|
||||
const auto& scale_is_scalar = scale_rank.compatible(0);
|
||||
const auto& scale_has_one_elem = scale_rank.compatible(1) && input_shapes[4][0].compatible(1);
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
scale_is_scalar || scale_has_one_elem,
|
||||
"Scale input must be scalar or have 1 element.");
|
||||
}
|
||||
}
|
||||
|
||||
if (n_dims.rank().is_static()) {
|
||||
n_dims.push_back(l_dim);
|
||||
n_dims.push_back(ev_dim);
|
||||
}
|
||||
return {n_dims};
|
||||
}
|
||||
} // namespace v13
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "openvino/op/scaled_dot_product_attention.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "scaled_dot_product_attention_shape_inference.hpp"
|
||||
|
||||
using namespace std;
|
||||
namespace ov {
|
||||
@@ -38,25 +39,33 @@ op::v13::ScaledDotProductAttention::ScaledDotProductAttention(const Output<Node>
|
||||
|
||||
void op::v13::ScaledDotProductAttention::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v13_ScaledDotProductAttention_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(this, get_input_size() >= 3 && get_input_size() <= 5);
|
||||
// TODO: More checks and accurate deduction of dimensions in case when various
|
||||
// dynamic combinations appear.
|
||||
auto query = get_input_partial_shape(0);
|
||||
auto key = get_input_partial_shape(1);
|
||||
auto value = get_input_partial_shape(2);
|
||||
|
||||
// using particular dimensions from query and value, to do that need to have them statically ranked
|
||||
if (query.rank().is_dynamic() || value.rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
auto out_type = get_input_element_type(0);
|
||||
const auto& input_size = get_input_size();
|
||||
const auto& causal = get_causal();
|
||||
if (input_size >= 4 && !causal) {
|
||||
const auto& attention_type = get_input_element_type(3);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
attention_type.is_real() || attention_type == element::boolean || attention_type.is_dynamic(),
|
||||
"The element type of attention_mask must be either floating-point or boolean.");
|
||||
}
|
||||
for (size_t i = 1; i < input_size; i++) {
|
||||
const auto& element_type = get_input_element_type(i);
|
||||
if (i == 3 && (element_type == element::boolean || causal)) {
|
||||
// Skip checking attention_mask in loop when boolean or skipped to not affect merged dtype.
|
||||
continue;
|
||||
}
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(out_type, out_type, element_type),
|
||||
"Mixed input types are not supported.");
|
||||
}
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
out_type.is_real() || out_type.is_dynamic(),
|
||||
"The element type of the input tensor must be a floating-point.");
|
||||
|
||||
OPENVINO_ASSERT(query.rank().get_length() >= 3);
|
||||
OPENVINO_ASSERT(value.rank().get_length() >= 3);
|
||||
|
||||
auto dimensions = std::vector<Dimension>(query.begin(), query.end() - 1);
|
||||
dimensions.push_back(*(value.end() - 1));
|
||||
set_output_type(0, get_input_element_type(0), PartialShape(dimensions));
|
||||
const auto& input_shapes = ov::util::get_node_input_partial_shapes(*this);
|
||||
const auto& output_shapes = shape_infer(this, input_shapes);
|
||||
set_output_type(0, out_type, output_shapes[0]);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v13::ScaledDotProductAttention::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
|
||||
@@ -12,156 +12,382 @@
|
||||
using namespace ov;
|
||||
using namespace testing;
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_5_inputs) {
|
||||
TEST(type_prop, scaled_dot_product_attention_static_5_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{1});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn =
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (Shape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_4_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 3, 5});
|
||||
TEST(type_prop, scaled_dot_product_attention_static_4_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_3_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
|
||||
TEST(type_prop, scaled_dot_product_attention_static_3_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_3_inputs_causal) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
|
||||
TEST(type_prop, scaled_dot_product_attention_static_3_inputs_causal) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
auto causal = true;
|
||||
|
||||
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_ignored_attention_mask) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{7, 8, 9, 10, 11});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
|
||||
TEST(type_prop, scaled_dot_product_scalar_attention_causal_false) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn =
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 3, 6}));
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_5_inputs_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 1, 3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_4_inputs_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 1, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_3_inputs_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_3_inputs_extra_batch_causal_true) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
|
||||
TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 8, 9, 10, 11});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = true;
|
||||
|
||||
const auto gn = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_static_ignored_attention_mask_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 7, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{7, 8, 9, 10, 11});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
|
||||
TEST(type_prop, scaled_dot_product_attention_static_5_inputs_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 1, 3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn =
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_shape(), (Shape{2, 7, 3, 6}));
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_dynamic_3d) {
|
||||
TEST(type_prop, scaled_dot_product_attention_static_4_inputs_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 1, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 1, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_static_3_inputs_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 6});
|
||||
auto causal = false;
|
||||
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_static_3_inputs_extra_batch_causal_true) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 5, 6});
|
||||
auto causal = true;
|
||||
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask_extra_batch) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 1, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 7, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{7, 8, 9, 10, 11});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = true;
|
||||
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_5_inputs_dynamic_rank) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto causal = false;
|
||||
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_dynamic_3d) {
|
||||
const auto dynamic = PartialShape{-1, -1, -1};
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn =
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_output_partial_shape(0), (dynamic));
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (dynamic));
|
||||
}
|
||||
|
||||
TEST(type_prop, scale_dot_product_attention_dynamic_4d) {
|
||||
TEST(type_prop, scaled_dot_product_attention_dynamic_4d) {
|
||||
const auto dynamic = PartialShape{-1, -1, -1, -1};
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
const auto gn =
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(gn->get_element_type(), element::f32);
|
||||
EXPECT_EQ(gn->get_output_partial_shape(0), (dynamic));
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (dynamic));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_mixed_shape_infer_5_inputs) {
|
||||
PartialShape query_shape{{2, 4}, 3, {2, 5}, 4};
|
||||
set_shape_labels(query_shape, 10);
|
||||
PartialShape key_shape{{4, 8}, {2, 4}, 5, 4};
|
||||
set_shape_labels(key_shape, 20);
|
||||
PartialShape value_shape{{2, 4}, 3, 5, {3, 7}};
|
||||
set_shape_labels(value_shape, 40);
|
||||
PartialShape attention_mask_shape{{2, 7}, 3, {4, 7}, 5};
|
||||
set_shape_labels(attention_mask_shape, 50);
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, query_shape);
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f64, key_shape);
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, value_shape);
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, attention_mask_shape);
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f64, PartialShape{-1});
|
||||
auto causal = false;
|
||||
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f64);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, {4, 5}, {3, 7}}));
|
||||
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), testing::ElementsAre(50, 51, 52, 43));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_mixed_shape_infer_5_inputs_ignore_attention) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{1, 4}, 3, {1, 5}, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{{4, 8}, {1, 4}, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{2, 4}, 3, 5, {3, 7}});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::i64, PartialShape{57, 3, {4, 7}, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f64, PartialShape{});
|
||||
auto causal = true;
|
||||
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f64);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, {1, 5}, {3, 7}}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_infer_5_dynamic_attn_partial) {
|
||||
const auto dynamic = PartialShape::dynamic();
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, dynamic);
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, -1, 5, 7});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (dynamic));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_mixed_shape_infer_4_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{1, 4}, 4, {2, 5}, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{{2, 8}, {1, 4}, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{{2, 4}, 4, 5, {3, 7}});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, PartialShape{4, {4, 7}, 5});
|
||||
auto causal = false;
|
||||
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f64);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{2, 4}, 4, {4, 5}, {3, 7}}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_type_infer_5_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{1, 3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
const auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_type_infer_4_inputs) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, PartialShape{3, 5});
|
||||
auto causal = false;
|
||||
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f64);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_attention_type_infer_4_inputs_bool_attention) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::boolean, PartialShape{1, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
const auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f64);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_unsupported_key_shape) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("Key input shape not compatible with other inputs."));
|
||||
}
|
||||
TEST(type_prop, scaled_dot_product_unsupported_value_shape) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("Value input shape not compatible with other inputs."));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_unsupported_attention_shape) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("Attention mask input shape not compatible with other inputs."));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_unsupported_scale_shape) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2});
|
||||
auto causal = false;
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("Scale input must be scalar or have 1 element."));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_unsupported_dtype) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::i32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::i32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::i32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::boolean, PartialShape{3, 3, 3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::i32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("The element type of the input tensor must be a floating-point."));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_unsupported_value_dtype_mixed) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f64, PartialShape{3, 3, 3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{});
|
||||
auto causal = false;
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
auto op =
|
||||
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("Mixed input types are not supported."));
|
||||
}
|
||||
|
||||
TEST(type_prop, scaled_dot_product_unsuported_attention_type) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f64, PartialShape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::dynamic, PartialShape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::i32, PartialShape{1, 3, 5});
|
||||
auto causal = false;
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
auto op = std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("The element type of attention_mask must be either floating-point or boolean."));
|
||||
}
|
||||
|
||||
@@ -94,6 +94,7 @@
|
||||
#include "roi_align_shape_inference.hpp"
|
||||
#include "roi_pooling_shape_inference.hpp"
|
||||
#include "roll_shape_inference.hpp"
|
||||
#include "scaled_dot_product_attention_shape_inference.hpp"
|
||||
#include "scatter_elements_update_shape_inference.hpp"
|
||||
#include "scatter_nd_base_shape_inference.hpp"
|
||||
#include "select_shape_inference.hpp"
|
||||
@@ -398,6 +399,7 @@ template <>
|
||||
const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
|
||||
// opset13
|
||||
_OV_OP_SHAPE_INFER_MASK_REG(opset13::Multinomial, ShapeInferTA, util::bit::mask(1)),
|
||||
_OV_OP_SHAPE_INFER_MASK_REG(opset13::ScaledDotProductAttention, ShapeInferTA, util::bit::mask(3, 5)),
|
||||
// opset12
|
||||
_OV_OP_SHAPE_INFER_MASK_REG(opset12::Pad, ShapeInferTA, util::bit::mask(1, 2)),
|
||||
_OV_OP_SHAPE_INFER_MASK_REG(opset12::ScatterElementsUpdate, ShapeInferTA, util::bit::mask(3)),
|
||||
|
||||
@@ -23,9 +23,9 @@ const std::vector<std::vector<InputShape>> shapes{
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||
},
|
||||
// attn shape: [B, 1, 1, L0+L1]
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 1, 1, -1},
|
||||
{ov::Shape{1, 1, 1, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 1, 10}}}
|
||||
// attn shape: [B, 1, -1, L0+L1]
|
||||
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1},
|
||||
{ov::Shape{1, 1, 100, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 10, 10}}}
|
||||
},
|
||||
},
|
||||
// heads number of kv is 1, attn mask: [B, H, L1, L0+L1]
|
||||
@@ -75,4 +75,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_CPU,
|
||||
|
||||
} // namespace ScaledAttn
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "openvino/opsets/opset13.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace testing;
|
||||
|
||||
class ScaledDotProductAttentionV13StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v13::ScaledDotProductAttention> {
|
||||
};
|
||||
|
||||
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, default_ctor) {
|
||||
op = make_op();
|
||||
|
||||
input_shapes = ShapeVector{{3, 2, 3, 4}, {2, 5, 4}, {1, 5, 6}, {1, 3, 5}, {}};
|
||||
output_shapes = shape_inference(op.get(), input_shapes);
|
||||
EXPECT_EQ(output_shapes.size(), 1);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({3, 2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, dynamic_shapes) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1});
|
||||
auto causal = false;
|
||||
op = make_op(query, key, value, attention_mask, scale, causal);
|
||||
|
||||
input_shapes = ShapeVector{{2, 3, 4}, {2, 5, 4}, {2, 5, 6}, {1, 3, 5}, {}};
|
||||
output_shapes = shape_inference(op.get(), input_shapes);
|
||||
EXPECT_EQ(output_shapes.size(), 1);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, static_shapes) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 4});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, Shape{1, 3, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{1});
|
||||
auto causal = false;
|
||||
op = make_op(query, key, value, attention_mask, scale, causal);
|
||||
|
||||
input_shapes = ShapeVector{{2, 3, 4}, {2, 5, 4}, {2, 5, 6}, {1, 3, 5}, {1}};
|
||||
output_shapes = shape_inference(op.get(), input_shapes);
|
||||
EXPECT_EQ(output_shapes.size(), 1);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST_F(ScaledDotProductAttentionV13StaticShapeInferenceTest, mixed_shapes) {
|
||||
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, {2, 3}, 4});
|
||||
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{-1, {2, 7}, -1});
|
||||
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{2, 5, 6});
|
||||
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, {3, 5}, 5});
|
||||
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
|
||||
auto causal = false;
|
||||
op = make_op(query, key, value, attention_mask, scale, causal);
|
||||
|
||||
input_shapes = ShapeVector{{2, 3, 4}, {2, 5, 4}, {2, 5, 6}, {1, 3, 5}, {}};
|
||||
output_shapes = shape_inference(op.get(), input_shapes);
|
||||
EXPECT_EQ(output_shapes.size(), 1);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({2, 3, 6}));
|
||||
}
|
||||
Reference in New Issue
Block a user