Files
openvino/tests/layer_tests/tensorflow_tests/test_tf_Select.py
Roman Kazantsev af6ed211d6 [TF FE] Support TF2 Object Detection models (#14979)
* [TF FE] Support TF2 Object detection models

For support of OOB conversion of OD models (Faster RCNN, SSD models) several fixes were done
for Select, BroadcastArgs, Slice, and Concat operations.
Implement tests for each case

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Switch off Transpose Sinking that breaks some model conversion

* Apply code-review feedback: copyright and extra commented out code

* Mention that for concat this is workaround

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
2023-01-09 17:36:42 +03:00

55 lines
2.3 KiB
Python

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestSelect(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'cond' in inputs_info, "Test error: inputs_info must contain `cond`"
assert 'x' in inputs_info, "Test error: inputs_info must contain `x`"
assert 'y' in inputs_info, "Test error: inputs_info must contain `y`"
cond_shape = inputs_info['cond']
x_shape = inputs_info['x']
y_shape = inputs_info['y']
inputs_data = {}
inputs_data['cond'] = np.random.randint(0, 2, cond_shape).astype(bool)
inputs_data['x'] = np.random.randint(-100, 100, x_shape).astype(np.float32)
inputs_data['y'] = np.random.randint(-100, 100, y_shape).astype(np.float32)
return inputs_data
def create_select_net(self, cond_shape, x_shape, y_shape):
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
cond = tf.compat.v1.placeholder(tf.bool, cond_shape, 'cond')
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
y = tf.compat.v1.placeholder(tf.float32, y_shape, 'y')
tf.raw_ops.Select(condition=cond, x=x, y=y, name='select')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(cond_shape=[], x_shape=[], y_shape=[]),
dict(cond_shape=[], x_shape=[3, 2, 4], y_shape=[3, 2, 4]),
dict(cond_shape=[2], x_shape=[2, 4, 5], y_shape=[2, 4, 5]),
dict(cond_shape=[2, 3, 4], x_shape=[2, 3, 4], y_shape=[2, 3, 4]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_select_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
if not use_new_frontend:
pytest.skip("Select tests are not passing for the legacy frontend.")
self._test(*self.create_select_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)