[TF FE] Complex type support extended for Separate Bass model. (#21477)

* Complex type support extended, fixed some errors.
* Tests correction.
* FloorDiv, TensorListConcatV2 fixed.
* FloorDiv test added.
* Corrected imports.
This commit is contained in:
Anastasiia Pnevskaia 2023-12-21 12:17:12 +01:00 committed by GitHub
parent 71fca88d4b
commit 12faade22f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 710 additions and 21 deletions

View File

@ -108,10 +108,15 @@ std::vector<TRShape> shape_infer(const StridedSlice* op,
AxisSet begin_mask = convert_mask_to_axis_set(op->get_begin_mask());
AxisSet end_mask = convert_mask_to_axis_set(op->get_end_mask());
AxisSet shrink_axis_mask = convert_mask_to_axis_set(op->get_shrink_axis_mask());
NODE_VALIDATION_CHECK(op,
input_rank + new_axis_mask.size() >= static_cast<size_t>(number_axes),
"Input rank plus number of new axis has to be at least the size of Lower "
"and Upper bounds vector.");
// If ellipsis_mask is set, Lower and Upper bownd vectors can be less than input rank + number of new axes,
// because ellipsis adds missing dimensions, which can be missing in begin or end inputs
if (!ellipsis_mask.size()) {
NODE_VALIDATION_CHECK(op,
input_rank + new_axis_mask.size() >= static_cast<size_t>(number_axes),
"Input rank plus number of new axis has to be at least the size of Lower "
"and Upper bounds vector.");
}
auto& out = output_shapes.front();
out.resize(0);

View File

@ -117,7 +117,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"LogicalXor", CreatorFunction(translate_binary_op<opset8::LogicalXor>)},
{"Maximum", CreatorFunction(translate_binary_op<opset8::Maximum>)},
{"Minimum", CreatorFunction(translate_binary_op<opset8::Minimum>)},
{"Mul", CreatorFunction(translate_binary_op<opset8::Multiply>)},
{"Mul", CreatorFunction(translate_mul_op)},
{"Mod", CreatorFunction(translate_binary_op<opset8::Mod>)},
{"NotEqual", CreatorFunction(translate_binary_op<opset8::NotEqual>)},
{"Pow", CreatorFunction(translate_binary_op<opset8::Power>)},
@ -321,6 +321,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"TensorListStack", CreatorFunction(translate_tensor_list_stack_op)},
{"TensorListReserve", CreatorFunction(translate_tensor_list_reserve_op)},
{"TensorListResize", CreatorFunction(translate_tensor_list_resize_op)},
{"TensorListConcatV2", CreatorFunction(translate_tensor_list_concat_v2_op)},
{"Tile", CreatorFunction(translate_tile_op)},
{"ToBool", CreatorFunction(translate_tobool_op)},
{"TopK", CreatorFunction(translate_top_k_op)},

View File

@ -62,6 +62,7 @@ OP_CONVERTER(translate_crop_and_resize_op);
OP_CONVERTER(translate_depth_to_space_op);
OP_CONVERTER(translate_depthwise_conv_2d_native_op);
OP_CONVERTER(translate_div_no_nan_op);
OP_CONVERTER(translate_mul_op);
OP_CONVERTER(translate_dynamic_partition_op);
OP_CONVERTER(translate_einsum_op);
OP_CONVERTER(translate_elu_op);
@ -154,6 +155,7 @@ OP_CONVERTER(translate_tensor_list_reserve_op);
OP_CONVERTER(translate_tensor_list_set_item_op);
OP_CONVERTER(translate_tensor_list_stack_op);
OP_CONVERTER(translate_tensor_list_resize_op);
OP_CONVERTER(translate_tensor_list_concat_v2_op);
OP_CONVERTER(translate_tile_op);
OP_CONVERTER(translate_tobool_op);
OP_CONVERTER_NAMED(translate_top_k_op);

View File

