Supports case min == max for clamp (#5196)

This commit is contained in:
Ilya Churaev 2021-04-13 07:19:21 +03:00 committed by GitHub
parent 64fd5734fe
commit c0764a5d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 5 deletions

View File

@ -134,8 +134,8 @@ void op::Clamp::validate_and_infer_types()
"Input element type must be numeric. Got: ",
input_et);
NODE_VALIDATION_CHECK(this,
m_min < m_max,
"Attribute 'min' must be less than 'max'. Got: ",
m_min <= m_max,
"Attribute 'min' must be less or equal than 'max'. Got: ",
m_min,
" and ",
m_max);

View File

@ -75,19 +75,30 @@ TEST(type_prop, clamp_invalid_element_type)
}
}
TEST(type_prop, clamp_equal_attributes)
{
auto data = make_shared<op::Parameter>(element::f64, Shape{2, 2});
auto clamp = make_shared<op::Clamp>(data, 1.0, 1.0);
ASSERT_EQ(clamp->get_element_type(), element::f64);
ASSERT_EQ(clamp->get_min(), 1.0);
ASSERT_EQ(clamp->get_max(), 1.0);
ASSERT_EQ(clamp->get_output_shape(0), (Shape{2, 2}));
}
TEST(type_prop, clamp_invalid_attributes)
{
auto data = make_shared<op::Parameter>(element::f64, Shape{2, 2});
try
{
auto clamp = make_shared<op::Clamp>(data, 1.0, 1.0);
auto clamp = make_shared<op::Clamp>(data, 2.0, 1.0);
// Attribute 'max' not greater than 'min'
FAIL() << "Attribute 'min' equal to 'max' not detected";
FAIL() << "Attribute 'min' bigger than 'max' not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Attribute 'min' must be less than 'max'");
EXPECT_HAS_SUBSTRING(error.what(), "Attribute 'min' must be less or equal than 'max'");
}
catch (...)
{