Normalize_L2 relax constant input restriction (#17568)

* Normalize_L2 relax constant input restriction

* Fix warning treated as error during windows build
This commit is contained in:
Evgenya Stepyreva
2023-05-17 16:28:50 +04:00
committed by GitHub
parent a880cba9b7
commit 293fccc4fe
2 changed files with 1 additions and 12 deletions

View File

@@ -38,8 +38,6 @@ void op::v0::NormalizeL2::validate_and_infer_types() {
const auto& input_rank = input_pshape.rank();
const auto& axes_rank = axes_pshape.rank();
NODE_VALIDATION_CHECK(this, has_and_set_equal_bounds(input_value(1)), "Input axes must be Constant type");
if (axes_rank.is_static()) {
NODE_VALIDATION_CHECK(this,
axes_rank.get_length() <= 1,

View File

@@ -39,16 +39,7 @@ TEST(type_prop, normalize_l2_axes_input_not_constant) {
auto axes = make_shared<op::Parameter>(element::u64, Shape{1});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
try {
auto normalize = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input axes must be Constant type"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
ASSERT_NO_THROW(auto op = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode));
}
TEST(type_prop, normalize_l2_invalid_axes_rank) {