Files
openvino/ngraph/python/tests/test_ngraph/test_data_movement.py
2020-10-14 12:20:22 +03:00

218 lines
5.6 KiB
Python

# ******************************************************************************
# Copyright 2017-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import pytest
import ngraph as ng
from ngraph.impl import Type
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
def test_reverse_sequence():
input_data = np.array(
[
0,
0,
3,
0,
6,
0,
9,
0,
1,
0,
4,
0,
7,
0,
10,
0,
2,
0,
5,
0,
8,
0,
11,
0,
12,
0,
15,
0,
18,
0,
21,
0,
13,
0,
16,
0,
19,
0,
22,
0,
14,
0,
17,
0,
20,
0,
23,
0,
],
dtype=np.int32,
).reshape([2, 3, 4, 2])
seq_lenghts = np.array([1, 2, 1, 2], dtype=np.int32)
batch_axis = 2
sequence_axis = 1
input_param = ng.parameter(input_data.shape, name="input", dtype=np.int32)
seq_lengths_param = ng.parameter(seq_lenghts.shape, name="sequence lengths", dtype=np.int32)
model = ng.reverse_sequence(input_param, seq_lengths_param, batch_axis, sequence_axis)
runtime = get_runtime()
computation = runtime.computation(model, input_param, seq_lengths_param)
result = computation(input_data, seq_lenghts)
expected = np.array(
[
0,
0,
4,
0,
6,
0,
10,
0,
1,
0,
3,
0,
7,
0,
9,
0,
2,
0,
5,
0,
8,
0,
11,
0,
12,
0,
16,
0,
18,
0,
22,
0,
13,
0,
15,
0,
19,
0,
21,
0,
14,
0,
17,
0,
20,
0,
23,
0,
],
).reshape([1, 2, 3, 4, 2])
assert np.allclose(result, expected)
def test_pad_edge():
input_data = np.arange(1, 13).reshape([3, 4])
pads_begin = np.array([0, 1], dtype=np.int32)
pads_end = np.array([2, 3], dtype=np.int32)
input_param = ng.parameter(input_data.shape, name="input", dtype=np.int32)
model = ng.pad(input_param, pads_begin, pads_end, "edge")
runtime = get_runtime()
computation = runtime.computation(model, input_param)
result = computation(input_data)
expected = np.array(
[
[1, 1, 2, 3, 4, 4, 4, 4],
[5, 5, 6, 7, 8, 8, 8, 8],
[9, 9, 10, 11, 12, 12, 12, 12],
[9, 9, 10, 11, 12, 12, 12, 12],
[9, 9, 10, 11, 12, 12, 12, 12],
]
)
assert np.allclose(result, expected)
@pytest.mark.xfail(reason="AssertionError")
def test_pad_constant():
input_data = np.arange(1, 13).reshape([3, 4])
pads_begin = np.array([0, 1], dtype=np.int32)
pads_end = np.array([2, 3], dtype=np.int32)
input_param = ng.parameter(input_data.shape, name="input", dtype=np.int32)
model = ng.pad(input_param, pads_begin, pads_end, "constant", arg_pad_value=np.array(100, dtype=np.int32))
runtime = get_runtime()
computation = runtime.computation(model, input_param)
result = computation(input_data)
expected = np.array(
[
[100, 1, 2, 3, 4, 100, 100, 100],
[100, 5, 6, 7, 8, 100, 100, 100],
[100, 9, 10, 11, 12, 100, 100, 100],
[100, 100, 100, 100, 100, 100, 100, 100],
[100, 100, 100, 100, 100, 100, 100, 100],
]
)
assert np.allclose(result, expected)
def test_select():
cond = np.array([[False, False], [True, False], [True, True]])
then_node = np.array([[-1, 0], [1, 2], [3, 4]], dtype=np.int32)
else_node = np.array([[11, 10], [9, 8], [7, 6]], dtype=np.int32)
excepted = np.array([[11, 10], [1, 8], [3, 4]], dtype=np.int32)
result = run_op_node([cond, then_node, else_node], ng.select)
assert np.allclose(result, excepted)
def test_gather_nd():
indices_type = np.int32
data_dtype = np.float32
data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data")
indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices")
batch_dims = 2
expected_shape = [20, 30, 40, 50]
node = ng.gather_nd(data, indices, batch_dims)
assert node.get_type_name() == "GatherND"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32