[TF FE] Refactor RandomUniform support and provide more test coverage (#14847)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
1ef17c507d
commit
36a16c8441
@ -13,27 +13,37 @@ namespace frontend {
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace op {
|
namespace op {
|
||||||
ov::OutputVector translate_random_uniform_op(const NodeContext& node) {
|
ov::OutputVector translate_random_uniform_op(const NodeContext& node) {
|
||||||
|
default_op_checks(node, 1, {"RandomUniform"});
|
||||||
auto shape = node.get_input(0);
|
auto shape = node.get_input(0);
|
||||||
|
|
||||||
|
// retrieve attributes
|
||||||
auto seed = node.get_attribute<int64_t>("seed", 0);
|
auto seed = node.get_attribute<int64_t>("seed", 0);
|
||||||
auto seed2 = node.get_attribute<int64_t>("seed2", 0);
|
auto seed2 = node.get_attribute<int64_t>("seed2", 0);
|
||||||
auto minval_const = make_shared<Constant>(element::f32, Shape{}, 0);
|
auto output_type = node.get_attribute<ov::element::Type>("dtype");
|
||||||
auto maxval_const = make_shared<Constant>(element::f32, Shape{}, 1);
|
|
||||||
auto ng_et = node.get_attribute<ov::element::Type>("dtype");
|
auto minval = make_shared<Constant>(output_type, Shape{}, 0);
|
||||||
auto res = std::make_shared<RandomUniform>(shape, minval_const, maxval_const, ng_et, seed, seed2);
|
auto maxval = make_shared<Constant>(output_type, Shape{}, 1);
|
||||||
set_node_name(node.get_name(), res);
|
auto random = std::make_shared<RandomUniform>(shape, minval, maxval, output_type, seed, seed2);
|
||||||
return res->outputs();
|
|
||||||
|
set_node_name(node.get_name(), random);
|
||||||
|
return random->outputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
ov::OutputVector translate_random_uniform_int_op(const NodeContext& node) {
|
ov::OutputVector translate_random_uniform_int_op(const NodeContext& node) {
|
||||||
|
default_op_checks(node, 3, {"RandomUniformInt"});
|
||||||
auto shape = node.get_input(0);
|
auto shape = node.get_input(0);
|
||||||
auto minval = node.get_input(1);
|
auto minval = node.get_input(1);
|
||||||
auto maxval = node.get_input(2);
|
auto maxval = node.get_input(2);
|
||||||
|
|
||||||
|
// retrieve attributes
|
||||||
auto seed = node.get_attribute<int64_t>("seed", 0);
|
auto seed = node.get_attribute<int64_t>("seed", 0);
|
||||||
auto seed2 = node.get_attribute<int64_t>("seed2", 0);
|
auto seed2 = node.get_attribute<int64_t>("seed2", 0);
|
||||||
auto ng_et = minval.get_element_type();
|
|
||||||
auto res = std::make_shared<RandomUniform>(shape, minval, maxval, ng_et, seed, seed2);
|
auto output_type = minval.get_element_type();
|
||||||
set_node_name(node.get_name(), res);
|
auto random = std::make_shared<RandomUniform>(shape, minval, maxval, output_type, seed, seed2);
|
||||||
return res->outputs();
|
|
||||||
|
set_node_name(node.get_name(), random);
|
||||||
|
return random->outputs();
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -5,14 +5,13 @@ import pytest
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from common.layer_test_class import check_ir_version
|
from common.layer_test_class import check_ir_version
|
||||||
from common.tf_layer_test_class import CommonTFLayerTest
|
from common.tf_layer_test_class import CommonTFLayerTest
|
||||||
from common.utils.tf_utils import permute_nchw_to_nhwc
|
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
|
||||||
|
|
||||||
|
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||||
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, connect, \
|
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, connect, \
|
||||||
shaped_data, connect_front
|
shaped_data, connect_front
|
||||||
|
|
||||||
|
|
||||||
class TestTFRandomUniform(CommonTFLayerTest):
|
class TestRandomUniform(CommonTFLayerTest):
|
||||||
def create_tf_random_uniform_net(self, global_seed, op_seed, x_shape, min_val, max_val,
|
def create_tf_random_uniform_net(self, global_seed, op_seed, x_shape, min_val, max_val,
|
||||||
input_type, precision,
|
input_type, precision,
|
||||||
ir_version, use_new_frontend):
|
ir_version, use_new_frontend):
|
||||||
@ -20,23 +19,18 @@ class TestTFRandomUniform(CommonTFLayerTest):
|
|||||||
|
|
||||||
# Create the graph and model
|
# Create the graph and model
|
||||||
with tf.compat.v1.Session() as sess:
|
with tf.compat.v1.Session() as sess:
|
||||||
tf_x_shape = x_shape.copy()
|
x = tf.compat.v1.placeholder(input_type, x_shape, 'Input')
|
||||||
|
|
||||||
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)
|
|
||||||
|
|
||||||
x = tf.compat.v1.placeholder(input_type, tf_x_shape, 'Input')
|
|
||||||
if global_seed is not None:
|
if global_seed is not None:
|
||||||
tf.compat.v1.random.set_random_seed(global_seed)
|
tf.compat.v1.random.set_random_seed(global_seed)
|
||||||
random_uniform = tf.random.uniform(tf_x_shape, seed=op_seed, dtype=input_type,
|
tf.random.uniform(x_shape, seed=op_seed, dtype=input_type,
|
||||||
minval=min_val,
|
minval=min_val,
|
||||||
maxval=max_val) + x
|
maxval=max_val) + x
|
||||||
|
|
||||||
tf.compat.v1.global_variables_initializer()
|
tf.compat.v1.global_variables_initializer()
|
||||||
tf_net = sess.graph_def
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
ref_net = None
|
ref_net = None
|
||||||
if check_ir_version(10, None, ir_version) and not use_new_frontend:
|
if check_ir_version(10, None, ir_version):
|
||||||
|
|
||||||
const_for_layer_tests = lambda name, value, shape, shape1: {
|
const_for_layer_tests = lambda name, value, shape, shape1: {
|
||||||
**{name + '_dd': {'kind': 'data', 'value': value, 'shape': shape1}},
|
**{name + '_dd': {'kind': 'data', 'value': value, 'shape': shape1}},
|
||||||
**{name: {'kind': 'op', 'type': 'Const'}},
|
**{name: {'kind': 'op', 'type': 'Const'}},
|
||||||
@ -83,25 +77,41 @@ class TestTFRandomUniform(CommonTFLayerTest):
|
|||||||
|
|
||||||
return tf_net, ref_net
|
return tf_net, ref_net
|
||||||
|
|
||||||
test_data = [pytest.param(
|
test_data_basic = [
|
||||||
dict(global_seed=32465, op_seed=48971, min_val=0.0, max_val=1.0, x_shape=[3, 7],
|
dict(global_seed=32465, op_seed=48971, min_val=0.0, max_val=1.0, x_shape=[3, 7],
|
||||||
input_type=tf.float32),
|
input_type=tf.float32),
|
||||||
marks=pytest.mark.precommit),
|
dict(global_seed=78132, op_seed=None, min_val=-200, max_val=-50, x_shape=[5, 8],
|
||||||
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[6],
|
input_type=tf.int32)
|
||||||
input_type=tf.float32),
|
]
|
||||||
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[1, 2, 1, 1],
|
|
||||||
input_type=tf.float32),
|
|
||||||
pytest.param(dict(global_seed=78132, op_seed=None, min_val=-200, max_val=-50, x_shape=[5, 8],
|
|
||||||
input_type=tf.int32), marks=pytest.mark.precommit_tf_fe),
|
|
||||||
dict(global_seed=4571, op_seed=48971, min_val=1.5, max_val=2.3, x_shape=[7],
|
|
||||||
input_type=tf.float32),
|
|
||||||
dict(global_seed=32465, op_seed=12335, min_val=-150, max_val=-100, x_shape=[18],
|
|
||||||
input_type=tf.int32)]
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_data)
|
@pytest.mark.parametrize("params", test_data_basic)
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
def test_tf_random_uniform(self, params, ie_device, precision, ir_version, temp_dir,
|
@pytest.mark.precommit
|
||||||
use_new_frontend, use_old_api):
|
@pytest.mark.precommit_tf_fe
|
||||||
|
def test_random_uniform_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||||
|
use_new_frontend, use_old_api):
|
||||||
|
if ie_device == 'GPU':
|
||||||
|
pytest.skip("RandomUniform is not supported on GPU")
|
||||||
|
self._test(
|
||||||
|
*self.create_tf_random_uniform_net(**params, precision=precision, ir_version=ir_version,
|
||||||
|
use_new_frontend=use_new_frontend), ie_device,
|
||||||
|
precision, temp_dir=temp_dir, ir_version=ir_version, use_new_frontend=use_new_frontend,
|
||||||
|
use_old_api=use_old_api, **params)
|
||||||
|
|
||||||
|
test_data_other = [
|
||||||
|
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[6],
|
||||||
|
input_type=tf.float32),
|
||||||
|
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[1, 2, 1, 1],
|
||||||
|
input_type=tf.float32),
|
||||||
|
dict(global_seed=4571, op_seed=48971, min_val=1.5, max_val=2.3, x_shape=[7],
|
||||||
|
input_type=tf.float32),
|
||||||
|
dict(global_seed=32465, op_seed=12335, min_val=-150, max_val=-100, x_shape=[18],
|
||||||
|
input_type=tf.int32)]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("params", test_data_other)
|
||||||
|
@pytest.mark.nightly
|
||||||
|
def test_random_uniform_other(self, params, ie_device, precision, ir_version, temp_dir,
|
||||||
|
use_new_frontend, use_old_api):
|
||||||
if ie_device == 'GPU':
|
if ie_device == 'GPU':
|
||||||
pytest.skip("RandomUniform is not supported on GPU")
|
pytest.skip("RandomUniform is not supported on GPU")
|
||||||
self._test(
|
self._test(
|
||||||
|
Loading…
Reference in New Issue
Block a user