Normalize_L2 relax constant input restriction (#17568)
* Normalize_L2 relax constant input restriction * Fix warning treated as error during windows build
This commit is contained in:
committed by
GitHub
parent
a880cba9b7
commit
293fccc4fe
@@ -38,8 +38,6 @@ void op::v0::NormalizeL2::validate_and_infer_types() {
|
||||
const auto& input_rank = input_pshape.rank();
|
||||
const auto& axes_rank = axes_pshape.rank();
|
||||
|
||||
NODE_VALIDATION_CHECK(this, has_and_set_equal_bounds(input_value(1)), "Input axes must be Constant type");
|
||||
|
||||
if (axes_rank.is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_rank.get_length() <= 1,
|
||||
|
||||
@@ -39,16 +39,7 @@ TEST(type_prop, normalize_l2_axes_input_not_constant) {
|
||||
auto axes = make_shared<op::Parameter>(element::u64, Shape{1});
|
||||
float eps{1e-6f};
|
||||
auto eps_mode = op::EpsMode::ADD;
|
||||
|
||||
try {
|
||||
auto normalize = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid input tensor rank.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input axes must be Constant type"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
ASSERT_NO_THROW(auto op = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode));
|
||||
}
|
||||
|
||||
TEST(type_prop, normalize_l2_invalid_axes_rank) {
|
||||
|
||||
Reference in New Issue
Block a user