@ -3,10 +3,13 @@
//
#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/bitwise_and.hpp"
#include "openvino/op/bitwise_or.hpp"
#include "openvino/op/bitwise_xor.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/floor.hpp"
@ -27,6 +30,7 @@
#include "openvino/op/prelu.hpp"
#include "openvino/op/squared_difference.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
using namespace std;
using namespace ov::op;
@ -47,8 +51,16 @@ OutputVector translate_binary_op(const NodeContext& node,
}
OutputVector translate_floor_div_op(const NodeContext& node) {
auto floordiv_fn = [](const Output<Node>& x, const Output<Node>& y) {
return make_shared<v0::Floor>(make_shared<v1::Divide>(x, y));
auto floordiv_fn = [](const Output<Node>& x, const Output<Node>& y) -> shared_ptr<Node> {
auto out_type = x.get_element_type();
if (out_type.is_integral()) {
auto float_x = make_shared<v0::Convert>(x, element::f32);
auto float_y = make_shared<v0::Convert>(y, element::f32);
return make_shared<v0::Convert>(make_shared<v0::Floor>(make_shared<v1::Divide>(float_x, float_y)),
out_type);
} else {
return make_shared<v0::Floor>(make_shared<v1::Divide>(x, y));
}
};
return translate_binary_op(node, floordiv_fn);
}
@ -60,6 +72,56 @@ OutputVector translate_binary_op(const NodeContext& node) {
});
}
OutputVector translate_mul_op(const NodeContext& node) {
default_op_checks(node, 2, {}, true);
auto lhs = node.get_input(0);
auto rhs = node.get_input(1);
auto result = make_shared<v1::Multiply>(lhs, rhs);
auto complex_type_mark_lhs = as_type_ptr<ComplexTypeMark>(lhs.get_node_shared_ptr());
auto complex_type_mark_rhs = as_type_ptr<ComplexTypeMark>(rhs.get_node_shared_ptr());
if (complex_type_mark_lhs || complex_type_mark_rhs) {
FRONT_END_GENERAL_CHECK(complex_type_mark_lhs != nullptr && complex_type_mark_rhs != nullptr,
"Mul gox complex and non-complex inputs. Inputs should be of same type.");
lhs = complex_type_mark_lhs->input_value(0);
rhs = complex_type_mark_rhs->input_value(0);
element::Type complex_part_type_lhs = complex_type_mark_lhs->get_complex_part_type();
element::Type complex_part_type_rhs = complex_type_mark_rhs->get_complex_part_type();
FRONT_END_GENERAL_CHECK(complex_part_type_lhs == complex_part_type_rhs,
"Mul got complex inputs of different types. Inputs should be of same type.");
auto gather_index_real = make_shared<v0::Constant>(element::i32, Shape{}, 0);
auto gather_index_imag = make_shared<v0::Constant>(element::i32, Shape{}, 1);
auto minus_one = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
auto lhs_real = make_shared<v8::Gather>(lhs, gather_index_real, minus_one)->output(0);
auto lhs_imag = make_shared<v8::Gather>(lhs, gather_index_imag, minus_one)->output(0);
auto rhs_real = make_shared<v8::Gather>(rhs, gather_index_real, minus_one)->output(0);
auto rhs_imag = make_shared<v8::Gather>(rhs, gather_index_imag, minus_one)->output(0);
// result_real = lhs_real * rhs_real - lhs_imag * rhs_imag
auto result_real = make_shared<v1::Subtract>(make_shared<v1::Multiply>(lhs_real, rhs_real),
make_shared<v1::Multiply>(lhs_imag, rhs_imag));
// result_imag = lhs_real * rhs_imag + lhs_imag * rhs_real
auto result_imag = make_shared<v1::Add>(make_shared<v1::Multiply>(lhs_real, rhs_imag),
make_shared<v1::Multiply>(lhs_imag, rhs_real));
auto real_unsqueeze = make_shared<v0::Unsqueeze>(result_real, minus_one);
auto imag_unsqueeze = make_shared<v0::Unsqueeze>(result_imag, minus_one);
auto concat_result = make_shared<v0::Concat>(OutputVector{real_unsqueeze, imag_unsqueeze}, -1);
set_node_name(node.get_name(), concat_result);
auto complex_result = make_shared<ComplexTypeMark>(concat_result->output(0), complex_part_type_lhs);
return {complex_result};
}
set_node_name(node.get_name(), result);
return {result};
}
template OutputVector translate_binary_op<v1::Add>(const NodeContext& node);
template OutputVector translate_binary_op<v13::BitwiseAnd>(const NodeContext& node);
template OutputVector translate_binary_op<v13::BitwiseOr>(const NodeContext& node);

