Add a check for negative axis attr out of bounds (#4543)
This commit is contained in:
committed by
GitHub
parent
24fb09edb3
commit
635ffc760a
@@ -71,10 +71,14 @@ void op::Concat::validate_and_infer_types()
|
||||
}
|
||||
auto concat_axis = get_concatenation_axis();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
concat_axis < this_input_rank.get_length(),
|
||||
concat_axis < this_input_rank.get_length() && concat_axis >= 0,
|
||||
"Concatenation axis (",
|
||||
concat_axis,
|
||||
") is out of bounds for ",
|
||||
") is out of bounds [",
|
||||
-this_input_rank.get_length(),
|
||||
", ",
|
||||
this_input_rank.get_length() - 1,
|
||||
"] for ",
|
||||
"argument ",
|
||||
i,
|
||||
", which has shape ",
|
||||
|
||||
@@ -367,3 +367,37 @@ TEST(type_prop, concat_partial_all_static_with_concat_axis_static_dims_incompati
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, concat_partial_negative_axis_correct)
|
||||
{
|
||||
auto param0 = make_shared<op::Parameter>(element::f32, Shape{3, 2, 4});
|
||||
auto param1 = make_shared<op::Parameter>(element::f32, Shape{7, 2, 4});
|
||||
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4});
|
||||
|
||||
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, -3);
|
||||
|
||||
ASSERT_EQ(c->get_element_type(), element::f32);
|
||||
ASSERT_EQ(c->get_shape(), (Shape{12, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(type_prop, concat_partial_negative_axis_incorrect)
|
||||
{
|
||||
auto param0 = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 7, 4});
|
||||
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4});
|
||||
|
||||
try
|
||||
{
|
||||
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, -4);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect negative axis value not detected (out of bounds)";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (-1) is out of bounds"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user