[OV20] InputTensorInfo::set_shape (#9059)
* InputTensorInfo::set_shape * Fix clang-format
This commit is contained in:
@@ -201,6 +201,37 @@ def test_ngraph_preprocess_spatial_static_shape():
|
||||
assert np.equal(output, expected_output).all()
|
||||
|
||||
|
||||
def test_ngraph_preprocess_set_shape():
|
||||
shape = [1, 1, 1]
|
||||
parameter_a = ops.parameter(shape, dtype=np.int32, name="A")
|
||||
model = parameter_a
|
||||
function = Function(model, [parameter_a], "TestFunction")
|
||||
|
||||
@custom_preprocess_function
|
||||
def custom_crop(out_node: Output):
|
||||
start = ops.constant(np.array([1, 1, 1]), dtype=np.int32)
|
||||
stop = ops.constant(np.array([2, 2, 2]), dtype=np.int32)
|
||||
step = ops.constant(np.array([1, 1, 1]), dtype=np.int32)
|
||||
axis = ops.constant(np.array([0, 1, 2]), dtype=np.int32)
|
||||
return ops.slice(out_node, start, stop, step, axis)
|
||||
|
||||
p = PrePostProcessor(function)
|
||||
inp = p.input()
|
||||
inp.tensor().set_shape([3, 3, 3])
|
||||
inp.preprocess().custom(custom_crop)
|
||||
function = p.build()
|
||||
|
||||
input_data = np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
|
||||
[[9, 10, 11], [12, 13, 14], [15, 16, 17]],
|
||||
[[18, 19, 20], [21, 22, 23], [24, 25, 26]]]).astype(np.int32)
|
||||
expected_output = np.array([[[13]]]).astype(np.float32)
|
||||
|
||||
runtime = get_runtime()
|
||||
computation = runtime.computation(function)
|
||||
output = computation(input_data)
|
||||
assert np.equal(output, expected_output).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"algorithm, color_format1, color_format2, is_failing",
|
||||
[(ResizeAlgorithm.RESIZE_LINEAR, ColorFormat.UNDEFINED, ColorFormat.BGR, True),
|
||||
|
||||
Reference in New Issue
Block a user