Add a check for negative axis attr out of bounds (#4543)

This commit is contained in:
Bartosz Lesniewski
2021-03-15 11:41:43 +01:00
committed by GitHub
parent 24fb09edb3
commit 635ffc760a
2 changed files with 40 additions and 2 deletions

View File

@@ -71,10 +71,14 @@ void op::Concat::validate_and_infer_types()
}
auto concat_axis = get_concatenation_axis();
NODE_VALIDATION_CHECK(this,
concat_axis < this_input_rank.get_length(),
concat_axis < this_input_rank.get_length() && concat_axis >= 0,
"Concatenation axis (",
concat_axis,
") is out of bounds for ",
") is out of bounds [",
-this_input_rank.get_length(),
", ",
this_input_rank.get_length() - 1,
"] for ",
"argument ",
i,
", which has shape ",

View File

@@ -367,3 +367,37 @@ TEST(type_prop, concat_partial_all_static_with_concat_axis_static_dims_incompati
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, concat_partial_negative_axis_correct)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{3, 2, 4});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{7, 2, 4});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4});
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, -3);
ASSERT_EQ(c->get_element_type(), element::f32);
ASSERT_EQ(c->get_shape(), (Shape{12, 2, 4}));
}
TEST(type_prop, concat_partial_negative_axis_incorrect)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4});
try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, -4);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect negative axis value not detected (out of bounds)";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (-1) is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}