Files
openvino/tests/layer_tests/onnx_tests/test_scatter.py
Ruslan Nugmanov 236778aeec Refactor of renaming ov libraries for layer tests with key --use_new_frontend (#12846)
* refactor of renaming libraries in layer tests

* 1. adds check for old API and new FE usafe
2. refactor of api_2 arg

* fix for tf_NMS test preprocessing

* take libs path from LD_LIBRARY_PATH env

* convert str to Path object

* use wheels path to libs

* print lib paths

* print lib paths

* use ov_frontend_path env

* also check if file to rename exists

* removes redundant prints

* copy instead of rename

* 1. copy instead of rename
2. adds some details to readme
2022-09-20 13:43:37 +04:00

131 lines
4.9 KiB
Python

# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from common.layer_test_class import check_ir_version
from common.onnx_layer_test_class import Caffe2OnnxLayerTest
from unit_tests.utils.graph import build_graph
class TestScatters(Caffe2OnnxLayerTest):
op = None
def create_net(self, input_shape, indices_shape, updates_shape, output_shape,
axis, ir_version):
"""
ONNX net IR net
Input->Scatter->Output => Parameter->ScatterElementsUpdate->Result
"""
#
# Create ONNX model
#
import onnx
from onnx import helper
from onnx import TensorProto
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, input_shape)
indices = helper.make_tensor_value_info('indices', TensorProto.INT64, indices_shape)
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT, indices_shape)
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape)
params = {'axis': axis} if axis is not None else {}
node_def = onnx.helper.make_node(
self.op,
inputs=['data', 'indices', 'updates'],
outputs=['output'],
**params,
)
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[node_def],
'test_model',
[data, indices, updates],
[output],
)
# Create the model (ModelProto)
onnx_net = helper.make_model(graph_def, producer_name='test_model')
#
# Create reference IR net
#
ref_net = None
if check_ir_version(10, None, ir_version):
nodes_attributes = {
# comparison in these tests starts from input node, as we have 3 of them IREngine gets confused
# and takes the first input node in inputs list sorted by lexicographical order
'1_input': {'kind': 'op', 'type': 'Parameter'},
'input_data': {'shape': input_shape, 'kind': 'data'},
'2_indices': {'kind': 'op', 'type': 'Parameter'},
'indices_data': {'shape': indices_shape, 'kind': 'data'},
'3_updates': {'kind': 'op', 'type': 'Parameter'},
'updates_data': {'shape': updates_shape, 'kind': 'data'},
'const_indata': {'kind': 'data',
'value': np.int64(axis) if axis is not None else np.int64(0)},
'const': {'kind': 'op', 'type': 'Const'},
'const_data': {'kind': 'data'},
'node': {'kind': 'op', 'type': 'ScatterElementsUpdate'},
'node_data': {'shape': output_shape, 'kind': 'data'},
'result': {'kind': 'op', 'type': 'Result'}
}
ref_net = build_graph(nodes_attributes,
[
('1_input', 'input_data'),
('input_data', 'node', {'in': 0}),
('2_indices', 'indices_data'),
('indices_data', 'node', {'in': 1}),
('3_updates', 'updates_data'),
('updates_data', 'node', {'in': 2}),
('const_indata', 'const'),
('const', 'const_data'),
('const_data', 'node', {'in': 3}),
('node', 'node_data'),
('node_data', 'result')
])
return onnx_net, ref_net
test_data = [
dict(input_shape=[1, 5], indices_shape=[1, 2], updates_shape=[1, 2],
axis=1, output_shape=[1, 5]),
dict(input_shape=[1, 256, 200, 272], indices_shape=[1, 256, 200, 272],
updates_shape=[1, 256, 200, 272],
axis=None, output_shape=[1, 256, 200, 272])]
class TestScatter(TestScatters):
op = 'Scatter'
@pytest.mark.parametrize("params", test_data)
@pytest.mark.nightly
def test_scatter(self, params, ie_device, precision, ir_version, temp_dir, use_old_api):
self._test(*self.create_net(**params, ir_version=ir_version), ie_device, precision,
ir_version,
temp_dir=temp_dir, use_old_api=use_old_api)
class TestScatterElements(TestScatters):
op = 'ScatterElements'
@pytest.mark.parametrize("params", test_data)
@pytest.mark.nightly
def test_scatter_elements(self, params, ie_device, precision, ir_version, temp_dir, use_old_api):
self._test(*self.create_net(**params, ir_version=ir_version), ie_device, precision,
ir_version,
temp_dir=temp_dir, use_old_api=use_old_api)