# ****************************************************************************** # Copyright 2017-2020 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ****************************************************************************** import numpy as np import pytest import ngraph as ng from ngraph.impl import Type from tests.runtime import get_runtime from tests.test_ngraph.util import run_op_node def test_reverse_sequence(): input_data = np.array( [ 0, 0, 3, 0, 6, 0, 9, 0, 1, 0, 4, 0, 7, 0, 10, 0, 2, 0, 5, 0, 8, 0, 11, 0, 12, 0, 15, 0, 18, 0, 21, 0, 13, 0, 16, 0, 19, 0, 22, 0, 14, 0, 17, 0, 20, 0, 23, 0, ], dtype=np.int32, ).reshape([2, 3, 4, 2]) seq_lenghts = np.array([1, 2, 1, 2], dtype=np.int32) batch_axis = 2 sequence_axis = 1 input_param = ng.parameter(input_data.shape, name="input", dtype=np.int32) seq_lengths_param = ng.parameter(seq_lenghts.shape, name="sequence lengths", dtype=np.int32) model = ng.reverse_sequence(input_param, seq_lengths_param, batch_axis, sequence_axis) runtime = get_runtime() computation = runtime.computation(model, input_param, seq_lengths_param) result = computation(input_data, seq_lenghts) expected = np.array( [ 0, 0, 4, 0, 6, 0, 10, 0, 1, 0, 3, 0, 7, 0, 9, 0, 2, 0, 5, 0, 8, 0, 11, 0, 12, 0, 16, 0, 18, 0, 22, 0, 13, 0, 15, 0, 19, 0, 21, 0, 14, 0, 17, 0, 20, 0, 23, 0, ], ).reshape([1, 2, 3, 4, 2]) assert np.allclose(result, expected) def test_pad_edge(): input_data = np.arange(1, 13).reshape([3, 4]) pads_begin = np.array([0, 1], dtype=np.int32) pads_end = np.array([2, 3], dtype=np.int32) input_param = ng.parameter(input_data.shape, name="input", dtype=np.int32) model = ng.pad(input_param, pads_begin, pads_end, "edge") runtime = get_runtime() computation = runtime.computation(model, input_param) result = computation(input_data) expected = np.array( [ [1, 1, 2, 3, 4, 4, 4, 4], [5, 5, 6, 7, 8, 8, 8, 8], [9, 9, 10, 11, 12, 12, 12, 12], [9, 9, 10, 11, 12, 12, 12, 12], [9, 9, 10, 11, 12, 12, 12, 12], ] ) assert np.allclose(result, expected) @pytest.mark.xfail(reason="AssertionError") def test_pad_constant(): input_data = np.arange(1, 13).reshape([3, 4]) pads_begin = np.array([0, 1], dtype=np.int32) pads_end = np.array([2, 3], dtype=np.int32) input_param = ng.parameter(input_data.shape, name="input", dtype=np.int32) model = ng.pad(input_param, pads_begin, pads_end, "constant", arg_pad_value=np.array(100, dtype=np.int32)) runtime = get_runtime() computation = runtime.computation(model, input_param) result = computation(input_data) expected = np.array( [ [100, 1, 2, 3, 4, 100, 100, 100], [100, 5, 6, 7, 8, 100, 100, 100], [100, 9, 10, 11, 12, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100, 100], ] ) assert np.allclose(result, expected) def test_select(): cond = np.array([[False, False], [True, False], [True, True]]) then_node = np.array([[-1, 0], [1, 2], [3, 4]], dtype=np.int32) else_node = np.array([[11, 10], [9, 8], [7, 6]], dtype=np.int32) excepted = np.array([[11, 10], [1, 8], [3, 4]], dtype=np.int32) result = run_op_node([cond, then_node, else_node], ng.select) assert np.allclose(result, excepted) def test_gather_nd(): indices_type = np.int32 data_dtype = np.float32 data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data") indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices") batch_dims = 2 expected_shape = [20, 30, 40, 50] node = ng.gather_nd(data, indices, batch_dims) assert node.get_type_name() == "GatherND" assert node.get_output_size() == 1 assert list(node.get_output_shape(0)) == expected_shape assert node.get_output_element_type(0) == Type.f32