[ShapeInference] CTCGreedyDecoder shape infer improvements (#15474)
* Add setter for ctc_merge_repeated * shape infer improvements * Add type prop tests * Add cpu shape infer tests * Tests refactor
This commit is contained in:
parent
18905ada20
commit
d94dae79d8
@ -32,6 +32,10 @@ public:
|
||||
return m_ctc_merge_repeated;
|
||||
}
|
||||
|
||||
void set_ctc_merge_repeated(bool ctc_merge_repeated) {
|
||||
m_ctc_merge_repeated = ctc_merge_repeated;
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_ctc_merge_repeated{true};
|
||||
};
|
||||
|
@ -8,26 +8,27 @@ namespace ov {
|
||||
namespace op {
|
||||
namespace v0 {
|
||||
|
||||
template <class T>
|
||||
void shape_infer(const CTCGreedyDecoder* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
// output dynamic rank tensor if all inputs are of dynamic rank
|
||||
template <class TShape>
|
||||
std::vector<TShape> shape_infer(const CTCGreedyDecoder* op, const std::vector<TShape>& input_shapes) {
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2);
|
||||
using DimType = typename TShape::value_type;
|
||||
|
||||
// Output shape rank is always static and equal to 4
|
||||
// The last two output shape dimensions are always static and equal to 1
|
||||
std::vector<DimType> output_dims(4);
|
||||
output_dims[2] = 1;
|
||||
output_dims[3] = 1;
|
||||
|
||||
const auto& logits_pshape = input_shapes[0];
|
||||
const auto& seq_mask_pshape = input_shapes[1];
|
||||
auto& output_shape = output_shapes[0];
|
||||
output_shape.resize(4);
|
||||
output_shape[2] = 1;
|
||||
output_shape[3] = 1;
|
||||
|
||||
if (logits_pshape.rank().is_dynamic() && seq_mask_pshape.rank().is_dynamic()) {
|
||||
return;
|
||||
return {TShape(std::move(output_dims))};
|
||||
}
|
||||
|
||||
// validate input shapes and compute output shape
|
||||
auto& batch_size = output_shape[0];
|
||||
auto& time_size = output_shape[1];
|
||||
// check ranks of input tensors
|
||||
auto& batch_size = output_dims[0];
|
||||
auto& time_size = output_dims[1];
|
||||
|
||||
if (logits_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op, logits_pshape.rank().compatible(3), "The rank of logits tensor must be equal to 3.");
|
||||
time_size = logits_pshape[0];
|
||||
@ -45,6 +46,14 @@ void shape_infer(const CTCGreedyDecoder* op, const std::vector<T>& input_shapes,
|
||||
DimType::merge(batch_size, batch_size, seq_mask_pshape[1]),
|
||||
"The second dimensions of input tensors must match.");
|
||||
}
|
||||
return {TShape(std::move(output_dims))};
|
||||
}
|
||||
|
||||
template <class TShape>
|
||||
void shape_infer(const CTCGreedyDecoder* op,
|
||||
const std::vector<TShape>& input_shapes,
|
||||
std::vector<TShape>& output_shapes) {
|
||||
output_shapes = shape_infer(op, input_shapes);
|
||||
}
|
||||
} // namespace v0
|
||||
} // namespace op
|
||||
|
@ -2,87 +2,165 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_default_constructor) {
|
||||
PartialShape data_shape{100, 3, 1200};
|
||||
PartialShape seq_mask_shape{100, 3};
|
||||
PartialShape expected_shape{3, 100, 1, 1};
|
||||
|
||||
auto op = make_shared<op::CTCGreedyDecoder>();
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
op->set_arguments(OutputVector{data, seq_mask});
|
||||
|
||||
op->set_ctc_merge_repeated(false);
|
||||
EXPECT_EQ(op->get_ctc_merge_repeated(), false);
|
||||
|
||||
op->set_ctc_merge_repeated(true);
|
||||
EXPECT_EQ(op->get_ctc_merge_repeated(), true);
|
||||
|
||||
op->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), expected_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_static_shapes) {
|
||||
PartialShape logits_shape{100, 3, 1200};
|
||||
PartialShape seq_mask_shape{100, 3};
|
||||
Shape out_shape{3, 100, 1, 1};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_shape(), out_shape);
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
EXPECT_EQ(op->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_all) {
|
||||
PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}};
|
||||
PartialShape seq_mask_shape{{10, 1000}, {4, 8}};
|
||||
PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1};
|
||||
|
||||
set_shape_labels(data_shape, 10);
|
||||
set_shape_labels(seq_mask_shape, 20);
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
EXPECT_EQ(out_shape, expected_shape);
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(21, 20, ov::no_label, ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_data) {
|
||||
PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}};
|
||||
PartialShape seq_mask_shape{{10, 1000}, {4, 8}};
|
||||
PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1};
|
||||
|
||||
set_shape_labels(data_shape, 10);
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
EXPECT_EQ(out_shape, expected_shape);
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(11, 10, ov::no_label, ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_mask) {
|
||||
PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}};
|
||||
PartialShape seq_mask_shape{{10, 1000}, {4, 8}};
|
||||
PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1};
|
||||
|
||||
set_shape_labels(seq_mask_shape, 20);
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(out_shape, expected_shape);
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(21, 20, ov::no_label, ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_output_static_shape1) {
|
||||
PartialShape logits_shape{Dimension::dynamic(), 3, 1200};
|
||||
PartialShape seq_mask_shape{100, 3};
|
||||
Shape out_shape{3, 100, 1, 1};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_shape(), out_shape);
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_output_static_shape2) {
|
||||
PartialShape logits_shape{Dimension::dynamic(), 3, 1200};
|
||||
PartialShape seq_mask_shape{100, Dimension::dynamic()};
|
||||
Shape out_shape{3, 100, 1, 1};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_shape(), out_shape);
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
EXPECT_EQ(op->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_dynamic_shapes) {
|
||||
PartialShape logits_shape{Dimension::dynamic(), Dimension::dynamic(), 1200};
|
||||
PartialShape seq_mask_shape{Dimension::dynamic(), Dimension::dynamic()};
|
||||
PartialShape out_shape{Dimension::dynamic(), Dimension::dynamic(), 1, 1};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_dynamic_ranks1) {
|
||||
PartialShape logits_shape = PartialShape::dynamic();
|
||||
PartialShape seq_mask_shape{100, Dimension::dynamic()};
|
||||
PartialShape out_shape{Dimension::dynamic(), 100, 1, 1};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_dynamic_ranks2) {
|
||||
PartialShape logits_shape = PartialShape::dynamic();
|
||||
PartialShape seq_mask_shape = PartialShape::dynamic();
|
||||
PartialShape out_shape{Dimension::dynamic(), Dimension::dynamic(), 1, 1};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
EXPECT_EQ(op->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_incorrect_rank) {
|
||||
PartialShape logits_shape{Dimension::dynamic(), 3, 1200, 5};
|
||||
PartialShape seq_mask_shape{100, 3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
|
||||
try {
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
@ -95,11 +173,11 @@ TEST(type_prop, ctc_greedy_decoder_incorrect_rank) {
|
||||
TEST(type_prop, ctc_greedy_decoder_incorrect_rank2) {
|
||||
PartialShape logits_shape{Dimension::dynamic(), 3, 1200};
|
||||
PartialShape seq_mask_shape{100, 3, 2};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
|
||||
try {
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
@ -112,11 +190,11 @@ TEST(type_prop, ctc_greedy_decoder_incorrect_rank2) {
|
||||
TEST(type_prop, ctc_greedy_decoder_mismatched_dim1) {
|
||||
PartialShape logits_shape{100, 4, 1200};
|
||||
PartialShape seq_mask_shape{100, 3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
|
||||
try {
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
@ -129,11 +207,11 @@ TEST(type_prop, ctc_greedy_decoder_mismatched_dim1) {
|
||||
TEST(type_prop, ctc_greedy_decoder_mismatched_dim2) {
|
||||
PartialShape logits_shape{101, 3, 1200};
|
||||
PartialShape seq_mask_shape{100, 3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
|
||||
|
||||
try {
|
||||
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
|
||||
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
|
@ -1,25 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ctc_greedy_decoder_shape_inference.hpp>
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(StaticShapeInferenceTest, CtcGreedyDecoderTest) {
|
||||
auto P = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto I = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1});
|
||||
auto G = std::make_shared<op::v0::CTCGreedyDecoder>(P, I, false);
|
||||
// Test StaticShape
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(G.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 100, 1, 1}));
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "ctc_greedy_decoder_shape_inference.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "openvino/op/ops.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace testing;
|
||||
|
||||
class CTCGreedyDecoderV0StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v0::CTCGreedyDecoder> {
|
||||
void SetUp() override {
|
||||
output_shapes.resize(1);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, basic) {
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
|
||||
op = make_op(data, seq_mask, false);
|
||||
|
||||
input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}};
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], StaticShape({3, 100, 1, 1}));
|
||||
}
|
||||
|
||||
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, decoder_default_ctor) {
|
||||
op = make_op();
|
||||
|
||||
input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}};
|
||||
|
||||
shape_infer(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], StaticShape({3, 100, 1, 1}));
|
||||
}
|
||||
|
||||
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, incompatible_batch) {
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
op = make_op(data, seq_mask, false);
|
||||
|
||||
input_shapes = {StaticShape{10, 3, 1200}, StaticShape{100, 3}};
|
||||
|
||||
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("The first dimensions of input tensors must match"))
|
||||
}
|
||||
|
||||
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, incompatible_t_dim) {
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
op = make_op(data, seq_mask, false);
|
||||
|
||||
input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 5}};
|
||||
|
||||
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("The second dimensions of input tensors must match"))
|
||||
}
|
Loading…
Reference in New Issue
Block a user