Files
openvino/ngraph/python/tests/test_ngraph/test_input_validation.py

170 lines
5.4 KiB
Python

# ******************************************************************************
# Copyright 2017-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import pytest
from ngraph.exceptions import UserInputError
from ngraph.utils.input_validation import (
_check_value,
check_valid_attribute,
check_valid_attributes,
is_non_negative_value,
is_positive_value,
)
@pytest.mark.parametrize("dtype", [np.int8, np.int16, np.int32, np.int64, np.float32, np.float64])
def test_is_positive_value_signed_type(dtype):
assert is_positive_value(dtype(16))
assert not is_positive_value(dtype(-16))
@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64])
def test_is_positive_value_unsigned_type(dtype):
assert is_positive_value(dtype(16))
@pytest.mark.parametrize("dtype", [np.int8, np.int16, np.int32, np.int64, np.float32, np.float64])
def test_is_non_negative_value_signed_type(dtype):
assert is_non_negative_value(dtype(16))
assert is_non_negative_value(dtype(0))
assert not is_non_negative_value(dtype(-1))
assert not is_non_negative_value(dtype(-16))
@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64])
def test_is_non_negative_value_unsigned_type(dtype):
assert is_non_negative_value(dtype(16))
assert is_non_negative_value(dtype(0))
@pytest.mark.parametrize(
"value, val_type",
[
(np.int8(64), np.integer),
(np.int16(64), np.integer),
(np.int32(64), np.integer),
(np.int64(64), np.integer),
(np.uint8(64), np.unsignedinteger),
(np.uint16(64), np.unsignedinteger),
(np.uint32(64), np.unsignedinteger),
(np.uint64(64), np.unsignedinteger),
(np.float32(64), np.floating),
(np.float64(64), np.floating),
],
)
def test_check_value(value, val_type):
def is_even(x):
return x % 2 == 0
assert _check_value("TestOp", "test_attr", value, val_type, is_even)
@pytest.mark.parametrize(
"value, val_type",
[
(np.int8(64), np.floating),
(np.int16(64), np.floating),
(np.int32(64), np.floating),
(np.int64(64), np.floating),
(np.uint8(64), np.floating),
(np.uint16(64), np.floating),
(np.uint32(64), np.floating),
(np.uint64(64), np.floating),
(np.float32(64), np.integer),
(np.float64(64), np.integer),
],
)
def test_check_value_fail_type(value, val_type):
try:
_check_value("TestOp", "test_attr", value, val_type, None)
except UserInputError:
pass
else:
raise AssertionError("Type validation has unexpectedly passed.")
@pytest.mark.parametrize(
"value, val_type",
[
(np.int8(61), np.integer),
(np.int16(61), np.integer),
(np.int32(61), np.integer),
(np.int64(61), np.integer),
(np.uint8(61), np.unsignedinteger),
(np.uint16(61), np.unsignedinteger),
(np.uint32(61), np.unsignedinteger),
(np.uint64(61), np.unsignedinteger),
(np.float32(61), np.floating),
(np.float64(61), np.floating),
],
)
def test_check_value_fail_cond(value, val_type):
def is_even(x):
return x % 2 == 0
try:
_check_value("TestOp", "test_attr", value, val_type, is_even)
except UserInputError:
pass
else:
raise AssertionError("Condition validation has unexpectedly passed.")
def test_check_valid_attribute():
attr_dict = {
"mode": "bilinear",
"coefficients": [1, 2, 3, 4, 5],
}
assert check_valid_attribute("TestOp", attr_dict, "width", np.unsignedinteger, required=False)
assert check_valid_attribute("TestOp", attr_dict, "mode", np.str_, required=True)
assert check_valid_attribute("TestOp", attr_dict, "coefficients", np.integer, required=True)
try:
check_valid_attribute("TestOp", attr_dict, "alpha", np.floating, required=True)
except UserInputError:
pass
else:
raise AssertionError("Validation of missing required attribute has unexpectedly passed.")
def test_check_valid_attributes():
attr_dict = {
"mode": "bilinear",
"coefficients": [1, 2, 3, 4, 5],
}
def _is_supported_mode(x):
return x in ["linear", "area", "cubic", "bilinear"]
requirements = [
("width", False, np.unsignedinteger, None),
("mode", True, np.str_, _is_supported_mode),
("coefficients", True, np.integer, lambda x: x > 0),
("alpha", False, np.float64, None),
]
assert check_valid_attributes("TestOp", attr_dict, requirements)
requirements[3] = ("alpha", True, np.float64, None)
try:
check_valid_attributes("TestOp", attr_dict, requirements)
except UserInputError:
pass
else:
raise AssertionError("Validation of missing required attribute has unexpectedly passed.")