diff --git a/src/common/transformations/src/transformations/common_optimizations/gru_cell_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/gru_cell_fusion.cpp index e5eae04c640..5b3aaec614f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/gru_cell_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/gru_cell_fusion.cpp @@ -148,6 +148,15 @@ ov::pass::GRUCellFusion::GRUCellFusion() { Bh = rg.make(WRh.get_element_type(), Shape{1, static_cast(hidden_size)}, 0); } + // perform additional check for applicability of the transformation + // without this check, process_weights can fail + if (WR.get_partial_shape()[1] != (hidden_size + input_size)) { + return false; + } + if (WRh.get_partial_shape()[1] != (hidden_size + input_size)) { + return false; + } + Output Wzrh, Rzrh, Bzrh; if (cnt_of_consumers_of_zero_out == 1 && cnt_of_consumers_of_first_out == 2) { tie(Wzrh, Rzrh) = process_weights(rg, false, WR, WRh, input_size, hidden_size, axis_0, axis_1); diff --git a/src/common/transformations/src/transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.cpp b/src/common/transformations/src/transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.cpp index f9738929931..3304ee3718a 100644 --- a/src/common/transformations/src/transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.cpp @@ -116,7 +116,7 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st } // Remove inputs bool pass_required = false; - std::set> required_inputs; + std::set required_inputs_indices; auto op_inputs = multi_subgraph_op->input_values(); std::vector> to_remove_descriptors_indexes; to_remove_descriptors_indexes.resize(subgraphs_size); @@ -133,7 +133,7 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st } else { // collecting required inputs is needed to detect cases where the input // is not needed in a one body, but the other one uses it (for example If case) - required_inputs.insert(op_inputs[body_in_descriptors[i]->m_input_index]); // only unique + required_inputs_indices.insert(body_in_descriptors[i]->m_input_index); } } } @@ -148,7 +148,9 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st } }; auto update_op_inputs_desc = [&subgraphs_size](const std::shared_ptr& op, + std::set& required_inputs_indices, uint64_t removed_loop_idx) { + std::set new_required_inputs_indices; for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) { auto& descriptors = op->get_input_descriptions(static_cast(body_idx)); for (auto& desc : descriptors) { @@ -157,6 +159,14 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st } } } + for (auto input_index : required_inputs_indices) { + if (input_index > removed_loop_idx) { + new_required_inputs_indices.insert(input_index - 1); + } else { + new_required_inputs_indices.insert(input_index); + } + } + required_inputs_indices = new_required_inputs_indices; }; // Remove dangling body params and input and update input descriptors for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) { @@ -174,13 +184,17 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st update_body_param_desc(body_in_descriptors, body_in_descriptors[desc_idx]->m_body_parameter_index); // remove dangling input of MultiSubGraphOp which was not removed earlier - auto& current_input = op_inputs[body_in_descriptors[desc_idx]->m_input_index]; - if (std::count(std::begin(required_inputs), std::end(required_inputs), current_input) == 0 && + auto current_input_idx = body_in_descriptors[desc_idx]->m_input_index; + auto& current_input = op_inputs[current_input_idx]; + // the same input tensor can go to different input ports + if (std::count(std::begin(required_inputs_indices), + std::end(required_inputs_indices), + current_input_idx) == 0 && std::count(std::begin(op_inputs), std::end(op_inputs), current_input) > 0) { - op_inputs.erase(std::next(op_inputs.begin(), body_in_descriptors[desc_idx]->m_input_index)); + op_inputs.erase(std::next(op_inputs.begin(), current_input_idx)); // Move all input indexes (in all bodies) which are after these indicated by // to_remove_descriptors_indexes and are not used in any body - update_op_inputs_desc(multi_subgraph_op, body_in_descriptors[desc_idx]->m_input_index); + update_op_inputs_desc(multi_subgraph_op, required_inputs_indices, current_input_idx); } } else { updated_body_in_descriptors.emplace_back(body_in_descriptors[desc_idx]); diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index 24b5824fe33..ad9b5b76bdf 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -14,6 +14,7 @@ #include "helper_transforms/embedding_segments_feature_fusing.hpp" #include "helper_transforms/gru_block_cell_replacer.hpp" #include "helper_transforms/saved_model_unused_remover.hpp" +#include "helper_transforms/tensor_array_v3_replacer.hpp" #include "input_model.hpp" #include "op_table.hpp" #include "openvino/core/so_extension.hpp" @@ -491,6 +492,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/tensorflow/src/op/merge.cpp b/src/frontends/tensorflow/src/op/merge.cpp index 3594f93ed08..708de72aa34 100644 --- a/src/frontends/tensorflow/src/op/merge.cpp +++ b/src/frontends/tensorflow/src/op/merge.cpp @@ -5,6 +5,8 @@ #include "helper_ops/merge.hpp" #include "common_op_table.hpp" +#include "helper_ops/enter.hpp" +#include "helper_ops/next_iteration.hpp" #include "openvino/frontend/tensorflow/node_context.hpp" #include "openvino/op/constant.hpp" #include "utils.hpp" @@ -24,20 +26,47 @@ OutputVector translate_merge_op(const NodeContext& node) { auto node_name = node.get_name(); default_op_checks(node, 1, {"Merge"}); int input_size = static_cast(node.get_input_size()); - OutputVector inputs; + OutputVector inputs(input_size); for (int input_ind = 0; input_ind < input_size; ++input_ind) { - inputs.push_back(node.get_input(input_ind)); + inputs[input_ind] = node.get_input(input_ind); } // if Merge node has just one input, there is nothing to merge // return the same input and value_index equal to 0 - if (inputs.size() == 1) { + if (input_size == 1) { auto value_index = make_shared(element::i32, Shape{}, 0); value_index->output(0).set_names({node_name + ":1"}); inputs[0].add_names({node_name + ":0"}); return OutputVector{inputs[0], value_index}; } + // check if it is a case of TF1 While: Enter, NextIteration are going to Merge node + // in this case it can refine output shape and type for NextIteration based on Enter + if (input_size == 2) { + auto enter = as_type_ptr(inputs[0].get_node_shared_ptr()); + if (!enter) { + enter = as_type_ptr(inputs[1].get_node_shared_ptr()); + } + auto next_iteration = as_type_ptr(inputs[0].get_node_shared_ptr()); + if (!next_iteration) { + next_iteration = as_type_ptr(inputs[1].get_node_shared_ptr()); + } + + if (enter && next_iteration) { + // set output type and shape for NextIteration + // borrow them from Enter output + auto enter_output_type = enter->output(0).get_element_type(); + auto enter_output_shape = enter->output(0).get_partial_shape(); + auto next_iteration_output_shape = PartialShape::dynamic(enter_output_shape.rank()); + next_iteration->set_output_shape_and_type(next_iteration_output_shape, enter_output_type); + + // reset inputs + // refines input shapes and types for Merge node + inputs[0] = enter->output(0); + inputs[1] = next_iteration->output(0); + } + } + auto merge_node = make_shared(inputs, node.get_decoder()); set_node_name(node.get_name(), merge_node); diff --git a/src/frontends/tensorflow/src/op/tensor_array_operations.cpp b/src/frontends/tensorflow/src/op/tensor_array_operations.cpp new file mode 100644 index 00000000000..c1b3d6ac205 --- /dev/null +++ b/src/frontends/tensorflow/src/op/tensor_array_operations.cpp @@ -0,0 +1,332 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/enter.hpp" +#include "helper_ops/tensor_array.hpp" +#include "openvino/frontend/tensorflow/node_context.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_nd_update.hpp" +#include "openvino/op/scatter_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; +using namespace ov::frontend::tensorflow; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +namespace { +// the function creates the constant imitating initial tensor array container +Output create_initial_tensor_array_constant(int64_t tensor_element_rank, + const element::Type& element_type, + Output size, + const string& node_name) { + // adjust size to have it of shape [1] for further concatenation with element shape + auto new_size_shape = make_shared(element::i32, Shape{1}, 1); + size = make_shared(size, new_size_shape, false); + + // create a vector of size element_shape.rank() with ones + // and compute a shape of initial tensor array [size, 1, ..., 1] + vector ones(tensor_element_rank, 1); + auto ones_const = make_shared(element::i32, Shape{ones.size()}, ones); + auto target_shape = make_shared(OutputVector{size, ones_const}, 0); + + // create initial tensor array + auto scalar_value = make_shared(element_type, Shape{}, vector{0}); + auto initial_tensor_array = make_shared(scalar_value, target_shape); + + return initial_tensor_array->output(0); +} +} // namespace + +OutputVector translate_tensor_array_v3_op(const NodeContext& node) { + // TensorArrayV3 has just one input: + // 0) size to initialize a size of tensor array + default_op_checks(node, 1, {"TensorArrayV3"}); + auto dtype = node.get_attribute("dtype"); + auto size = node.get_input(0); + auto element_shape = node.get_attribute("element_shape"); + + if (element_shape.rank().is_static()) { + auto node_name = node.get_name(); + auto new_output1 = + create_initial_tensor_array_constant(element_shape.rank().get_length(), dtype, size, node.get_name()); + new_output1.set_names({node_name + ":0"}); + auto new_output2 = + create_initial_tensor_array_constant(element_shape.rank().get_length(), dtype, size, node.get_name()); + new_output2.set_names({node_name + ":1"}); + return OutputVector{new_output1, new_output2}; + } + + // dynamic case when it is unable retrieve element rank from the attribute + auto tensor_array_v3 = make_shared(size, dtype, node.get_decoder()); + set_node_name(node.get_name(), tensor_array_v3); + + return tensor_array_v3->outputs(); +} + +OutputVector translate_tensor_array_scatter_v3_op(const NodeContext& node) { + // TensorArrayScatterV3 has four inputs: + // 0) handle, a Tensor of type resource. The handle to a TensorArray. + // 1) indices, a Tensor of type int32. The locations at which to write the tensor elements. + // 2) value, a Tensor. The concatenated tensor to write to the TensorArray + // 3) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations. + // The operation has one output: + // 0) flow_out indicates that operation is complete and handle resource is updated + default_op_checks(node, 4, {"TensorArrayScatterV3"}); + auto indices = node.get_input(1); + auto value = node.get_input(2); + // flow_in is used for transferring input tensor array + auto tensor_array = node.get_input(3); + + // check if producer of tensor_array is TensorArrayV3, internal operation, still + // if yes, try to replace it with constant container + if (as_type_ptr(tensor_array.get_node_shared_ptr()) && + value.get_partial_shape().rank().is_static()) { + // set tensor element rank that gets known from TensorArrayScatterV3 operation + auto tensor_array_v3 = as_type_ptr(tensor_array.get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION( + node, + value.get_partial_shape().rank().get_length() > 0, + "[TensorFlow Frontend] internal error or inconsistent model: value to TensorArrayScatterV3 is a scalar"); + int64_t tensor_element_rank = value.get_partial_shape().rank().get_length() - 1; + tensor_array_v3->set_element_rank(tensor_element_rank); + } + + // compute element shape (shape of a tensor in the tensor array) using value + auto element_shape = make_shared(value, element::i32)->output(0); + auto one_const = make_shared(element::i32, Shape{1}, 1); + auto max_const = make_shared(element::i32, Shape{1}, numeric_limits::max()); + element_shape = make_shared(element_shape, one_const, max_const, one_const); + + // compute size of tensor array + auto tensor_array_size = make_shared(tensor_array, element::i32)->output(0); + auto zero_const = make_shared(element::i32, Shape{1}, 0); + tensor_array_size = make_shared(tensor_array_size, zero_const, zero_const); + + // compute the new shape for tensor array where new tensors will be inserted + auto new_shape = make_shared(OutputVector{tensor_array_size, element_shape}, 0); + tensor_array = make_shared(tensor_array, new_shape); + + // adjust indices for ScatterNDUpdate to have a shape [N, 1] where N is a number of indices + indices = make_shared(indices, one_const); + + // compute updated tensor array using ScatterNDUpdate + // value should be of a shape [N, ] + auto updated_tensor_array = make_shared(tensor_array, indices, value); + set_node_name(node.get_name(), updated_tensor_array); + + // TensorArrayScatterV3 has just one output flow_out + // that is used for transferring updated tensor array + return {updated_tensor_array}; +} + +OutputVector translate_tensor_array_read_v3_op(const NodeContext& node) { + // TensorArrayReadV3 read an element from the TensorArray into the output + // and it has three inputs: + // 0) handle, a Tensor of type resource. The handle to a TensorArray. + // 1) index, a Tensor of type int32. The location from which to read the value + // 2) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations. + // The operation has one output + // 0) read value from tensor array + default_op_checks(node, 3, {"TensorArrayReadV3"}); + auto index = node.get_input(1); + // flow_in is used for transferring input tensor array + auto tensor_array = node.get_input(2); + auto dtype = node.get_attribute("dtype"); + + // adjust the index to a scalar for using Gather operation + auto new_shape = make_shared(element::i32, Shape{0}, vector{}); + index = make_shared(index, new_shape, false); + + // gather tensor element by the required position + auto gather_axis = make_shared(element::i32, Shape{1}, 0); + Output tensor_element = make_shared(tensor_array, index, gather_axis); + tensor_element = make_shared(tensor_element, dtype); + + set_node_name(node.get_name(), tensor_element.get_node_shared_ptr()); + return {tensor_element}; +} + +OutputVector translate_tensor_array_close_v3_op(const NodeContext& node) { + // TensorArrayCloseV3 deletes the TensorArray from its resource container + // it outputs nothing + default_op_checks(node, 1, {"TensorArrayCloseV3"}); + return {}; +} + +OutputVector translate_tensor_array_size_v3_op(const NodeContext& node) { + // TensorArraySizeV3 gets the current size of the TensorArray + // it outputs int32 scalar equal to a size of the tensor array + default_op_checks(node, 2, {"TensorArraySizeV3"}); + // skip the handle by the first input + auto tensor_array = node.get_input(1); + + auto size = make_shared(tensor_array, element::i32)->output(0); + auto zero_const = make_shared(element::i32, Shape{1}, 0); + size = make_shared(size, zero_const, zero_const); + + // size must be scalar + auto scalar_shape = make_shared(element::i32, Shape{0}, vector{}); + size = make_shared(size, scalar_shape, false); + + set_node_name(node.get_name(), size.get_node_shared_ptr()); + return {size}; +} + +OutputVector translate_tensor_array_gather_v3_op(const NodeContext& node) { + // TensorArrayGatherV3 gathers specific elements from the TensorArray into output + // and it has three inputs: + // 0) handle, a Tensor of type resource. The handle to a TensorArray. + // 1) indices, a Tensor of type int32. The location from which to read tensor elements + // 2) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations. + // The operation has one output + // 0) value with read tensor elements + // it outputs int32 scalar equal to a size of the tensor array + default_op_checks(node, 3, {"TensorArrayGatherV3"}); + // skip the handle by the first input + auto indices = node.get_input(1); + // flow_in serves for transferring tensor array + // handle input is ignored + auto tensor_array = node.get_input(2); + auto dtype = node.get_attribute("dtype"); + auto element_shape = node.get_attribute("element_shape", PartialShape::dynamic()); + + // gather tensor element by the required position + auto gather_axis = make_shared(element::i32, Shape{1}, 0); + Output tensor_element = make_shared(tensor_array, indices, gather_axis); + tensor_element = make_shared(tensor_element, dtype); + + // concretize tensor_element shape if this is specified + if (tensor_element.get_partial_shape().rank().is_dynamic() && element_shape.is_static()) { + auto element_shape_value = element_shape.get_shape(); + auto element_shape_const = + make_shared(element::i32, Shape{element_shape_value.size()}, element_shape_value); + auto size = make_shared(tensor_array, element::i32)->output(0); + auto zero_const = make_shared(element::i32, Shape{1}, 0); + size = make_shared(size, zero_const, zero_const); + auto new_shape = make_shared(OutputVector{size, element_shape_const}, 0); + tensor_element = make_shared(tensor_element, new_shape, false); + } + + set_node_name(node.get_name(), tensor_element.get_node_shared_ptr()); + return {tensor_element}; +} + +OutputVector translate_tensor_array_concat_v3_op(const NodeContext& node) { + // TensorArrayConcatV3 Concat the elements from the TensorArray into value + // and it has two inputs: + // 0) handle, a Tensor of type resource. The handle to a TensorArray. + // 1) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations. + // The operation has one output + // 0) concatenated value by the first dimension + default_op_checks(node, 2, {"TensorArrayConcatV3"}); + // flow_in serves for transferring tensor array + // handle input is ignored + auto tensor_array = node.get_input(1); + auto dtype = node.get_attribute("dtype"); + + // since tensor array saves tensor elements in the concatenated form by the first dimension + // and for this operation they should be concatenated by the first dimension of the tensor element + // it needs to combine the first two dimensions + // tensor array is of shape [k, n0, n1, ..., nd] + // 1. compute element shape excluding the first dimension + auto zero_const = make_shared(element::i32, Shape{1}, 0); + auto one_const = make_shared(element::i32, Shape{1}, 1); + auto two_const = make_shared(element::i32, Shape{1}, 2); + auto max_const = make_shared(element::i32, Shape{1}, numeric_limits::max()); + auto tensor_array_shape = make_shared(tensor_array, element::i64); + auto element_shape_no_two_dims = make_shared(tensor_array_shape, two_const, max_const, one_const); + // 2. compute the first and second dimensions k and n0 + auto k = make_shared(tensor_array_shape, zero_const, zero_const); + auto n0 = make_shared(tensor_array_shape, one_const, zero_const); + auto k_by_n0 = make_shared(k, n0); + // 3. compute the first output containing concatenated tensor elements + // it folds the first and second dimensions + auto new_shape = make_shared(OutputVector{k_by_n0, element_shape_no_two_dims}, 0); + auto concatenated_array = make_shared(tensor_array, new_shape, false)->output(0); + concatenated_array = make_shared(concatenated_array, dtype); + concatenated_array.set_names({node.get_name() + ":0"}); + // 4. compute the second output with length of each tensor element for the concatenation + auto lengths = make_shared(n0, k)->output(0); + lengths.set_names({node.get_name() + ":1"}); + + return {concatenated_array, lengths}; +} + +OutputVector translate_tensor_array_write_v3_op(const NodeContext& node) { + // TensorArrayWriteV3 pushes an element onto the tensor_array. + // and it has four inputs + // 0) handle, a Tensor of type resource. The handle to a TensorArray. + // 1) index, a Tensor of type int32. The location where to write tensor element + // 2) value, a Tensor. The tensor to write at the specified location + // 3) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations. + // The operation has one output + // 0) read value from tensor array + default_op_checks(node, 4, {"TensorArrayWriteV3"}); + auto handle = node.get_input(0); + auto index = node.get_input(1); + auto value = node.get_input(2); + // flow_in is used for transferring input tensor array + // tensor array has a rank equal to 1 + rank(element of tensor array) + // if it just initialized, its shape is equal to [tensor_array_size, 1, ..., 1] + // otherwise, it is equal to [tensor_array_size, ] + auto tensor_array = node.get_input(3); + + // reshape index to have it of [1] shape + auto new_index_shape = make_shared(element::i32, Shape{1}, 1); + index = make_shared(index, new_index_shape, false); + + if (auto enter = as_type_ptr(handle.get_node_shared_ptr())) { + if (as_type_ptr(enter->input_value(0).get_node_shared_ptr()) && + value.get_partial_shape().rank().is_static()) { + // set tensor element rank that gets known from TensorArrayWriteV3 operation + auto tensor_array_v3 = as_type_ptr(enter->input_value(0).get_node_shared_ptr()); + int64_t tensor_element_rank = value.get_partial_shape().rank().get_length(); + tensor_array_v3->set_element_rank(tensor_element_rank); + } + } + + // compute element shape in the input tensor array + auto tensor_array_shape = make_shared(tensor_array, element::i32); + + // compute the current size of tensor array + auto zero_const = make_shared(element::i32, Shape{1}, 0); + auto tensor_array_size = make_shared(tensor_array_shape, zero_const, zero_const); + + // adjust tensor array to have the correct shape [size, ] before value insertion + auto element_shape = make_shared(value, element::i32); + auto new_tensor_array_shape = make_shared(OutputVector{tensor_array_size, element_shape}, 0); + tensor_array = make_shared(tensor_array, new_tensor_array_shape); + + // update the resulted tensor using ScatterUpdate + value = make_shared(value, zero_const); + auto scatter_update = make_shared(tensor_array, index, value, zero_const); + + set_node_name(node.get_name(), scatter_update); + // use flow_out for transferring updated tensor array + return {scatter_update}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 149b2d76184..3a4c570c657 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -46,6 +46,14 @@ TF_OP_CONVERTER(translate_sparse_segment_sum_op); TF_OP_CONVERTER(translate_staticregexfullmatch_op); TF_OP_CONVERTER(translate_stringjoin_op); TF_OP_CONVERTER(translate_switch_op); +TF_OP_CONVERTER(translate_tensor_array_close_v3_op); +TF_OP_CONVERTER(translate_tensor_array_concat_v3_op); +TF_OP_CONVERTER(translate_tensor_array_gather_v3_op); +TF_OP_CONVERTER(translate_tensor_array_read_v3_op); +TF_OP_CONVERTER(translate_tensor_array_scatter_v3_op); +TF_OP_CONVERTER(translate_tensor_array_size_v3_op); +TF_OP_CONVERTER(translate_tensor_array_v3_op); +TF_OP_CONVERTER(translate_tensor_array_write_v3_op); TF_OP_CONVERTER(translate_varhandle_op); TF_OP_CONVERTER(translate_variable_op); TF_OP_CONVERTER(translate_varisinitialized_op); @@ -174,6 +182,8 @@ const std::map get_supported_ops() { {"Gather", CreatorFunction(translate_gather_op)}, {"GatherV2", CreatorFunction(translate_gather_v2_op)}, {"GatherNd", CreatorFunction(translate_gather_nd_op)}, + {"GatherTree", CreatorFunction(translate_gather_tree_op)}, + {"Addons>GatherTree", CreatorFunction(translate_gather_tree_op)}, {"HashTable", CreatorFunction(translate_hash_table_op)}, {"HashTableV2", CreatorFunction(translate_hash_table_op)}, {"Identity", CreatorFunction(translate_identity_op)}, @@ -269,6 +279,14 @@ const std::map get_supported_ops() { {"StatelessWhile", CreatorFunction(translate_while_op)}, {"StridedSlice", CreatorFunction(translate_strided_slice_op)}, {"Switch", CreatorFunction(translate_switch_op)}, + {"TensorArrayCloseV3", CreatorFunction(translate_tensor_array_close_v3_op)}, + {"TensorArrayConcatV3", CreatorFunction(translate_tensor_array_concat_v3_op)}, + {"TensorArrayGatherV3", CreatorFunction(translate_tensor_array_gather_v3_op)}, + {"TensorArrayReadV3", CreatorFunction(translate_tensor_array_read_v3_op)}, + {"TensorArrayScatterV3", CreatorFunction(translate_tensor_array_scatter_v3_op)}, + {"TensorArraySizeV3", CreatorFunction(translate_tensor_array_size_v3_op)}, + {"TensorArrayV3", CreatorFunction(translate_tensor_array_v3_op)}, + {"TensorArrayWriteV3", CreatorFunction(translate_tensor_array_write_v3_op)}, {"TensorListFromTensor", CreatorFunction(translate_tensor_list_from_tensor_op)}, {"TensorListGetItem", CreatorFunction(translate_tensor_list_get_item_op)}, {"TensorListLength", CreatorFunction(translate_tensor_list_length_op)}, diff --git a/src/frontends/tensorflow/src/tf_utils.cpp b/src/frontends/tensorflow/src/tf_utils.cpp index c72e8e7bb90..e298f49f928 100644 --- a/src/frontends/tensorflow/src/tf_utils.cpp +++ b/src/frontends/tensorflow/src/tf_utils.cpp @@ -423,7 +423,7 @@ shared_ptr create_loop_for_tf_while(const std::string& while_node_name FRONT_END_GENERAL_CHECK( cond_results.size() == 1 && cond_results[0], "[TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node."); - auto body_condition_output_idx = static_cast(body_results.size()); + auto body_condition_output_idx = body_results.size(); body_model->add_results(cond_results); // type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present @@ -435,14 +435,18 @@ shared_ptr create_loop_for_tf_while(const std::string& while_node_name loop->set_function(body_model); // body_results may contain less nodes than body_params that means back edge exists not for all body_params - for (size_t input_ind = 0; input_ind < static_cast(body_condition_output_idx); ++input_ind) { + for (size_t input_ind = 0; input_ind < body_condition_output_idx; ++input_ind) { loop->set_merged_input(body_params[input_ind], ov_inputs[input_ind], body_results[input_ind]->input_value(0)); } - loop->set_special_body_ports({-1, body_condition_output_idx}); + loop->set_special_body_ports({-1, static_cast(body_condition_output_idx)}); + // set invariant inputs for the loop + for (size_t input_ind = body_condition_output_idx; input_ind < input_size; ++input_ind) { + loop->set_invariant_input(body_params[input_ind], ov_inputs[input_ind]); + } // set external outputs for Loop node // do not get execution condition outside of the Loop node - for (size_t output_ind = 0; output_ind < static_cast(body_condition_output_idx); ++output_ind) { + for (size_t output_ind = 0; output_ind < body_condition_output_idx; ++output_ind) { loop->get_iter_value(body_results[output_ind]); } loop->validate_and_infer_types(); diff --git a/src/frontends/tensorflow/tests/convert_model.cpp b/src/frontends/tensorflow/tests/convert_model.cpp index fc00a678496..f6ec18cf9cc 100644 --- a/src/frontends/tensorflow/tests/convert_model.cpp +++ b/src/frontends/tensorflow/tests/convert_model.cpp @@ -15,7 +15,8 @@ static const std::vector models{ std::string("2in_2out/2in_2out.pb"), std::string("forward_edge_model/forward_edge_model.pbtxt"), std::string("forward_edge_model2/forward_edge_model2.pbtxt"), - std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt")}; + std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt"), + std::string("gather_tree_model/gather_tree_model.pbtxt")}; INSTANTIATE_TEST_SUITE_P(TFConvertModelTest, FrontEndConvertModelTest, diff --git a/src/frontends/tensorflow/tests/test_models/models_pbtxt/gather_tree_model.pbtxt b/src/frontends/tensorflow/tests/test_models/models_pbtxt/gather_tree_model.pbtxt new file mode 100644 index 00000000000..54351036dd7 --- /dev/null +++ b/src/frontends/tensorflow/tests/test_models/models_pbtxt/gather_tree_model.pbtxt @@ -0,0 +1,103 @@ +node { + name: "step_ids" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 20 + } + dim { + size: 2 + } + dim { + size: 30 + } + } + } + } +} +node { + name: "parent_ids" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 20 + } + dim { + size: 2 + } + dim { + size: 30 + } + } + } + } +} +node { + name: "max_seq_len" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + } + } + } +} +node { + name: "end_token" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "Addons>GatherTree" + op: "Addons>GatherTree" + input: "step_ids" + input: "parent_ids" + input: "max_seq_len" + input: "end_token" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index 75a9bdcafc9..29efb83547d 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -72,6 +72,7 @@ OP_CONVERTER_NAMED(translate_fused_batch_norm_op); OP_CONVERTER(translate_gather_op); OP_CONVERTER(translate_gather_v2_op); OP_CONVERTER(translate_gather_nd_op); +OP_CONVERTER(translate_gather_tree_op); OP_CONVERTER(translate_identity_op); OP_CONVERTER(translate_identity_n_op); OP_CONVERTER(translate_input_arg_op); diff --git a/src/frontends/tensorflow_common/include/helper_ops/merge.hpp b/src/frontends/tensorflow_common/include/helper_ops/merge.hpp index eb7e611f3e2..6261dd0e67c 100644 --- a/src/frontends/tensorflow_common/include/helper_ops/merge.hpp +++ b/src/frontends/tensorflow_common/include/helper_ops/merge.hpp @@ -33,20 +33,34 @@ public: ov::PartialShape output_data_shape = ov::PartialShape::dynamic(); auto input_size = get_input_size(); - bool merge_output_shape = true; for (size_t input_ind = 0; input_ind < input_size; ++input_ind) { auto input_type = get_input_element_type(input_ind); if (input_type.is_static()) { output_data_type = input_type; } - // check if it still needs to merge input shapes - // if yes, it tries to merge them - if (merge_output_shape && - !PartialShape::merge_into(output_data_shape, get_input_partial_shape(input_ind))) { - merge_output_shape = false; - // reset output shape to dynamic rank + auto input_shape = get_input_partial_shape(input_ind); + if (input_shape.rank().is_dynamic()) { + continue; + } + + if (output_data_shape.rank().is_dynamic()) { + // firstly met shape of static rank + // immediately use this shape of static rank + output_data_shape = input_shape; + } else if (output_data_shape.rank().is_static() && + output_data_shape.rank().get_length() != input_shape.rank().get_length()) { + // different inputs have different rank means output must be of a dynamic rank output_data_shape = ov::PartialShape::dynamic(); + break; + } else { + auto output_rank = output_data_shape.rank().get_length(); + for (int64_t dim_ind = 0; dim_ind < output_rank; ++dim_ind) { + if (input_shape[dim_ind] != output_data_shape[dim_ind]) { + // different inputs can have different dimensions so it must combine them + output_data_shape[dim_ind] = ov::Dimension::dynamic(); + } + } } } diff --git a/src/frontends/tensorflow_common/include/helper_ops/next_iteration.hpp b/src/frontends/tensorflow_common/include/helper_ops/next_iteration.hpp index eb262b4307a..e556c9ad447 100644 --- a/src/frontends/tensorflow_common/include/helper_ops/next_iteration.hpp +++ b/src/frontends/tensorflow_common/include/helper_ops/next_iteration.hpp @@ -43,6 +43,10 @@ public: producer_output_port_idx = m_producer_output_port_idx; } + void set_output_shape_and_type(const ov::PartialShape& output_shape, const ov::element::Type& output_type) { + set_output_type(0, output_type, output_shape); + } + private: bool m_back_edge_set; std::string m_producer_name; diff --git a/src/frontends/tensorflow_common/include/helper_ops/tensor_array.hpp b/src/frontends/tensorflow_common/include/helper_ops/tensor_array.hpp new file mode 100644 index 00000000000..030ff12d5b6 --- /dev/null +++ b/src/frontends/tensorflow_common/include/helper_ops/tensor_array.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "internal_operation.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { + +// Internal operation for TensorArrayV3 +// An array of Tensors of given size +// It has two outputs: +// 1. handle - resource (a reference) for tensor array +// 2. flow_out - float type will be used for storing tensor array +class TensorArrayV3 : public InternalOperation { +public: + OPENVINO_OP("TensorArrayV3", "ov::frontend::tensorflow", InternalOperation); + + TensorArrayV3(const Output& size, + const ov::element::Type element_type, + const std::shared_ptr& decoder = std::make_shared()) + : InternalOperation(decoder, OutputVector{size}, 2, "TensorArrayV3"), + m_element_type(element_type), + m_element_rank(-1) { + validate_and_infer_types(); + } + + void validate_and_infer_types() override { + set_output_type(0, m_element_type, ov::PartialShape::dynamic()); + set_output_type(1, m_element_type, ov::PartialShape::dynamic()); + } + + ov::element::Type get_element_type() const { + return m_element_type; + } + + int64_t get_element_rank() const { + return m_element_rank; + } + + void set_element_rank(int64_t element_rank) { + FRONT_END_GENERAL_CHECK( + element_rank >= 0, + "[TensorFlow Frontend] internal error: negavite element rank tries to set for TensorArrayV3"); + m_element_rank = element_rank; + } + +private: + ov::element::Type m_element_type; + int64_t m_element_rank; +}; + +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/include/helper_transforms/tensor_array_v3_replacer.hpp b/src/frontends/tensorflow_common/include/helper_transforms/tensor_array_v3_replacer.hpp new file mode 100644 index 00000000000..42e5a0ad754 --- /dev/null +++ b/src/frontends/tensorflow_common/include/helper_transforms/tensor_array_v3_replacer.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace pass { + +// This transformation replaces internal operation TensorArrayV3 with a Constant +// that simulates initial state of tensor array container +class TensorArrayV3Replacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorArrayV3Replacer"); + TensorArrayV3Replacer(); +}; + +} // namespace pass +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/helper_transforms/tensor_array_v3_replacer.cpp b/src/frontends/tensorflow_common/src/helper_transforms/tensor_array_v3_replacer.cpp new file mode 100644 index 00000000000..72ed922511c --- /dev/null +++ b/src/frontends/tensorflow_common/src/helper_transforms/tensor_array_v3_replacer.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "helper_transforms/tensor_array_v3_replacer.hpp" + +#include "helper_ops/tensor_array.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; +using namespace ov::pass; + +ov::frontend::tensorflow::pass::TensorArrayV3Replacer::TensorArrayV3Replacer() { + auto tensor_array_v3 = pattern::wrap_type(); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + NodeRegistry rg; + + auto tensor_array_v3 = dynamic_pointer_cast(m.get_match_root()); + if (!tensor_array_v3) { + return false; + } + + int32_t tensor_element_rank = static_cast(tensor_array_v3->get_element_rank()); + if (tensor_element_rank < 0) { + return false; + } + + // retrieve all TensorArrayV3 inputs + auto size = tensor_array_v3->input_value(0); + auto element_type = tensor_array_v3->get_element_type(); + + // adjust size to have it of shape [1] for further concatenation with element shape + auto new_size_shape = rg.make(element::i32, Shape{1}, 1); + auto new_size = rg.make(size, new_size_shape, false); + + // create a vector of size element_shape.rank() with ones + // and compute a shape of initial tensor array [size, 1, ..., 1] + Output target_shape; + if (tensor_element_rank == 0) { + target_shape = new_size->output(0); + } else { + vector ones(tensor_element_rank, 1); + auto ones_const = rg.make(element::i32, Shape{ones.size()}, ones); + target_shape = rg.make(OutputVector{new_size, ones_const}, 0)->output(0); + } + + // create initial tensor array + auto scalar_value = make_shared(element_type, Shape{}, vector{0}); + auto initial_tensor_array = make_shared(scalar_value, target_shape); + + // preserve names of the node and the output tensor + initial_tensor_array->set_friendly_name(tensor_array_v3->get_friendly_name()); + copy_runtime_info(tensor_array_v3, rg.get()); + + ov::replace_node(tensor_array_v3, + ov::OutputVector{initial_tensor_array->output(0), initial_tensor_array->output(0)}); + return true; + }; + + auto m = + std::make_shared(tensor_array_v3, "ov::frontend::tensorflow::pass::TensorArrayV3Replacer"); + register_matcher(m, callback); +} diff --git a/src/frontends/tensorflow_common/src/op/gather_tree.cpp b/src/frontends/tensorflow_common/src/op/gather_tree.cpp new file mode 100644 index 00000000000..e349efe6784 --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/gather_tree.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/gather_tree.hpp" + +#include "common_op_table.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reshape.hpp" + +using namespace std; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_gather_tree_op(const NodeContext& node) { + default_op_checks(node, 4, {"GatherTree", "Addons>GatherTree"}); + auto step_ids = node.get_input(0); + auto parent_ids = node.get_input(1); + auto max_sequence_lengths = node.get_input(2); + auto end_token = node.get_input(3); + + // adjust end_token that must be a scalar + auto new_shape_end_token = make_shared(element::i32, Shape{0}, vector{}); + end_token = make_shared(end_token, new_shape_end_token, false); + + auto gather_tree = make_shared(step_ids, parent_ids, max_sequence_lengths, end_token); + set_node_name(node.get_name(), gather_tree); + + return {gather_tree}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/tensoriterator.cpp b/src/plugins/intel_cpu/src/nodes/tensoriterator.cpp index b38ae2fde7e..b9b7345b37f 100644 --- a/src/plugins/intel_cpu/src/nodes/tensoriterator.cpp +++ b/src/plugins/intel_cpu/src/nodes/tensoriterator.cpp @@ -513,7 +513,7 @@ void TensorIterator::createPrimitive() { lastUsedCond = initial_cond_check->getStatus(); } - if (isDynamicNode()) + if (runAsDynamic()) prepareDynamicBuffers(); Node::createPrimitive(); @@ -556,7 +556,7 @@ void TensorIterator::prepareParams() { prepareContinueCond(); prepareLoopBodyCurrentIteration(); - if (!isDynamicNode()) { + if (!runAsDynamic()) { prepareOutputPorts(); prepareBackEdges(); } @@ -568,6 +568,12 @@ void TensorIterator::prepareParams() { } void TensorIterator::execute(dnnl::stream strm) { + //Special case, the subgraph is dynamic while the node has all static shapes + if (runAsDynamic()) { + executeDynamicImpl(strm); + return; + } + sub_graph.ResetInferCount(); bool continue_cond = initial_cond_check->getStatus(); @@ -872,6 +878,10 @@ int TensorIterator::getNumIteration(const std::vector& inputPortMap, co return numIterations; } +bool TensorIterator::runAsDynamic() const { + return isDynamicNode() || Graph::Status::ReadyDynamic == sub_graph.getStatus(); +} + bool TensorIterator::created() const { return getType() == Type::TensorIterator; } diff --git a/src/plugins/intel_cpu/src/nodes/tensoriterator.h b/src/plugins/intel_cpu/src/nodes/tensoriterator.h index 8633be5c28d..104ee077f9a 100644 --- a/src/plugins/intel_cpu/src/nodes/tensoriterator.h +++ b/src/plugins/intel_cpu/src/nodes/tensoriterator.h @@ -138,6 +138,7 @@ private: void reshapeAndFillOutput(dnnl::stream strm); bool checkForInputAndBodyShapesInequality() const; int getNumIteration(const std::vector& inputPortMap, const std::vector& outputPortMap) const; + bool runAsDynamic() const; ExtensionManager::Ptr ext_mng; Graph sub_graph; diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/loop.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/loop.cpp index b92646e4581..cda499b042f 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/loop.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/loop.cpp @@ -371,6 +371,65 @@ protected: } }; +class StaticLoopDynamicSubgraphCPUTest : public SubgraphBaseTest { + void SetUp() override { + InputShape input_shape = {{25, 1, 1}, {{25, 1, 1}}}; + InputShape input_exec_flag_shape = {{1}, {{1}}}; + targetDevice = ov::test::utils::DEVICE_CPU; + ElementType netType = ov::element::f32; + init_input_shapes({input_shape, input_exec_flag_shape}); + + ov::ParameterVector params; + params.push_back(std::make_shared(netType, inputDynamicShapes[0])); + + // exec_condition + params.push_back(std::make_shared(ov::element::boolean, inputDynamicShapes[1])); + + auto trip_count_input = std::make_shared(ov::element::i64, ov::Shape{1}, 2); + auto body_condition_const = std::make_shared(ov::element::boolean, ov::Shape{1}, true); + + // Body parameters + ov::ParameterVector body_params = {std::make_shared(netType, ov::PartialShape{25, 1, -1})}; + + // Body + auto broadcast_target_shape = std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{25, 1, 256}); + auto broadcast_axis_mapping = std::make_shared(ov::element::i64, ov::Shape{1}, 0); + auto broadcast = std::make_shared(body_params[0], broadcast_target_shape); + auto body = std::make_shared(ov::OutputVector{body_condition_const, broadcast}, body_params); + + auto loop = std::make_shared(trip_count_input, params[1]); + loop->set_function(body); + loop->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 0}); + + loop->set_merged_input(body_params.front(), params.front(), broadcast); + + auto out0 = loop->get_iter_value(body_condition_const, -1); + auto out1 = loop->get_iter_value(broadcast, -1); + + auto result0 = std::make_shared(out0); + auto result1 = std::make_shared(out1); + function = std::make_shared(ov::ResultVector{result0, result1}, params, "loop"); + } + void generate_inputs(const std::vector& targetInputStaticShapes) override { + inputs.clear(); + const auto& funcInputs = function->inputs(); + for (size_t i = 0; i < funcInputs.size(); ++i) { + const auto& funcInput = funcInputs[i]; + ov::Tensor tensor; + + if (i == 1) { + tensor = ov::Tensor(funcInput.get_element_type(), targetInputStaticShapes[i]); + auto* dataPtr = tensor.data(); + *dataPtr = true; + } else { + tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 2560, 0, 256); + } + inputs.insert({funcInput.get_node_shared_ptr(), tensor}); + } + } +}; + + TEST_P(LoopLayerCPUTest, CompareWithRefs) { run(); } @@ -387,6 +446,10 @@ TEST_P(LoopForConcatLayerCPUTest, CompareWithRefs) { run(); } +TEST_F(StaticLoopDynamicSubgraphCPUTest, smoke_StaticLoopWithDynSubgraph) { + run(); +} + namespace { const std::vector inputPrecisions = { diff --git a/tests/layer_tests/common/utils/tf_utils.py b/tests/layer_tests/common/utils/tf_utils.py index fb02c3f0a1b..913048acf2e 100644 --- a/tests/layer_tests/common/utils/tf_utils.py +++ b/tests/layer_tests/common/utils/tf_utils.py @@ -98,7 +98,7 @@ def summarize_graph(model_path, output_nodes_for_freeze=None, reshape_net=None): variables = list() outputs = list() graph = load_graph(model_path, output_nodes_for_freeze) - unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert', 'switch_t', 'switch_f'] + unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert', 'switch_t', 'switch_f', 'TensorArrayCloseV3'] control_dependents_map = collect_control_dependencies(graph) for node in graph.as_graph_def().node: if node.op == 'Placeholder': diff --git a/tests/layer_tests/tensorflow_tests/test_tf_TensorArrayOps.py b/tests/layer_tests/tensorflow_tests/test_tf_TensorArrayOps.py new file mode 100644 index 00000000000..098f099f74d --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_TensorArrayOps.py @@ -0,0 +1,200 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +def create_tensor_array(data_shape, data_type): + size = data_shape[0] + data = tf.compat.v1.placeholder(data_type, data_shape, 'data') + indices = tf.compat.v1.placeholder(tf.int32, [size], 'indices') + size_const = tf.constant(size, dtype=tf.int32, shape=[]) + handle, flow = tf.raw_ops.TensorArrayV3(size=size_const, dtype=tf.as_dtype(data_type)) + flow = tf.raw_ops.TensorArrayScatterV3(handle=handle, indices=indices, value=data, flow_in=flow) + return handle, flow + + +class TestTensorArraySizeV3(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'data' in inputs_info + assert 'indices' in inputs_info + data_shape = inputs_info['data'] + inputs_data = {} + rng = np.random.default_rng() + inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type) + inputs_data['indices'] = rng.permutation(self.size).astype(np.int32) + return inputs_data + + def create_tensor_array_size_v3(self, data_shape, data_type): + size = data_shape[0] + self.data_type = data_type + self.size = size + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + handle, flow = create_tensor_array(data_shape, data_type) + tf.raw_ops.TensorArraySizeV3(handle=handle, flow_in=flow) + tf.raw_ops.TensorArrayCloseV3(handle=handle) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(data_shape=[5], data_type=np.float32), + dict(data_shape=[10, 20, 30], data_type=np.int32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tensor_array_size_v3(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_tensor_array_size_v3(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestTensorArrayReadV3(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'data' in inputs_info + assert 'indices' in inputs_info + data_shape = inputs_info['data'] + inputs_data = {} + rng = np.random.default_rng() + inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type) + inputs_data['index_to_read'] = rng.integers(0, data_shape[0], []).astype(np.int32) + inputs_data['indices'] = rng.permutation(self.size).astype(np.int32) + return inputs_data + + def create_tensor_array_read_v3(self, data_shape, data_type): + size = data_shape[0] + self.data_type = data_type + self.size = size + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + handle, flow = create_tensor_array(data_shape, data_type) + index_to_read = tf.compat.v1.placeholder(tf.int32, [], 'index_to_read') + tf.raw_ops.TensorArrayReadV3(handle=handle, index=index_to_read, flow_in=flow, + dtype=tf.dtypes.as_dtype(data_type)) + tf.raw_ops.TensorArrayCloseV3(handle=handle) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(data_shape=[6], data_type=np.float32), + dict(data_shape=[8, 5, 6, 10], data_type=np.int32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tensor_array_read_v3(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_tensor_array_read_v3(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestTensorArrayWriteGatherV3(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'data' in inputs_info + assert 'indices' in inputs_info + assert 'value_to_write' in inputs_info + data_shape = inputs_info['data'] + value_shape = inputs_info['value_to_write'] + inputs_data = {} + rng = np.random.default_rng() + inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type) + inputs_data['value_to_write'] = rng.integers(-10, 10, value_shape).astype(self.data_type) + indices_data = rng.permutation(self.size).astype(np.int32) + inputs_data['indices'] = np.delete(indices_data, np.where(indices_data == self.index_to_write)) + return inputs_data + + def create_tensor_array_write_v3(self, size, data_shape, data_type, index_to_write, indices_to_gather): + self.data_type = data_type + self.size = size + self.index_to_write = index_to_write + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + value_to_write = tf.compat.v1.placeholder(data_type, data_shape[1:], 'value_to_write') + index_to_write_const = tf.constant(index_to_write, dtype=tf.int32, shape=[]) + indices_to_gather_const = tf.constant(indices_to_gather, dtype=tf.int32, shape=[len(indices_to_gather)]) + data = tf.compat.v1.placeholder(data_type, data_shape, 'data') + indices = tf.compat.v1.placeholder(tf.int32, [size - 1], 'indices') + size_const = tf.constant(size, dtype=tf.int32, shape=[]) + handle, flow = tf.raw_ops.TensorArrayV3(size=size_const, dtype=tf.as_dtype(data_type)) + flow = tf.raw_ops.TensorArrayScatterV3(handle=handle, indices=indices, value=data, flow_in=flow) + flow = tf.raw_ops.TensorArrayWriteV3(handle=handle, index=index_to_write_const, + value=value_to_write, flow_in=flow) + tf.raw_ops.TensorArrayGatherV3(handle=handle, indices=indices_to_gather_const, flow_in=flow, + dtype=tf.dtypes.as_dtype(data_type)) + tf.raw_ops.TensorArrayCloseV3(handle=handle) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(size=7, data_shape=[6], data_type=np.float32, index_to_write=3, indices_to_gather=[0, 3, 1]), + dict(size=10, data_shape=[9, 2, 4], data_type=np.int32, index_to_write=2, indices_to_gather=[2, 1, 4, 3]), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tensor_array_write_v3(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_tensor_array_write_v3(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestTensorArrayConcatV3(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'data' in inputs_info + assert 'indices' in inputs_info + data_shape = inputs_info['data'] + inputs_data = {} + rng = np.random.default_rng() + inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type) + inputs_data['indices'] = rng.permutation(self.size).astype(np.int32) + return inputs_data + + def create_tensor_array_concat_v3(self, data_shape, data_type): + size = data_shape[0] + self.data_type = data_type + self.size = size + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + handle, flow = create_tensor_array(data_shape, data_type) + tensor_array_concat_v3 = tf.raw_ops.TensorArrayConcatV3(handle=handle, flow_in=flow, + dtype=tf.as_dtype(data_type)) + tf.identity(tensor_array_concat_v3[0], name='values') + tf.identity(tensor_array_concat_v3[1], name='length') + tf.raw_ops.TensorArrayCloseV3(handle=handle) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(data_shape=[5, 3, 11, 2], data_type=np.int32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tensor_array_concat_v3(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_tensor_array_concat_v3(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) diff --git a/tools/mo/openvino/tools/mo/convert_impl.py b/tools/mo/openvino/tools/mo/convert_impl.py index ae6c39a144b..9d683f4b6ac 100644 --- a/tools/mo/openvino/tools/mo/convert_impl.py +++ b/tools/mo/openvino/tools/mo/convert_impl.py @@ -312,8 +312,8 @@ def update_fallback_with_conversion_error(use_new_frontend: bool, is_tf: bool, e conversion_error_re = r"^(\[TensorFlow\ Frontend\]\ Internal\ error\,\ no\ translator\ found\ for\ operation\(s\)\:\ )((\w+)(\,\ \w+)*)$" conversion_error_match = re.findall(conversion_error_re, ex_msg, re.MULTILINE) all_fallback_operations = [ - # corresponds to TF1 TensorList operation - "TensorArrayScatterV3", "TensorArrayV3", "TensorArraySizeV3", "TensorArrayGatherV3", + # corresponds to TF1 While operation + "LoopCond", "Enter", "NextIteration", "Exit", "Switch", "Merge", # corresponds to operations with complex tensors "FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D", "RFFT", "RFFT2D", "RFFT3D", "IRFFT", "IRFFT2D", "IRFFT3D", diff --git a/tools/mo/unit_tests/moc_tf_fe/conversion_basic_models_test.py b/tools/mo/unit_tests/moc_tf_fe/conversion_basic_models_test.py index 8d905d8f131..26ea01b77d6 100644 --- a/tools/mo/unit_tests/moc_tf_fe/conversion_basic_models_test.py +++ b/tools/mo/unit_tests/moc_tf_fe/conversion_basic_models_test.py @@ -235,17 +235,13 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase): freeze_placeholder_with_value, input_shape, only_conversion, True) - def test_conversion_failure_fallback_default(self): + def test_conversion_tf1_while_default(self): self.basic("ctc_model_based.pbtxt", None, None, None, None, None, None, True, True, False, False) - @unittest.skipIf(platform == 'darwin', reason="Ticket - 122182") - def test_conversion_failure_fallback_use_new_frontend(self): - with self.assertRaisesRegex(Exception, - "\[TensorFlow Frontend\] Internal error, no translator found for operation\(s\)\: " - "TensorArrayGatherV3\, TensorArrayReadV3\, TensorArraySizeV3\, TensorArrayV3\, TensorArrayWriteV3"): - self.basic("ctc_model_based.pbtxt", None, None, None, None, - None, None, True, True, True, False) + def test_conversion_tf1_while_use_new_frontend(self): + self.basic("ctc_model_based.pbtxt", None, None, None, None, + None, None, True, True, True, False) @unittest.skip("88349: Fix auto-pruning in legacy FE") def test_conversion_model_oneshot_iterator_use_legacy_frontend(self):