View File

@ -24,7 +24,7 @@ OutputVector translate_identity_op(const NodeContext& node) {
"MergeV2Checkpoints",
// TF Lite nodes
"DENSIFY"};
default_op_checks(node, 1, supported_ops);
default_op_checks(node, 1, supported_ops, true);
auto input = node.get_input(0);
// set only tensor names

View File

@ -5,6 +5,9 @@
#include "openvino/op/reshape.hpp"
#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
using namespace std;
using namespace ov::op;
@ -15,9 +18,25 @@ namespace tensorflow {
namespace op {
OutputVector translate_reshape_op(const NodeContext& node) {
default_op_checks(node, 2, {"Reshape"});
default_op_checks(node, 2, {"Reshape"}, true);
auto tensor = node.get_input(0);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(tensor.get_node_shared_ptr());
auto shape = node.get_input(1);
if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
tensor = complex_type_mark->input_value(0);
OutputVector concat_inputs;
concat_inputs.push_back(shape);
concat_inputs.push_back(make_shared<v0::Constant>(shape.get_element_type(), Shape{1}, 2));
auto concat = make_shared<v0::Concat>(concat_inputs, 0);
auto reshape = make_shared<v1::Reshape>(tensor, concat, false);
set_node_name(node.get_name(), reshape);
auto complex_reshape = make_shared<ComplexTypeMark>(reshape, complex_part_type);
return {complex_reshape->output(0)};
}
auto reshape = make_shared<v1::Reshape>(tensor, shape, false);
set_node_name(node.get_name(), reshape);
return {reshape};

View File

@ -3,7 +3,9 @@
//
#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
using namespace std;
using namespace ov;
@ -14,27 +16,58 @@ namespace frontend {
namespace tensorflow {
namespace op {
std::shared_ptr<v8::Slice> compute_complex_shape(const ov::Output<ov::Node>& input, element::Type out_type) {
auto shapeof = make_shared<v3::ShapeOf>(input, out_type);
auto rank = make_shared<v3::ShapeOf>(shapeof, out_type);
auto one = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
auto start = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto stop = make_shared<v1::Subtract>(rank, one);
auto step = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
auto axes = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
return make_shared<v8::Slice>(shapeof, start, stop, step, axes);
}
OutputVector translate_shape_op(const NodeContext& node) {
default_op_checks(node, 1, {"Shape", "ShapeN", "SHAPE"});
default_op_checks(node, 1, {"Shape", "ShapeN", "SHAPE"}, true);
auto input_size = static_cast<int>(node.get_input_size());
auto out_type = node.get_attribute<element::Type>("out_type", element::i32);
auto node_name = node.get_name();
if (input_size == 1) {
auto input = node.get_input(0);
auto shapeof = make_shared<v3::ShapeOf>(input, out_type);
set_node_name(node_name, shapeof);
return {shapeof};
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
if (complex_type_mark) {
auto slice = compute_complex_shape(complex_type_mark->input_value(0), out_type);
set_node_name(node_name, slice);
return {slice};
} else {
auto shapeof = make_shared<v3::ShapeOf>(input, out_type);
set_node_name(node_name, shapeof);
return {shapeof};
}
}
OutputVector outputs;
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
auto input = node.get_input(input_ind);
auto shapeof = make_shared<v3::ShapeOf>(input, out_type);
shapeof->set_friendly_name(node_name + "_" + to_string(input_ind));
auto shapeof_output = shapeof->output(0);
set_out_name({node_name + ":" + to_string(input_ind)}, shapeof_output);
outputs.push_back(shapeof_output);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
if (complex_type_mark) {
auto slice = compute_complex_shape(complex_type_mark->input_value(input_ind), out_type);
slice->set_friendly_name(node_name + "_" + to_string(input_ind));
auto shapeof_output = slice->output(0);
set_out_name({node_name + ":" + to_string(input_ind)}, shapeof_output);
outputs.push_back(shapeof_output);
} else {
auto shapeof = make_shared<v3::ShapeOf>(input, out_type);
shapeof->set_friendly_name(node_name + "_" + to_string(input_ind));
auto shapeof_output = shapeof->output(0);
set_out_name({node_name + ":" + to_string(input_ind)}, shapeof_output);
outputs.push_back(shapeof_output);
}
}
return outputs;

View File

@ -5,7 +5,11 @@
#include "openvino/op/squeeze.hpp"
#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/floor_mod.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov::op;
@ -16,8 +20,10 @@ namespace tensorflow {
namespace op {
OutputVector translate_squeeze_op(const NodeContext& node) {
default_op_checks(node, 1, {"Squeeze", "SQUEEZE"});
default_op_checks(node, 1, {"Squeeze", "SQUEEZE"}, true);
auto input = node.get_input(0);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
std::vector<int64_t> axis;
if (node.has_attribute("axis")) {
axis = node.get_attribute<std::vector<int64_t>>("axis", {});
@ -26,6 +32,24 @@ OutputVector translate_squeeze_op(const NodeContext& node) {
axis = node.get_attribute<std::vector<int64_t>>("squeeze_dims", {});
}
auto axis_const = make_shared<v0::Constant>(element::i32, Shape{axis.size()}, axis);
if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
input = complex_type_mark->input_value(0);
auto input_rank = compute_subgraph_scalar_rank(input, element::i32, true);
auto const_one = make_shared<v0::Constant>(element::i32, Shape{}, 1);
auto input_rank_minus_one = make_shared<v1::Subtract>(input_rank, const_one)->output(0);
// adjust axis to make them non-negative
auto axis_complex = make_shared<v1::FloorMod>(axis_const, input_rank_minus_one);
auto squeeze = make_shared<v0::Squeeze>(input, axis_complex);
set_node_name(node.get_name(), squeeze);
auto squeeze_complex = make_shared<ComplexTypeMark>(squeeze, complex_part_type);
return {squeeze_complex->output(0)};
}
auto squeeze = make_shared<v0::Squeeze>(input, axis_const);
set_node_name(node.get_name(), squeeze);
return {squeeze};

View File

@ -7,6 +7,9 @@
#include <climits>
#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
using namespace std;
using namespace ov::op;
@ -17,7 +20,7 @@ namespace tensorflow {
namespace op {
OutputVector translate_strided_slice_op(const NodeContext& node) {
default_op_checks(node, 4, {"StridedSlice", "STRIDED_SLICE"});
default_op_checks(node, 4, {"StridedSlice", "STRIDED_SLICE"}, true);
auto input = node.get_input(0);
auto begin = node.get_input(1);
auto end = node.get_input(2);
@ -50,12 +53,41 @@ OutputVector translate_strided_slice_op(const NodeContext& node) {
// the masks can be of different length and we need to align them by the maximum length
size_t max_length = std::max(
{begin_mask.size(), end_mask.size(), new_axis_mask.size(), ellipsis_mask.size(), shrink_axis_mask.size()});
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
element::Type complex_part_type = element::dynamic;
std::vector<int64_t> begin_axes;
if (complex_type_mark) {
complex_part_type = complex_type_mark->get_complex_part_type();
input = complex_type_mark->input_value(0);
TENSORFLOW_OP_VALIDATION(node,
as_type_ptr<v0::Constant>(node.get_input(1).get_node_shared_ptr()),
"StridedSlice for complex values is not supported with non-constant begin");
get_const_input(node, 1, &begin_axes);
max_length = std::max(begin_axes.size() + 1, max_length);
}
begin_mask.resize(max_length, 0);
end_mask.resize(max_length, 0);
new_axis_mask.resize(max_length, 0);
ellipsis_mask.resize(max_length, 0);
shrink_axis_mask.resize(max_length, 0);
if (complex_type_mark) {
auto zero = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto one = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
begin = make_shared<v0::Concat>(OutputVector{begin, zero}, 0);
end = make_shared<v0::Concat>(OutputVector{end, zero}, 0);
strides = make_shared<v0::Concat>(OutputVector{strides, one}, 0);
begin_mask[begin_axes.size()] = 1;
end_mask[begin_axes.size()] = 1;
new_axis_mask[begin_axes.size()] = 0;
ellipsis_mask[begin_axes.size()] = 0;
shrink_axis_mask[begin_axes.size()] = 0;
}
auto strided_slice = make_shared<v1::StridedSlice>(input,
begin,
end,
@ -66,6 +98,12 @@ OutputVector translate_strided_slice_op(const NodeContext& node) {
shrink_axis_mask,
ellipsis_mask);
set_node_name(node.get_name(), strided_slice);
if (complex_type_mark) {
auto complex_strided_slice = make_shared<ComplexTypeMark>(strided_slice, complex_part_type);
return {complex_strided_slice->output(0)};
}
return {strided_slice};
}

View File

@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <climits>
#include "common_op_table.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
@ -222,6 +224,36 @@ OutputVector translate_tensor_list_length_op(const NodeContext& node) {
return {list_length};
}
OutputVector translate_tensor_list_concat_v2_op(const NodeContext& node) {
default_op_checks(node, 2, {"TensorListConcatV2"});
auto input_handle = node.get_input(0);
auto size = node.get_input(1);
std::vector<int64_t> leading_dims;
get_const_input(node, 2, &leading_dims);
TENSORFLOW_OP_VALIDATION(node,
leading_dims.size() == 0,
"TensorListConcatV2 is not supported for non-empty leading_dims.");
TENSORFLOW_OP_VALIDATION(node,
as_type_ptr<v0::Constant>(node.get_input(1).get_node_shared_ptr()),
"TensorListConcatV2 is not supported with non-constant shape input");
std::vector<int64_t> list_elememt_shape;
get_const_input(node, 1, &list_elememt_shape);
list_elememt_shape[0] = list_elememt_shape[0] * input_handle.get_partial_shape()[0].get_max_length();
auto out = make_shared<v1::Reshape>(
input_handle,
make_shared<v0::Constant>(element::i64, Shape{list_elememt_shape.size()}, list_elememt_shape),
false);
set_node_name(node.get_name(), out);
return {out};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend

View File

@ -5,6 +5,11 @@
#include "openvino/op/transpose.hpp"
#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov::op;
@ -15,9 +20,30 @@ namespace tensorflow {
namespace op {
OutputVector translate_transpose_op(const NodeContext& node) {
default_op_checks(node, 2, {"Transpose", "TRANSPOSE"});
default_op_checks(node, 2, {"Transpose", "TRANSPOSE"}, true);
auto x = node.get_input(0);
auto perm = node.get_input(1);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());
if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
x = complex_type_mark->input_value(0);
auto input_rank = compute_subgraph_scalar_rank(x, element::i32, false);
auto const_one = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
auto input_rank_minus_one = make_shared<v1::Subtract>(input_rank, const_one)->output(0);
OutputVector concat_inputs;
concat_inputs.push_back(perm);
concat_inputs.push_back(input_rank_minus_one);
auto concat = make_shared<v0::Concat>(concat_inputs, 0);
auto transpose = make_shared<v1::Transpose>(x, concat);
set_node_name(node.get_name(), transpose);
auto complex_transpose = make_shared<ComplexTypeMark>(transpose, complex_part_type);
return {complex_transpose->output(0)};
}
auto transpose = make_shared<v1::Transpose>(x, perm);
set_node_name(node.get_name(), transpose);
return {transpose};

View File

@ -0,0 +1,51 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import permute_nchw_to_nhwc
class TestFloorDiv(CommonTFLayerTest):
def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_frontend):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(dtype, x_shape, 'Input')
constant_value = np.array(-10).astype(dtype)
y = tf.constant(constant_value)
x = tf.raw_ops.Abs(x=x)
res = tf.raw_ops.FloorDiv(x=x, y=y)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
ref_net = None
return tf_net, ref_net
# TODO: implement tests for 2 Consts + Add
test_data_1D = [
dict(x_shape=[], dtype=np.int32),
dict(x_shape=[2], dtype=np.int64),
dict(x_shape=[2, 4, 5], dtype=np.int32),
dict(x_shape=[], dtype=np.float32),
dict(x_shape=[2], dtype=np.float64),
dict(x_shape=[2, 4, 5], dtype=np.float32),
]
@pytest.mark.parametrize("params", test_data_1D)
@pytest.mark.nightly
@pytest.mark.precommit_tf_fe
def test_add_placeholder_const_1D(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_add_placeholder_const_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -247,3 +247,58 @@ class TestMul(CommonTFLayerTest):
use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestComplexMul(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'param_real_1' in inputs_info
assert 'param_imag_1' in inputs_info
assert 'param_real_2' in inputs_info
assert 'param_imag_2' in inputs_info
param_real_shape_1 = inputs_info['param_real_1']
param_imag_shape_1 = inputs_info['param_imag_1']
param_real_shape_2 = inputs_info['param_real_2']
param_imag_shape_2 = inputs_info['param_imag_2']
inputs_data = {}
inputs_data['param_real_1'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag_1'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
inputs_data['param_real_2'] = 4 * rng.random(param_real_shape_2).astype(np.float32) - 2
inputs_data['param_imag_2'] = 4 * rng.random(param_imag_shape_2).astype(np.float32) - 2
return inputs_data
def create_complex_mul_net(self, input_shape):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real_1')
param_imag1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag_1')
param_real2 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real_2')
param_imag2 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag_2')
complex1 = tf.raw_ops.Complex(real=param_real1, imag=param_imag1)
complex2 = tf.raw_ops.Complex(real=param_real2, imag=param_imag2)
mul = tf.raw_ops.Mul(x=complex1, y=complex2, name="complex_mul")
real = tf.raw_ops.Real(input=mul)
img = tf.raw_ops.Imag(input=mul)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[]),
dict(input_shape=[2]),
dict(input_shape=[1, 3]),
dict(input_shape=[2, 3, 4]),
dict(input_shape=[3, 4, 5, 6]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_mul(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(
*self.create_complex_mul_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -49,3 +49,47 @@ class TestReshape(CommonTFLayerTest):
self._test(*self.create_reshape_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestComplexReshape(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'param_real' in inputs_info
assert 'param_imag' in inputs_info
param_real_shape_1 = inputs_info['param_real']
param_imag_shape_1 = inputs_info['param_imag']
inputs_data = {}
inputs_data['param_real'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
return inputs_data
def create_complex_transpose_net(self, input_shape, target_shape):
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real')
param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag')
complex = tf.raw_ops.Complex(real=param_real, imag=param_imag)
transpose = tf.raw_ops.Reshape(tensor=complex, shape=target_shape)
real = tf.raw_ops.Real(input=transpose)
img = tf.raw_ops.Imag(input=transpose)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[2, 6], target_shape=[2, 3, 2]),
dict(input_shape=[2, 4, 5], target_shape=[4, -1, 5]),
dict(input_shape=[1], target_shape=[])
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_reshape(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(
*self.create_complex_transpose_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -51,3 +51,46 @@ class TestShape(CommonTFLayerTest):
self._test(*self.create_shape_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestComplexShape(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'param_real' in inputs_info
assert 'param_imag' in inputs_info
param_real_shape_1 = inputs_info['param_real']
param_imag_shape_1 = inputs_info['param_imag']
inputs_data = {}
inputs_data['param_real'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
return inputs_data
def create_complex_shape_net(self, input_shape):
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real')
param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag')
complex = tf.raw_ops.Complex(real=param_real, imag=param_imag)
out = tf.raw_ops.Shape(input=complex, name="Shape")
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[]),
dict(input_shape=[2, 3]),
dict(input_shape=[2, 4, 3]),
dict(input_shape=[2, 5, 3, 6, 8]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_shape(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(
*self.create_complex_shape_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -114,3 +114,49 @@ class TestSqueeze(CommonTFLayerTest):
self._test(*self.create_squeeze_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestComplexSqueeze(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'param_real' in inputs_info
assert 'param_imag' in inputs_info
param_real_shape_1 = inputs_info['param_real']
param_imag_shape_1 = inputs_info['param_imag']
inputs_data = {}
inputs_data['param_real'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
return inputs_data
def create_complex_squeeze_net(self, input_shape, axis):
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real')
param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag')
complex = tf.raw_ops.Complex(real=param_real, imag=param_imag)
squeeze = tf.raw_ops.Squeeze(input=complex, axis=axis)
real = tf.raw_ops.Real(input=squeeze)
img = tf.raw_ops.Imag(input=squeeze)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[1], axis=[0]),
dict(input_shape=[3, 1], axis=[]),
dict(input_shape=[2, 3, 1], axis=[-1]),
dict(input_shape=[1, 10, 1, 5], axis=[0, 2]),
dict(input_shape=[1, 22, 1, 1, 10], axis=[0, 2, -2]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_squeeze(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(
*self.create_complex_squeeze_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -34,6 +34,11 @@ class TestStridedSlice(CommonTFLayerTest):
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=8, shrink_axis_mask=0),
dict(input_shape=[3, 4, 5, 7], begin_value=[2, 0, 3], end_value=[3, 0, 6], strides_value=[1, 1, 1],
begin_mask=6, end_mask=6, ellipsis_mask=2, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[1, 4, 7, 2], begin_value=[0, 0, 0], end_value=[0, 6, 0], strides_value=[1, 1, 1],
begin_mask=6, end_mask=4, ellipsis_mask=1, new_axis_mask=0, shrink_axis_mask=0),
dict(input_shape=[1, 4, 7, 2], begin_value=[0, 0, 0], end_value=[0, 6, 0], strides_value=[1, 1, 1],
begin_mask=6, end_mask=4, ellipsis_mask=1, new_axis_mask=8, shrink_axis_mask=0),
]
@pytest.mark.parametrize('params', test_basic_data)
@ -113,3 +118,112 @@ class TestStridedSlice(CommonTFLayerTest):
self._test(*self.create_strided_slice_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestComplexStridedSlice(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
import numpy as np
rng = np.random.default_rng()
assert 'param_real' in inputs_info
assert 'param_imag' in inputs_info
param_real_shape_1 = inputs_info['param_real']
param_imag_shape_1 = inputs_info['param_imag']
inputs_data = {}
inputs_data['param_real'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
return inputs_data
def create_complex_strided_slice_net(self, input_shape, begin_value, end_value, strides_value, begin_mask, end_mask,
ellipsis_mask,
new_axis_mask, shrink_axis_mask):
import tensorflow as tf
import numpy as np
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real')
param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag')
complex = tf.raw_ops.Complex(real=param_real, imag=param_imag)
#transpose = tf.raw_ops.Squeeze(input=complex, axis=axis)
begin = tf.constant(begin_value, dtype=tf.int32)
end = tf.constant(end_value, dtype=tf.int32)
strides = tf.constant(strides_value, dtype=tf.int32)
strided_slice = tf.raw_ops.StridedSlice(input=complex, begin=begin, end=end, strides=strides, begin_mask=begin_mask,
end_mask=end_mask, ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask)
real = tf.raw_ops.Real(input=strided_slice)
img = tf.raw_ops.Imag(input=strided_slice)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[2, 5, 4, 3], begin_value=[1, 0, 2, 0], end_value=[2, 5, 4, 2], strides_value=[1, 2, 1, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[1, 5, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 5, 3], strides_value=[1, 2, 3, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=8, shrink_axis_mask=0),
dict(input_shape=[3, 4, 5, 7], begin_value=[2, 0, 3], end_value=[3, 0, 6], strides_value=[1, 1, 1],
begin_mask=6, end_mask=6, ellipsis_mask=2, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[1, 4, 7, 2], begin_value=[0, 0, 0], end_value=[0, 6, 0], strides_value=[1, 1, 1],
begin_mask=6, end_mask=4, ellipsis_mask=1, new_axis_mask=0, shrink_axis_mask=0),
dict(input_shape=[1, 3, 7, 2], begin_value=[0, 0, 0], end_value=[0, 6, 0], strides_value=[1, 1, 1],
begin_mask=6, end_mask=4, ellipsis_mask=1, new_axis_mask=8, shrink_axis_mask=0),
dict(input_shape=[1, 5], begin_value=[0, 0], end_value=[1, 5], strides_value=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[5, 1], begin_value=[0, 0], end_value=[5, 1], strides_value=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=1),
dict(input_shape=[1, 1, 3], begin_value=[0, 0, 0], end_value=[1, 1, 3], strides_value=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
dict(input_shape=[1, 5, 1], begin_value=[0, 0, 0], end_value=[1, 5, 1], strides_value=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
dict(input_shape=[1, 1, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 1, 5, 3], strides_value=[1, 1, 1, 1],
begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=2),
dict(input_shape=[1, 5, 1, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 1, 3], strides_value=[1, 1, 1, 1],
begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=4),
dict(input_shape=[1, 5, 5, 1], begin_value=[0, 0, 0, 0], end_value=[1, 5, 1, 1], strides_value=[1, 1, 1, 1],
begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=8),
dict(input_shape=[1, 1, 5, 5, 3], begin_value=[0, 0, 0, 0, 0], end_value=[1, 1, 5, 5, 3],
strides_value=[1, 1, 1, 1, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=3),
dict(input_shape=[1, 5, 1, 5, 3], begin_value=[0, 0, 0, 0, 0], end_value=[1, 5, 1, 5, 3],
strides_value=[1, 1, 1, 1, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=5),
dict(input_shape=[1, 5, 1, 5, 1], begin_value=[0, 0, 0, 0, 0], end_value=[1, 5, 1, 5, 1],
strides_value=[1, 1, 1, 1, 1],
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=21),
dict(input_shape=[1, 5], begin_value=[0, 0], end_value=[1, 5], strides_value=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=1, shrink_axis_mask=0),
dict(input_shape=[1, 5], begin_value=[0, 0], end_value=[1, 5], strides_value=[1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=3, shrink_axis_mask=0),
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
dict(input_shape=[1, 5, 3], begin_value=[0, 0, 0], end_value=[1, 5, 3], strides_value=[1, 1, 1], begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=5, shrink_axis_mask=0),
dict(input_shape=[1, 5, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 5, 3], strides_value=[1, 1, 1, 1],
begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=4, shrink_axis_mask=0),
dict(input_shape=[1, 5, 5, 3], begin_value=[0, 0, 0, 0], end_value=[1, 5, 5, 3], strides_value=[1, 1, 1, 1],
begin_mask=0,
end_mask=0, ellipsis_mask=0, new_axis_mask=2, shrink_axis_mask=0),
dict(input_shape=[16, 4, 64], begin_value=[0, 0, 0, 0], end_value=[0, 0, 0, 0], strides_value=[1, 1, 1, 1],
begin_mask=19,
end_mask=19, ellipsis_mask=0, new_axis_mask=12, shrink_axis_mask=0),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_strided_slice(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(
*self.create_complex_strided_slice_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -0,0 +1,49 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from sys import platform
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
class TestTensorListConcatV2(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'x' in inputs_info
x_shape = inputs_info['x']
inputs_data = {}
inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type)
return inputs_data
def create_tensor_list_resize(self, input_shape, input_type):
self.input_type = input_type
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(input_type, input_shape, 'x')
tensor_list = tf.raw_ops.TensorListFromTensor(tensor=x,
element_shape=tf.constant(input_shape[1:], dtype=tf.int32))
tf.raw_ops.TensorListConcatV2(input_handle=tensor_list, element_shape=tf.constant(input_shape[1:], dtype=tf.int32),
element_dtype=input_type,
leading_dims=tf.constant([], dtype=tf.int64))
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[10, 20], input_type=np.float32),
dict(input_shape=[2, 3, 4], input_type=np.int32),
dict(input_shape=[3,2,4], input_type=np.float32),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
def test_tensor_list_resize_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_tensor_list_resize(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -32,3 +32,48 @@ class TestTranspose(CommonTFLayerTest):
self._test(*self.create_transpose_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestComplexTranspose(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
import numpy as np
rng = np.random.default_rng()
assert 'param_real' in inputs_info
assert 'param_imag' in inputs_info
param_real_shape_1 = inputs_info['param_real']
param_imag_shape_1 = inputs_info['param_imag']
inputs_data = {}
inputs_data['param_real'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2
inputs_data['param_imag'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2
return inputs_data
def create_complex_transpose_net(self, input_shape, perm_value):
import numpy as np
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real')
param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag')
complex = tf.raw_ops.Complex(real=param_real, imag=param_imag)
transpose = tf.raw_ops.Transpose(x=complex, perm=perm_value)
real = tf.raw_ops.Real(input=transpose)
img = tf.raw_ops.Imag(input=transpose)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(input_shape=[2, 4], perm_value=[1, 0]),
dict(input_shape=[2, 1, 3, 4], perm_value=[2, 0, 1, 3]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_transpose(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(
*self.create_complex_transpose_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)