[TF FE] Provide full support of TF1 Control flow and TensorArray* ops (#20270)

* [TF FE] Provide full support of TF1 Control flow and TensorArray ops

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Add missed header for TensorArrayV3 op

* Temporarily disable GRU cell fusion

* Update src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp

* Fix a case when element_shape for TensorArrayV3

* Fix translator for TensorArrayCloseV3

* Update summarize graph with TensorArrayCloseV3

* Add layer tests for TensorArrayScatterV3, Close, Size, Array

* Fix output shape for Merge node

* Remove unused variable

* Fix translator for TensorArrayConcatV3

* Fix translator for TensorArrayConcatV3

* Add layer tests for TensorArrayWriteV3, Gather, and Concat

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Add translator for GatherTree

* Fix TF FE unit-test for GatherTree

* Fix GatherTree translator

* Fix GatherTree translator to handle 1d end_token

* Fix undeclared parameter issue

* Fix GatherTree unit-test

* Add TensorArrayV3Replacer transformation

* Temporarily disable dangling transformation

* Recover RemoveMultiSubGraphOpDanglingParamsResults transformation

* Recover GRUCellFusion transformation

* Simplify check for GRUCellFusion transformation

* Use proper name for unit-tests

* Simplify translator for TensorArrayWriteV3

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix RemoveMultiSubgraphOpDanglingParamsResults transformation

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Additional fix for remove_multi_subgraph_op_dangling_params

* Make static TI run a dynamic subgraph

* Dedicated SL test

* Change condition to respect stat shapes

* Adjust test to cover the code path properly

* Recover fallback for still failing case GNMT

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Co-authored-by: Maksim Kutakov <maksim.kutakov@intel.com>
This commit is contained in:
Roman Kazantsev 2023-10-24 00:50:26 +04:00 committed by GitHub
parent 99dfbb400a
commit 009ef5657c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1034 additions and 34 deletions

View File

@ -148,6 +148,15 @@ ov::pass::GRUCellFusion::GRUCellFusion() {
Bh = rg.make<ov::op::v0::Constant>(WRh.get_element_type(), Shape{1, static_cast<size_t>(hidden_size)}, 0); Bh = rg.make<ov::op::v0::Constant>(WRh.get_element_type(), Shape{1, static_cast<size_t>(hidden_size)}, 0);
} }
// perform additional check for applicability of the transformation
// without this check, process_weights can fail
if (WR.get_partial_shape()[1] != (hidden_size + input_size)) {
return false;
}
if (WRh.get_partial_shape()[1] != (hidden_size + input_size)) {
return false;
}
Output<Node> Wzrh, Rzrh, Bzrh; Output<Node> Wzrh, Rzrh, Bzrh;
if (cnt_of_consumers_of_zero_out == 1 && cnt_of_consumers_of_first_out == 2) { if (cnt_of_consumers_of_zero_out == 1 && cnt_of_consumers_of_first_out == 2) {
tie(Wzrh, Rzrh) = process_weights(rg, false, WR, WRh, input_size, hidden_size, axis_0, axis_1); tie(Wzrh, Rzrh) = process_weights(rg, false, WR, WRh, input_size, hidden_size, axis_0, axis_1);

View File

@ -116,7 +116,7 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
} }
// Remove inputs // Remove inputs
bool pass_required = false; bool pass_required = false;
std::set<ov::Output<ov::Node>> required_inputs; std::set<uint64_t> required_inputs_indices;
auto op_inputs = multi_subgraph_op->input_values(); auto op_inputs = multi_subgraph_op->input_values();
std::vector<std::vector<size_t>> to_remove_descriptors_indexes; std::vector<std::vector<size_t>> to_remove_descriptors_indexes;
to_remove_descriptors_indexes.resize(subgraphs_size); to_remove_descriptors_indexes.resize(subgraphs_size);
@ -133,7 +133,7 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
} else { } else {
// collecting required inputs is needed to detect cases where the input // collecting required inputs is needed to detect cases where the input
// is not needed in a one body, but the other one uses it (for example If case) // is not needed in a one body, but the other one uses it (for example If case)
required_inputs.insert(op_inputs[body_in_descriptors[i]->m_input_index]); // only unique required_inputs_indices.insert(body_in_descriptors[i]->m_input_index);
} }
} }
} }
@ -148,7 +148,9 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
} }
}; };
auto update_op_inputs_desc = [&subgraphs_size](const std::shared_ptr<op::util::MultiSubGraphOp>& op, auto update_op_inputs_desc = [&subgraphs_size](const std::shared_ptr<op::util::MultiSubGraphOp>& op,
std::set<uint64_t>& required_inputs_indices,
uint64_t removed_loop_idx) { uint64_t removed_loop_idx) {
std::set<uint64_t> new_required_inputs_indices;
for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) { for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) {
auto& descriptors = op->get_input_descriptions(static_cast<int>(body_idx)); auto& descriptors = op->get_input_descriptions(static_cast<int>(body_idx));
for (auto& desc : descriptors) { for (auto& desc : descriptors) {
@ -157,6 +159,14 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
} }
} }
} }
for (auto input_index : required_inputs_indices) {
if (input_index > removed_loop_idx) {
new_required_inputs_indices.insert(input_index - 1);
} else {
new_required_inputs_indices.insert(input_index);
}
}
required_inputs_indices = new_required_inputs_indices;
}; };
// Remove dangling body params and input and update input descriptors // Remove dangling body params and input and update input descriptors
for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) { for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) {
@ -174,13 +184,17 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
update_body_param_desc(body_in_descriptors, update_body_param_desc(body_in_descriptors,
body_in_descriptors[desc_idx]->m_body_parameter_index); body_in_descriptors[desc_idx]->m_body_parameter_index);
// remove dangling input of MultiSubGraphOp which was not removed earlier // remove dangling input of MultiSubGraphOp which was not removed earlier
auto& current_input = op_inputs[body_in_descriptors[desc_idx]->m_input_index]; auto current_input_idx = body_in_descriptors[desc_idx]->m_input_index;
if (std::count(std::begin(required_inputs), std::end(required_inputs), current_input) == 0 && auto& current_input = op_inputs[current_input_idx];
// the same input tensor can go to different input ports
if (std::count(std::begin(required_inputs_indices),
std::end(required_inputs_indices),
current_input_idx) == 0 &&
std::count(std::begin(op_inputs), std::end(op_inputs), current_input) > 0) { std::count(std::begin(op_inputs), std::end(op_inputs), current_input) > 0) {
op_inputs.erase(std::next(op_inputs.begin(), body_in_descriptors[desc_idx]->m_input_index)); op_inputs.erase(std::next(op_inputs.begin(), current_input_idx));
// Move all input indexes (in all bodies) which are after these indicated by // Move all input indexes (in all bodies) which are after these indicated by
// to_remove_descriptors_indexes and are not used in any body // to_remove_descriptors_indexes and are not used in any body
update_op_inputs_desc(multi_subgraph_op, body_in_descriptors[desc_idx]->m_input_index); update_op_inputs_desc(multi_subgraph_op, required_inputs_indices, current_input_idx);
} }
} else { } else {
updated_body_in_descriptors.emplace_back(body_in_descriptors[desc_idx]); updated_body_in_descriptors.emplace_back(body_in_descriptors[desc_idx]);

View File

@ -14,6 +14,7 @@
#include "helper_transforms/embedding_segments_feature_fusing.hpp" #include "helper_transforms/embedding_segments_feature_fusing.hpp"
#include "helper_transforms/gru_block_cell_replacer.hpp" #include "helper_transforms/gru_block_cell_replacer.hpp"
#include "helper_transforms/saved_model_unused_remover.hpp" #include "helper_transforms/saved_model_unused_remover.hpp"
#include "helper_transforms/tensor_array_v3_replacer.hpp"
#include "input_model.hpp" #include "input_model.hpp"
#include "op_table.hpp" #include "op_table.hpp"
#include "openvino/core/so_extension.hpp" #include "openvino/core/so_extension.hpp"
@ -491,6 +492,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>(); manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
manager.register_pass<pass::BlockLSTMReplacer>(); manager.register_pass<pass::BlockLSTMReplacer>();
manager.register_pass<pass::GRUBlockCellReplacer>(); manager.register_pass<pass::GRUBlockCellReplacer>();
manager.register_pass<pass::TensorArrayV3Replacer>();
manager.register_pass<pass::ConstToResultRemover>(); manager.register_pass<pass::ConstToResultRemover>();
manager.register_pass<pass::SwitchMergeResolver>(); manager.register_pass<pass::SwitchMergeResolver>();
manager.register_pass<ov::pass::UnrollIf>(); manager.register_pass<ov::pass::UnrollIf>();

View File

@ -5,6 +5,8 @@
#include "helper_ops/merge.hpp" #include "helper_ops/merge.hpp"
#include "common_op_table.hpp" #include "common_op_table.hpp"
#include "helper_ops/enter.hpp"
#include "helper_ops/next_iteration.hpp"
#include "openvino/frontend/tensorflow/node_context.hpp" #include "openvino/frontend/tensorflow/node_context.hpp"
#include "openvino/op/constant.hpp" #include "openvino/op/constant.hpp"
#include "utils.hpp" #include "utils.hpp"
@ -24,20 +26,47 @@ OutputVector translate_merge_op(const NodeContext& node) {
auto node_name = node.get_name(); auto node_name = node.get_name();
default_op_checks(node, 1, {"Merge"}); default_op_checks(node, 1, {"Merge"});
int input_size = static_cast<int>(node.get_input_size()); int input_size = static_cast<int>(node.get_input_size());
OutputVector inputs; OutputVector inputs(input_size);
for (int input_ind = 0; input_ind < input_size; ++input_ind) { for (int input_ind = 0; input_ind < input_size; ++input_ind) {
inputs.push_back(node.get_input(input_ind)); inputs[input_ind] = node.get_input(input_ind);
} }
// if Merge node has just one input, there is nothing to merge // if Merge node has just one input, there is nothing to merge
// return the same input and value_index equal to 0 // return the same input and value_index equal to 0
if (inputs.size() == 1) { if (input_size == 1) {
auto value_index = make_shared<v0::Constant>(element::i32, Shape{}, 0); auto value_index = make_shared<v0::Constant>(element::i32, Shape{}, 0);
value_index->output(0).set_names({node_name + ":1"}); value_index->output(0).set_names({node_name + ":1"});
inputs[0].add_names({node_name + ":0"}); inputs[0].add_names({node_name + ":0"});
return OutputVector{inputs[0], value_index}; return OutputVector{inputs[0], value_index};
} }
// check if it is a case of TF1 While: Enter, NextIteration are going to Merge node
// in this case it can refine output shape and type for NextIteration based on Enter
if (input_size == 2) {
auto enter = as_type_ptr<Enter>(inputs[0].get_node_shared_ptr());
if (!enter) {
enter = as_type_ptr<Enter>(inputs[1].get_node_shared_ptr());
}
auto next_iteration = as_type_ptr<NextIteration>(inputs[0].get_node_shared_ptr());
if (!next_iteration) {
next_iteration = as_type_ptr<NextIteration>(inputs[1].get_node_shared_ptr());
}
if (enter && next_iteration) {
// set output type and shape for NextIteration
// borrow them from Enter output
auto enter_output_type = enter->output(0).get_element_type();
auto enter_output_shape = enter->output(0).get_partial_shape();
auto next_iteration_output_shape = PartialShape::dynamic(enter_output_shape.rank());
next_iteration->set_output_shape_and_type(next_iteration_output_shape, enter_output_type);
// reset inputs
// refines input shapes and types for Merge node
inputs[0] = enter->output(0);
inputs[1] = next_iteration->output(0);
}
}
auto merge_node = make_shared<Merge>(inputs, node.get_decoder()); auto merge_node = make_shared<Merge>(inputs, node.get_decoder());
set_node_name(node.get_name(), merge_node); set_node_name(node.get_name(), merge_node);

View File

@ -0,0 +1,332 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_op_table.hpp"
#include "helper_ops/enter.hpp"
#include "helper_ops/tensor_array.hpp"
#include "openvino/frontend/tensorflow/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::op;
using namespace ov::frontend::tensorflow;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
namespace {
// the function creates the constant imitating initial tensor array container
Output<Node> create_initial_tensor_array_constant(int64_t tensor_element_rank,
const element::Type& element_type,
Output<Node> size,
const string& node_name) {
// adjust size to have it of shape [1] for further concatenation with element shape
auto new_size_shape = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
size = make_shared<v1::Reshape>(size, new_size_shape, false);
// create a vector of size element_shape.rank() with ones
// and compute a shape of initial tensor array [size, 1, ..., 1]
vector<int32_t> ones(tensor_element_rank, 1);
auto ones_const = make_shared<v0::Constant>(element::i32, Shape{ones.size()}, ones);
auto target_shape = make_shared<v0::Concat>(OutputVector{size, ones_const}, 0);
// create initial tensor array
auto scalar_value = make_shared<v0::Constant>(element_type, Shape{}, vector<int32_t>{0});
auto initial_tensor_array = make_shared<v3::Broadcast>(scalar_value, target_shape);
return initial_tensor_array->output(0);
}
} // namespace
OutputVector translate_tensor_array_v3_op(const NodeContext& node) {
// TensorArrayV3 has just one input:
// 0) size to initialize a size of tensor array
default_op_checks(node, 1, {"TensorArrayV3"});
auto dtype = node.get_attribute<element::Type>("dtype");
auto size = node.get_input(0);
auto element_shape = node.get_attribute<PartialShape>("element_shape");
if (element_shape.rank().is_static()) {
auto node_name = node.get_name();
auto new_output1 =
create_initial_tensor_array_constant(element_shape.rank().get_length(), dtype, size, node.get_name());
new_output1.set_names({node_name + ":0"});
auto new_output2 =
create_initial_tensor_array_constant(element_shape.rank().get_length(), dtype, size, node.get_name());
new_output2.set_names({node_name + ":1"});
return OutputVector{new_output1, new_output2};
}
// dynamic case when it is unable retrieve element rank from the attribute
auto tensor_array_v3 = make_shared<TensorArrayV3>(size, dtype, node.get_decoder());
set_node_name(node.get_name(), tensor_array_v3);
return tensor_array_v3->outputs();
}
OutputVector translate_tensor_array_scatter_v3_op(const NodeContext& node) {
// TensorArrayScatterV3 has four inputs:
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
// 1) indices, a Tensor of type int32. The locations at which to write the tensor elements.
// 2) value, a Tensor. The concatenated tensor to write to the TensorArray
// 3) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
// The operation has one output:
// 0) flow_out indicates that operation is complete and handle resource is updated
default_op_checks(node, 4, {"TensorArrayScatterV3"});
auto indices = node.get_input(1);
auto value = node.get_input(2);
// flow_in is used for transferring input tensor array
auto tensor_array = node.get_input(3);
// check if producer of tensor_array is TensorArrayV3, internal operation, still
// if yes, try to replace it with constant container
if (as_type_ptr<TensorArrayV3>(tensor_array.get_node_shared_ptr()) &&
value.get_partial_shape().rank().is_static()) {
// set tensor element rank that gets known from TensorArrayScatterV3 operation
auto tensor_array_v3 = as_type_ptr<TensorArrayV3>(tensor_array.get_node_shared_ptr());
TENSORFLOW_OP_VALIDATION(
node,
value.get_partial_shape().rank().get_length() > 0,
"[TensorFlow Frontend] internal error or inconsistent model: value to TensorArrayScatterV3 is a scalar");
int64_t tensor_element_rank = value.get_partial_shape().rank().get_length() - 1;
tensor_array_v3->set_element_rank(tensor_element_rank);
}
// compute element shape (shape of a tensor in the tensor array) using value
auto element_shape = make_shared<v3::ShapeOf>(value, element::i32)->output(0);
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
auto max_const = make_shared<v0::Constant>(element::i32, Shape{1}, numeric_limits<int32_t>::max());
element_shape = make_shared<v8::Slice>(element_shape, one_const, max_const, one_const);
// compute size of tensor array
auto tensor_array_size = make_shared<v3::ShapeOf>(tensor_array, element::i32)->output(0);
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
tensor_array_size = make_shared<v8::Gather>(tensor_array_size, zero_const, zero_const);
// compute the new shape for tensor array where new tensors will be inserted
auto new_shape = make_shared<v0::Concat>(OutputVector{tensor_array_size, element_shape}, 0);
tensor_array = make_shared<v3::Broadcast>(tensor_array, new_shape);
// adjust indices for ScatterNDUpdate to have a shape [N, 1] where N is a number of indices
indices = make_shared<v0::Unsqueeze>(indices, one_const);
// compute updated tensor array using ScatterNDUpdate
// value should be of a shape [N, <elem_shape>]
auto updated_tensor_array = make_shared<v3::ScatterNDUpdate>(tensor_array, indices, value);
set_node_name(node.get_name(), updated_tensor_array);
// TensorArrayScatterV3 has just one output flow_out
// that is used for transferring updated tensor array
return {updated_tensor_array};
}
OutputVector translate_tensor_array_read_v3_op(const NodeContext& node) {
// TensorArrayReadV3 read an element from the TensorArray into the output
// and it has three inputs:
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
// 1) index, a Tensor of type int32. The location from which to read the value
// 2) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
// The operation has one output
// 0) read value from tensor array
default_op_checks(node, 3, {"TensorArrayReadV3"});
auto index = node.get_input(1);
// flow_in is used for transferring input tensor array
auto tensor_array = node.get_input(2);
auto dtype = node.get_attribute<element::Type>("dtype");
// adjust the index to a scalar for using Gather operation
auto new_shape = make_shared<v0::Constant>(element::i32, Shape{0}, vector<int32_t>{});
index = make_shared<v1::Reshape>(index, new_shape, false);
// gather tensor element by the required position
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
Output<Node> tensor_element = make_shared<v8::Gather>(tensor_array, index, gather_axis);
tensor_element = make_shared<v0::Convert>(tensor_element, dtype);
set_node_name(node.get_name(), tensor_element.get_node_shared_ptr());
return {tensor_element};
}
OutputVector translate_tensor_array_close_v3_op(const NodeContext& node) {
// TensorArrayCloseV3 deletes the TensorArray from its resource container
// it outputs nothing
default_op_checks(node, 1, {"TensorArrayCloseV3"});
return {};
}
OutputVector translate_tensor_array_size_v3_op(const NodeContext& node) {
// TensorArraySizeV3 gets the current size of the TensorArray
// it outputs int32 scalar equal to a size of the tensor array
default_op_checks(node, 2, {"TensorArraySizeV3"});
// skip the handle by the first input
auto tensor_array = node.get_input(1);
auto size = make_shared<v3::ShapeOf>(tensor_array, element::i32)->output(0);
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
size = make_shared<v8::Gather>(size, zero_const, zero_const);
// size must be scalar
auto scalar_shape = make_shared<v0::Constant>(element::i32, Shape{0}, vector<int32_t>{});
size = make_shared<v1::Reshape>(size, scalar_shape, false);
set_node_name(node.get_name(), size.get_node_shared_ptr());
return {size};
}
OutputVector translate_tensor_array_gather_v3_op(const NodeContext& node) {
// TensorArrayGatherV3 gathers specific elements from the TensorArray into output
// and it has three inputs:
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
// 1) indices, a Tensor of type int32. The location from which to read tensor elements
// 2) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
// The operation has one output
// 0) value with read tensor elements
// it outputs int32 scalar equal to a size of the tensor array
default_op_checks(node, 3, {"TensorArrayGatherV3"});
// skip the handle by the first input
auto indices = node.get_input(1);
// flow_in serves for transferring tensor array
// handle input is ignored
auto tensor_array = node.get_input(2);
auto dtype = node.get_attribute<element::Type>("dtype");
auto element_shape = node.get_attribute<PartialShape>("element_shape", PartialShape::dynamic());
// gather tensor element by the required position
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
Output<Node> tensor_element = make_shared<v8::Gather>(tensor_array, indices, gather_axis);
tensor_element = make_shared<v0::Convert>(tensor_element, dtype);
// concretize tensor_element shape if this is specified
if (tensor_element.get_partial_shape().rank().is_dynamic() && element_shape.is_static()) {
auto element_shape_value = element_shape.get_shape();
auto element_shape_const =
make_shared<v0::Constant>(element::i32, Shape{element_shape_value.size()}, element_shape_value);
auto size = make_shared<v3::ShapeOf>(tensor_array, element::i32)->output(0);
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
size = make_shared<v8::Gather>(size, zero_const, zero_const);
auto new_shape = make_shared<v0::Concat>(OutputVector{size, element_shape_const}, 0);
tensor_element = make_shared<v1::Reshape>(tensor_element, new_shape, false);
}
set_node_name(node.get_name(), tensor_element.get_node_shared_ptr());
return {tensor_element};
}
OutputVector translate_tensor_array_concat_v3_op(const NodeContext& node) {
// TensorArrayConcatV3 Concat the elements from the TensorArray into value
// and it has two inputs:
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
// 1) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
// The operation has one output
// 0) concatenated value by the first dimension
default_op_checks(node, 2, {"TensorArrayConcatV3"});
// flow_in serves for transferring tensor array
// handle input is ignored
auto tensor_array = node.get_input(1);
auto dtype = node.get_attribute<element::Type>("dtype");
// since tensor array saves tensor elements in the concatenated form by the first dimension
// and for this operation they should be concatenated by the first dimension of the tensor element
// it needs to combine the first two dimensions
// tensor array is of shape [k, n0, n1, ..., nd]
// 1. compute element shape excluding the first dimension
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
auto two_const = make_shared<v0::Constant>(element::i32, Shape{1}, 2);
auto max_const = make_shared<v0::Constant>(element::i32, Shape{1}, numeric_limits<int32_t>::max());
auto tensor_array_shape = make_shared<v3::ShapeOf>(tensor_array, element::i64);
auto element_shape_no_two_dims = make_shared<v8::Slice>(tensor_array_shape, two_const, max_const, one_const);
// 2. compute the first and second dimensions k and n0
auto k = make_shared<v8::Gather>(tensor_array_shape, zero_const, zero_const);
auto n0 = make_shared<v8::Gather>(tensor_array_shape, one_const, zero_const);
auto k_by_n0 = make_shared<v1::Multiply>(k, n0);
// 3. compute the first output containing concatenated tensor elements
// it folds the first and second dimensions
auto new_shape = make_shared<v0::Concat>(OutputVector{k_by_n0, element_shape_no_two_dims}, 0);
auto concatenated_array = make_shared<v1::Reshape>(tensor_array, new_shape, false)->output(0);
concatenated_array = make_shared<v0::Convert>(concatenated_array, dtype);
concatenated_array.set_names({node.get_name() + ":0"});
// 4. compute the second output with length of each tensor element for the concatenation
auto lengths = make_shared<v3::Broadcast>(n0, k)->output(0);
lengths.set_names({node.get_name() + ":1"});
return {concatenated_array, lengths};
}
OutputVector translate_tensor_array_write_v3_op(const NodeContext& node) {
// TensorArrayWriteV3 pushes an element onto the tensor_array.
// and it has four inputs
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
// 1) index, a Tensor of type int32. The location where to write tensor element
// 2) value, a Tensor. The tensor to write at the specified location
// 3) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
// The operation has one output
// 0) read value from tensor array
default_op_checks(node, 4, {"TensorArrayWriteV3"});
auto handle = node.get_input(0);
auto index = node.get_input(1);
auto value = node.get_input(2);
// flow_in is used for transferring input tensor array
// tensor array has a rank equal to 1 + rank(element of tensor array)
// if it just initialized, its shape is equal to [tensor_array_size, 1, ..., 1]
// otherwise, it is equal to [tensor_array_size, <element shape>]
auto tensor_array = node.get_input(3);
// reshape index to have it of [1] shape
auto new_index_shape = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
index = make_shared<v1::Reshape>(index, new_index_shape, false);
if (auto enter = as_type_ptr<Enter>(handle.get_node_shared_ptr())) {
if (as_type_ptr<TensorArrayV3>(enter->input_value(0).get_node_shared_ptr()) &&
value.get_partial_shape().rank().is_static()) {
// set tensor element rank that gets known from TensorArrayWriteV3 operation
auto tensor_array_v3 = as_type_ptr<TensorArrayV3>(enter->input_value(0).get_node_shared_ptr());
int64_t tensor_element_rank = value.get_partial_shape().rank().get_length();
tensor_array_v3->set_element_rank(tensor_element_rank);
}
}
// compute element shape in the input tensor array
auto tensor_array_shape = make_shared<v3::ShapeOf>(tensor_array, element::i32);
// compute the current size of tensor array
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto tensor_array_size = make_shared<v8::Gather>(tensor_array_shape, zero_const, zero_const);
// adjust tensor array to have the correct shape [size, <real element shape>] before value insertion
auto element_shape = make_shared<v3::ShapeOf>(value, element::i32);
auto new_tensor_array_shape = make_shared<v0::Concat>(OutputVector{tensor_array_size, element_shape}, 0);
tensor_array = make_shared<v3::Broadcast>(tensor_array, new_tensor_array_shape);
// update the resulted tensor using ScatterUpdate
value = make_shared<v0::Unsqueeze>(value, zero_const);
auto scatter_update = make_shared<v3::ScatterUpdate>(tensor_array, index, value, zero_const);
set_node_name(node.get_name(), scatter_update);
// use flow_out for transferring updated tensor array
return {scatter_update};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -46,6 +46,14 @@ TF_OP_CONVERTER(translate_sparse_segment_sum_op);
TF_OP_CONVERTER(translate_staticregexfullmatch_op); TF_OP_CONVERTER(translate_staticregexfullmatch_op);
TF_OP_CONVERTER(translate_stringjoin_op); TF_OP_CONVERTER(translate_stringjoin_op);
TF_OP_CONVERTER(translate_switch_op); TF_OP_CONVERTER(translate_switch_op);
TF_OP_CONVERTER(translate_tensor_array_close_v3_op);
TF_OP_CONVERTER(translate_tensor_array_concat_v3_op);
TF_OP_CONVERTER(translate_tensor_array_gather_v3_op);
TF_OP_CONVERTER(translate_tensor_array_read_v3_op);
TF_OP_CONVERTER(translate_tensor_array_scatter_v3_op);
TF_OP_CONVERTER(translate_tensor_array_size_v3_op);
TF_OP_CONVERTER(translate_tensor_array_v3_op);
TF_OP_CONVERTER(translate_tensor_array_write_v3_op);
TF_OP_CONVERTER(translate_varhandle_op); TF_OP_CONVERTER(translate_varhandle_op);
TF_OP_CONVERTER(translate_variable_op); TF_OP_CONVERTER(translate_variable_op);
TF_OP_CONVERTER(translate_varisinitialized_op); TF_OP_CONVERTER(translate_varisinitialized_op);
@ -174,6 +182,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Gather", CreatorFunction(translate_gather_op)}, {"Gather", CreatorFunction(translate_gather_op)},
{"GatherV2", CreatorFunction(translate_gather_v2_op)}, {"GatherV2", CreatorFunction(translate_gather_v2_op)},
{"GatherNd", CreatorFunction(translate_gather_nd_op)}, {"GatherNd", CreatorFunction(translate_gather_nd_op)},
{"GatherTree", CreatorFunction(translate_gather_tree_op)},
{"Addons>GatherTree", CreatorFunction(translate_gather_tree_op)},
{"HashTable", CreatorFunction(translate_hash_table_op)}, {"HashTable", CreatorFunction(translate_hash_table_op)},
{"HashTableV2", CreatorFunction(translate_hash_table_op)}, {"HashTableV2", CreatorFunction(translate_hash_table_op)},
{"Identity", CreatorFunction(translate_identity_op)}, {"Identity", CreatorFunction(translate_identity_op)},
@ -269,6 +279,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"StatelessWhile", CreatorFunction(translate_while_op)}, {"StatelessWhile", CreatorFunction(translate_while_op)},
{"StridedSlice", CreatorFunction(translate_strided_slice_op)}, {"StridedSlice", CreatorFunction(translate_strided_slice_op)},
{"Switch", CreatorFunction(translate_switch_op)}, {"Switch", CreatorFunction(translate_switch_op)},
{"TensorArrayCloseV3", CreatorFunction(translate_tensor_array_close_v3_op)},
{"TensorArrayConcatV3", CreatorFunction(translate_tensor_array_concat_v3_op)},
{"TensorArrayGatherV3", CreatorFunction(translate_tensor_array_gather_v3_op)},
{"TensorArrayReadV3", CreatorFunction(translate_tensor_array_read_v3_op)},
{"TensorArrayScatterV3", CreatorFunction(translate_tensor_array_scatter_v3_op)},
{"TensorArraySizeV3", CreatorFunction(translate_tensor_array_size_v3_op)},
{"TensorArrayV3", CreatorFunction(translate_tensor_array_v3_op)},
{"TensorArrayWriteV3", CreatorFunction(translate_tensor_array_write_v3_op)},
{"TensorListFromTensor", CreatorFunction(translate_tensor_list_from_tensor_op)}, {"TensorListFromTensor", CreatorFunction(translate_tensor_list_from_tensor_op)},
{"TensorListGetItem", CreatorFunction(translate_tensor_list_get_item_op)}, {"TensorListGetItem", CreatorFunction(translate_tensor_list_get_item_op)},
{"TensorListLength", CreatorFunction(translate_tensor_list_length_op)}, {"TensorListLength", CreatorFunction(translate_tensor_list_length_op)},

View File

@ -423,7 +423,7 @@ shared_ptr<v5::Loop> create_loop_for_tf_while(const std::string& while_node_name
FRONT_END_GENERAL_CHECK( FRONT_END_GENERAL_CHECK(
cond_results.size() == 1 && cond_results[0], cond_results.size() == 1 && cond_results[0],
"[TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node."); "[TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node.");
auto body_condition_output_idx = static_cast<int64_t>(body_results.size()); auto body_condition_output_idx = body_results.size();
body_model->add_results(cond_results); body_model->add_results(cond_results);
// type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present // type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present
@ -435,14 +435,18 @@ shared_ptr<v5::Loop> create_loop_for_tf_while(const std::string& while_node_name
loop->set_function(body_model); loop->set_function(body_model);
// body_results may contain less nodes than body_params that means back edge exists not for all body_params // body_results may contain less nodes than body_params that means back edge exists not for all body_params
for (size_t input_ind = 0; input_ind < static_cast<size_t>(body_condition_output_idx); ++input_ind) { for (size_t input_ind = 0; input_ind < body_condition_output_idx; ++input_ind) {
loop->set_merged_input(body_params[input_ind], ov_inputs[input_ind], body_results[input_ind]->input_value(0)); loop->set_merged_input(body_params[input_ind], ov_inputs[input_ind], body_results[input_ind]->input_value(0));
} }
loop->set_special_body_ports({-1, body_condition_output_idx}); loop->set_special_body_ports({-1, static_cast<int64_t>(body_condition_output_idx)});
// set invariant inputs for the loop
for (size_t input_ind = body_condition_output_idx; input_ind < input_size; ++input_ind) {
loop->set_invariant_input(body_params[input_ind], ov_inputs[input_ind]);
}
// set external outputs for Loop node // set external outputs for Loop node
// do not get execution condition outside of the Loop node // do not get execution condition outside of the Loop node
for (size_t output_ind = 0; output_ind < static_cast<size_t>(body_condition_output_idx); ++output_ind) { for (size_t output_ind = 0; output_ind < body_condition_output_idx; ++output_ind) {
loop->get_iter_value(body_results[output_ind]); loop->get_iter_value(body_results[output_ind]);
} }
loop->validate_and_infer_types(); loop->validate_and_infer_types();

View File

@ -15,7 +15,8 @@ static const std::vector<std::string> models{
std::string("2in_2out/2in_2out.pb"), std::string("2in_2out/2in_2out.pb"),
std::string("forward_edge_model/forward_edge_model.pbtxt"), std::string("forward_edge_model/forward_edge_model.pbtxt"),
std::string("forward_edge_model2/forward_edge_model2.pbtxt"), std::string("forward_edge_model2/forward_edge_model2.pbtxt"),
std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt")}; std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt"),
std::string("gather_tree_model/gather_tree_model.pbtxt")};
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest, INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
FrontEndConvertModelTest, FrontEndConvertModelTest,

View File

@ -0,0 +1,103 @@
node {
name: "step_ids"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 20
}
dim {
size: 2
}
dim {
size: 30
}
}
}
}
}
node {
name: "parent_ids"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 20
}
dim {
size: 2
}
dim {
size: 30
}
}
}
}
}
node {
name: "max_seq_len"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
}
}
}
}
node {
name: "end_token"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "Addons>GatherTree"
op: "Addons>GatherTree"
input: "step_ids"
input: "parent_ids"
input: "max_seq_len"
input: "end_token"
attr {
key: "T"
value {
type: DT_INT32
}
}
}

View File

@ -72,6 +72,7 @@ OP_CONVERTER_NAMED(translate_fused_batch_norm_op);
OP_CONVERTER(translate_gather_op); OP_CONVERTER(translate_gather_op);
OP_CONVERTER(translate_gather_v2_op); OP_CONVERTER(translate_gather_v2_op);
OP_CONVERTER(translate_gather_nd_op); OP_CONVERTER(translate_gather_nd_op);
OP_CONVERTER(translate_gather_tree_op);
OP_CONVERTER(translate_identity_op); OP_CONVERTER(translate_identity_op);
OP_CONVERTER(translate_identity_n_op); OP_CONVERTER(translate_identity_n_op);
OP_CONVERTER(translate_input_arg_op); OP_CONVERTER(translate_input_arg_op);

View File

@ -33,20 +33,34 @@ public:
ov::PartialShape output_data_shape = ov::PartialShape::dynamic(); ov::PartialShape output_data_shape = ov::PartialShape::dynamic();
auto input_size = get_input_size(); auto input_size = get_input_size();
bool merge_output_shape = true;
for (size_t input_ind = 0; input_ind < input_size; ++input_ind) { for (size_t input_ind = 0; input_ind < input_size; ++input_ind) {
auto input_type = get_input_element_type(input_ind); auto input_type = get_input_element_type(input_ind);
if (input_type.is_static()) { if (input_type.is_static()) {
output_data_type = input_type; output_data_type = input_type;
} }
// check if it still needs to merge input shapes auto input_shape = get_input_partial_shape(input_ind);
// if yes, it tries to merge them if (input_shape.rank().is_dynamic()) {
if (merge_output_shape && continue;
!PartialShape::merge_into(output_data_shape, get_input_partial_shape(input_ind))) { }
merge_output_shape = false;
// reset output shape to dynamic rank if (output_data_shape.rank().is_dynamic()) {
// firstly met shape of static rank
// immediately use this shape of static rank
output_data_shape = input_shape;
} else if (output_data_shape.rank().is_static() &&
output_data_shape.rank().get_length() != input_shape.rank().get_length()) {
// different inputs have different rank means output must be of a dynamic rank
output_data_shape = ov::PartialShape::dynamic(); output_data_shape = ov::PartialShape::dynamic();
break;
} else {
auto output_rank = output_data_shape.rank().get_length();
for (int64_t dim_ind = 0; dim_ind < output_rank; ++dim_ind) {
if (input_shape[dim_ind] != output_data_shape[dim_ind]) {
// different inputs can have different dimensions so it must combine them
output_data_shape[dim_ind] = ov::Dimension::dynamic();
}
}
} }
} }

View File

@ -43,6 +43,10 @@ public:
producer_output_port_idx = m_producer_output_port_idx; producer_output_port_idx = m_producer_output_port_idx;
} }
void set_output_shape_and_type(const ov::PartialShape& output_shape, const ov::element::Type& output_type) {
set_output_type(0, output_type, output_shape);
}
private: private:
bool m_back_edge_set; bool m_back_edge_set;
std::string m_producer_name; std::string m_producer_name;

View File

@ -0,0 +1,60 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include "internal_operation.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
// Internal operation for TensorArrayV3
// An array of Tensors of given size
// It has two outputs:
// 1. handle - resource (a reference) for tensor array
// 2. flow_out - float type will be used for storing tensor array
class TensorArrayV3 : public InternalOperation {
public:
OPENVINO_OP("TensorArrayV3", "ov::frontend::tensorflow", InternalOperation);
TensorArrayV3(const Output<Node>& size,
const ov::element::Type element_type,
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, OutputVector{size}, 2, "TensorArrayV3"),
m_element_type(element_type),
m_element_rank(-1) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
set_output_type(0, m_element_type, ov::PartialShape::dynamic());
set_output_type(1, m_element_type, ov::PartialShape::dynamic());
}
ov::element::Type get_element_type() const {
return m_element_type;
}
int64_t get_element_rank() const {
return m_element_rank;
}
void set_element_rank(int64_t element_rank) {
FRONT_END_GENERAL_CHECK(
element_rank >= 0,
"[TensorFlow Frontend] internal error: negavite element rank tries to set for TensorArrayV3");
m_element_rank = element_rank;
}
private:
ov::element::Type m_element_type;
int64_t m_element_rank;
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,29 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <utility>
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
namespace pass {
// This transformation replaces internal operation TensorArrayV3 with a Constant
// that simulates initial state of tensor array container
class TensorArrayV3Replacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorArrayV3Replacer");
TensorArrayV3Replacer();
};
} // namespace pass
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,71 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_transforms/tensor_array_v3_replacer.hpp"
#include "helper_ops/tensor_array.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::op;
using namespace ov::pass;
ov::frontend::tensorflow::pass::TensorArrayV3Replacer::TensorArrayV3Replacer() {
auto tensor_array_v3 = pattern::wrap_type<TensorArrayV3>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
NodeRegistry rg;
auto tensor_array_v3 = dynamic_pointer_cast<TensorArrayV3>(m.get_match_root());
if (!tensor_array_v3) {
return false;
}
int32_t tensor_element_rank = static_cast<int32_t>(tensor_array_v3->get_element_rank());
if (tensor_element_rank < 0) {
return false;
}
// retrieve all TensorArrayV3 inputs
auto size = tensor_array_v3->input_value(0);
auto element_type = tensor_array_v3->get_element_type();
// adjust size to have it of shape [1] for further concatenation with element shape
auto new_size_shape = rg.make<v0::Constant>(element::i32, Shape{1}, 1);
auto new_size = rg.make<v1::Reshape>(size, new_size_shape, false);
// create a vector of size element_shape.rank() with ones
// and compute a shape of initial tensor array [size, 1, ..., 1]
Output<Node> target_shape;
if (tensor_element_rank == 0) {
target_shape = new_size->output(0);
} else {
vector<int32_t> ones(tensor_element_rank, 1);
auto ones_const = rg.make<v0::Constant>(element::i32, Shape{ones.size()}, ones);
target_shape = rg.make<v0::Concat>(OutputVector{new_size, ones_const}, 0)->output(0);
}
// create initial tensor array
auto scalar_value = make_shared<v0::Constant>(element_type, Shape{}, vector<int32_t>{0});
auto initial_tensor_array = make_shared<v3::Broadcast>(scalar_value, target_shape);
// preserve names of the node and the output tensor
initial_tensor_array->set_friendly_name(tensor_array_v3->get_friendly_name());
copy_runtime_info(tensor_array_v3, rg.get());
ov::replace_node(tensor_array_v3,
ov::OutputVector{initial_tensor_array->output(0), initial_tensor_array->output(0)});
return true;
};
auto m =
std::make_shared<pattern::Matcher>(tensor_array_v3, "ov::frontend::tensorflow::pass::TensorArrayV3Replacer");
register_matcher(m, callback);
}

View File

@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/gather_tree.hpp"
#include "common_op_table.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reshape.hpp"
using namespace std;
using namespace ov::op;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_gather_tree_op(const NodeContext& node) {
default_op_checks(node, 4, {"GatherTree", "Addons>GatherTree"});
auto step_ids = node.get_input(0);
auto parent_ids = node.get_input(1);
auto max_sequence_lengths = node.get_input(2);
auto end_token = node.get_input(3);
// adjust end_token that must be a scalar
auto new_shape_end_token = make_shared<v0::Constant>(element::i32, Shape{0}, vector<int32_t>{});
end_token = make_shared<v1::Reshape>(end_token, new_shape_end_token, false);
auto gather_tree = make_shared<v1::GatherTree>(step_ids, parent_ids, max_sequence_lengths, end_token);
set_node_name(node.get_name(), gather_tree);
return {gather_tree};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -513,7 +513,7 @@ void TensorIterator::createPrimitive() {
lastUsedCond = initial_cond_check->getStatus(); lastUsedCond = initial_cond_check->getStatus();
} }
if (isDynamicNode()) if (runAsDynamic())
prepareDynamicBuffers(); prepareDynamicBuffers();
Node::createPrimitive(); Node::createPrimitive();
@ -556,7 +556,7 @@ void TensorIterator::prepareParams() {
prepareContinueCond(); prepareContinueCond();
prepareLoopBodyCurrentIteration(); prepareLoopBodyCurrentIteration();
if (!isDynamicNode()) { if (!runAsDynamic()) {
prepareOutputPorts(); prepareOutputPorts();
prepareBackEdges(); prepareBackEdges();
} }
@ -568,6 +568,12 @@ void TensorIterator::prepareParams() {
} }
void TensorIterator::execute(dnnl::stream strm) { void TensorIterator::execute(dnnl::stream strm) {
//Special case, the subgraph is dynamic while the node has all static shapes
if (runAsDynamic()) {
executeDynamicImpl(strm);
return;
}
sub_graph.ResetInferCount(); sub_graph.ResetInferCount();
bool continue_cond = initial_cond_check->getStatus(); bool continue_cond = initial_cond_check->getStatus();
@ -872,6 +878,10 @@ int TensorIterator::getNumIteration(const std::vector<PortMap>& inputPortMap, co
return numIterations; return numIterations;
} }
bool TensorIterator::runAsDynamic() const {
return isDynamicNode() || Graph::Status::ReadyDynamic == sub_graph.getStatus();
}
bool TensorIterator::created() const { bool TensorIterator::created() const {
return getType() == Type::TensorIterator; return getType() == Type::TensorIterator;
} }

View File

@ -138,6 +138,7 @@ private:
void reshapeAndFillOutput(dnnl::stream strm); void reshapeAndFillOutput(dnnl::stream strm);
bool checkForInputAndBodyShapesInequality() const; bool checkForInputAndBodyShapesInequality() const;
int getNumIteration(const std::vector<PortMap>& inputPortMap, const std::vector<PortMap>& outputPortMap) const; int getNumIteration(const std::vector<PortMap>& inputPortMap, const std::vector<PortMap>& outputPortMap) const;
bool runAsDynamic() const;
ExtensionManager::Ptr ext_mng; ExtensionManager::Ptr ext_mng;
Graph sub_graph; Graph sub_graph;

View File

@ -371,6 +371,65 @@ protected:
} }
}; };
class StaticLoopDynamicSubgraphCPUTest : public SubgraphBaseTest {
void SetUp() override {
InputShape input_shape = {{25, 1, 1}, {{25, 1, 1}}};
InputShape input_exec_flag_shape = {{1}, {{1}}};
targetDevice = ov::test::utils::DEVICE_CPU;
ElementType netType = ov::element::f32;
init_input_shapes({input_shape, input_exec_flag_shape});
ov::ParameterVector params;
params.push_back(std::make_shared<ov::op::v0::Parameter>(netType, inputDynamicShapes[0]));
// exec_condition
params.push_back(std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[1]));
auto trip_count_input = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 2);
auto body_condition_const = std::make_shared<ov::op::v0::Constant>(ov::element::boolean, ov::Shape{1}, true);
// Body parameters
ov::ParameterVector body_params = {std::make_shared<ov::op::v0::Parameter>(netType, ov::PartialShape{25, 1, -1})};
// Body
auto broadcast_target_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{25, 1, 256});
auto broadcast_axis_mapping = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 0);
auto broadcast = std::make_shared<ov::op::v3::Broadcast>(body_params[0], broadcast_target_shape);
auto body = std::make_shared<ov::Model>(ov::OutputVector{body_condition_const, broadcast}, body_params);
auto loop = std::make_shared<ov::op::v5::Loop>(trip_count_input, params[1]);
loop->set_function(body);
loop->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 0});
loop->set_merged_input(body_params.front(), params.front(), broadcast);
auto out0 = loop->get_iter_value(body_condition_const, -1);
auto out1 = loop->get_iter_value(broadcast, -1);
auto result0 = std::make_shared<ov::op::v0::Result>(out0);
auto result1 = std::make_shared<ov::op::v0::Result>(out1);
function = std::make_shared<ov::Model>(ov::ResultVector{result0, result1}, params, "loop");
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (size_t i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
if (i == 1) {
tensor = ov::Tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
auto* dataPtr = tensor.data<bool>();
*dataPtr = true;
} else {
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 2560, 0, 256);
}
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
}
}
};
TEST_P(LoopLayerCPUTest, CompareWithRefs) { TEST_P(LoopLayerCPUTest, CompareWithRefs) {
run(); run();
} }
@ -387,6 +446,10 @@ TEST_P(LoopForConcatLayerCPUTest, CompareWithRefs) {
run(); run();
} }
TEST_F(StaticLoopDynamicSubgraphCPUTest, smoke_StaticLoopWithDynSubgraph) {
run();
}
namespace { namespace {
const std::vector<ElementType> inputPrecisions = { const std::vector<ElementType> inputPrecisions = {

View File

@ -98,7 +98,7 @@ def summarize_graph(model_path, output_nodes_for_freeze=None, reshape_net=None):
variables = list() variables = list()
outputs = list() outputs = list()
graph = load_graph(model_path, output_nodes_for_freeze) graph = load_graph(model_path, output_nodes_for_freeze)
unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert', 'switch_t', 'switch_f'] unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert', 'switch_t', 'switch_f', 'TensorArrayCloseV3']
control_dependents_map = collect_control_dependencies(graph) control_dependents_map = collect_control_dependencies(graph)
for node in graph.as_graph_def().node: for node in graph.as_graph_def().node:
if node.op == 'Placeholder': if node.op == 'Placeholder':

View File

@ -0,0 +1,200 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
def create_tensor_array(data_shape, data_type):
size = data_shape[0]
data = tf.compat.v1.placeholder(data_type, data_shape, 'data')
indices = tf.compat.v1.placeholder(tf.int32, [size], 'indices')
size_const = tf.constant(size, dtype=tf.int32, shape=[])
handle, flow = tf.raw_ops.TensorArrayV3(size=size_const, dtype=tf.as_dtype(data_type))
flow = tf.raw_ops.TensorArrayScatterV3(handle=handle, indices=indices, value=data, flow_in=flow)
return handle, flow
class TestTensorArraySizeV3(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'data' in inputs_info
assert 'indices' in inputs_info
data_shape = inputs_info['data']
inputs_data = {}
rng = np.random.default_rng()
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
inputs_data['indices'] = rng.permutation(self.size).astype(np.int32)
return inputs_data
def create_tensor_array_size_v3(self, data_shape, data_type):
size = data_shape[0]
self.data_type = data_type
self.size = size
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
handle, flow = create_tensor_array(data_shape, data_type)
tf.raw_ops.TensorArraySizeV3(handle=handle, flow_in=flow)
tf.raw_ops.TensorArrayCloseV3(handle=handle)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(data_shape=[5], data_type=np.float32),
dict(data_shape=[10, 20, 30], data_type=np.int32),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_tensor_array_size_v3(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_tensor_array_size_v3(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestTensorArrayReadV3(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'data' in inputs_info
assert 'indices' in inputs_info
data_shape = inputs_info['data']
inputs_data = {}
rng = np.random.default_rng()
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
inputs_data['index_to_read'] = rng.integers(0, data_shape[0], []).astype(np.int32)
inputs_data['indices'] = rng.permutation(self.size).astype(np.int32)
return inputs_data
def create_tensor_array_read_v3(self, data_shape, data_type):
size = data_shape[0]
self.data_type = data_type
self.size = size
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
handle, flow = create_tensor_array(data_shape, data_type)
index_to_read = tf.compat.v1.placeholder(tf.int32, [], 'index_to_read')
tf.raw_ops.TensorArrayReadV3(handle=handle, index=index_to_read, flow_in=flow,
dtype=tf.dtypes.as_dtype(data_type))
tf.raw_ops.TensorArrayCloseV3(handle=handle)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(data_shape=[6], data_type=np.float32),
dict(data_shape=[8, 5, 6, 10], data_type=np.int32),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_tensor_array_read_v3(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_tensor_array_read_v3(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestTensorArrayWriteGatherV3(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'data' in inputs_info
assert 'indices' in inputs_info
assert 'value_to_write' in inputs_info
data_shape = inputs_info['data']
value_shape = inputs_info['value_to_write']
inputs_data = {}
rng = np.random.default_rng()
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
inputs_data['value_to_write'] = rng.integers(-10, 10, value_shape).astype(self.data_type)
indices_data = rng.permutation(self.size).astype(np.int32)
inputs_data['indices'] = np.delete(indices_data, np.where(indices_data == self.index_to_write))
return inputs_data
def create_tensor_array_write_v3(self, size, data_shape, data_type, index_to_write, indices_to_gather):
self.data_type = data_type
self.size = size
self.index_to_write = index_to_write
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
value_to_write = tf.compat.v1.placeholder(data_type, data_shape[1:], 'value_to_write')
index_to_write_const = tf.constant(index_to_write, dtype=tf.int32, shape=[])
indices_to_gather_const = tf.constant(indices_to_gather, dtype=tf.int32, shape=[len(indices_to_gather)])
data = tf.compat.v1.placeholder(data_type, data_shape, 'data')
indices = tf.compat.v1.placeholder(tf.int32, [size - 1], 'indices')
size_const = tf.constant(size, dtype=tf.int32, shape=[])
handle, flow = tf.raw_ops.TensorArrayV3(size=size_const, dtype=tf.as_dtype(data_type))
flow = tf.raw_ops.TensorArrayScatterV3(handle=handle, indices=indices, value=data, flow_in=flow)
flow = tf.raw_ops.TensorArrayWriteV3(handle=handle, index=index_to_write_const,
value=value_to_write, flow_in=flow)
tf.raw_ops.TensorArrayGatherV3(handle=handle, indices=indices_to_gather_const, flow_in=flow,
dtype=tf.dtypes.as_dtype(data_type))
tf.raw_ops.TensorArrayCloseV3(handle=handle)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(size=7, data_shape=[6], data_type=np.float32, index_to_write=3, indices_to_gather=[0, 3, 1]),
dict(size=10, data_shape=[9, 2, 4], data_type=np.int32, index_to_write=2, indices_to_gather=[2, 1, 4, 3]),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_tensor_array_write_v3(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_tensor_array_write_v3(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
class TestTensorArrayConcatV3(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'data' in inputs_info
assert 'indices' in inputs_info
data_shape = inputs_info['data']
inputs_data = {}
rng = np.random.default_rng()
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
inputs_data['indices'] = rng.permutation(self.size).astype(np.int32)
return inputs_data
def create_tensor_array_concat_v3(self, data_shape, data_type):
size = data_shape[0]
self.data_type = data_type
self.size = size
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
handle, flow = create_tensor_array(data_shape, data_type)
tensor_array_concat_v3 = tf.raw_ops.TensorArrayConcatV3(handle=handle, flow_in=flow,
dtype=tf.as_dtype(data_type))
tf.identity(tensor_array_concat_v3[0], name='values')
tf.identity(tensor_array_concat_v3[1], name='length')
tf.raw_ops.TensorArrayCloseV3(handle=handle)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
test_data_basic = [
dict(data_shape=[5, 3, 11, 2], data_type=np.int32),
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_tensor_array_concat_v3(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_tensor_array_concat_v3(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -312,8 +312,8 @@ def update_fallback_with_conversion_error(use_new_frontend: bool, is_tf: bool, e
conversion_error_re = r"^(\[TensorFlow\ Frontend\]\ Internal\ error\,\ no\ translator\ found\ for\ operation\(s\)\:\ )((\w+)(\,\ \w+)*)$" conversion_error_re = r"^(\[TensorFlow\ Frontend\]\ Internal\ error\,\ no\ translator\ found\ for\ operation\(s\)\:\ )((\w+)(\,\ \w+)*)$"
conversion_error_match = re.findall(conversion_error_re, ex_msg, re.MULTILINE) conversion_error_match = re.findall(conversion_error_re, ex_msg, re.MULTILINE)
all_fallback_operations = [ all_fallback_operations = [
# corresponds to TF1 TensorList operation # corresponds to TF1 While operation
"TensorArrayScatterV3", "TensorArrayV3", "TensorArraySizeV3", "TensorArrayGatherV3", "LoopCond", "Enter", "NextIteration", "Exit", "Switch", "Merge",
# corresponds to operations with complex tensors # corresponds to operations with complex tensors
"FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D", "FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D",
"RFFT", "RFFT2D", "RFFT3D", "IRFFT", "IRFFT2D", "IRFFT3D", "RFFT", "RFFT2D", "RFFT3D", "IRFFT", "IRFFT2D", "IRFFT3D",

View File

@ -235,17 +235,13 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
freeze_placeholder_with_value, freeze_placeholder_with_value,
input_shape, only_conversion, True) input_shape, only_conversion, True)
def test_conversion_failure_fallback_default(self): def test_conversion_tf1_while_default(self):
self.basic("ctc_model_based.pbtxt", None, None, None, None, self.basic("ctc_model_based.pbtxt", None, None, None, None,
None, None, True, True, False, False) None, None, True, True, False, False)
@unittest.skipIf(platform == 'darwin', reason="Ticket - 122182") def test_conversion_tf1_while_use_new_frontend(self):
def test_conversion_failure_fallback_use_new_frontend(self): self.basic("ctc_model_based.pbtxt", None, None, None, None,
with self.assertRaisesRegex(Exception, None, None, True, True, True, False)
"\[TensorFlow Frontend\] Internal error, no translator found for operation\(s\)\: "
"TensorArrayGatherV3\, TensorArrayReadV3\, TensorArraySizeV3\, TensorArrayV3\, TensorArrayWriteV3"):
self.basic("ctc_model_based.pbtxt", None, None, None, None,
None, None, True, True, True, False)
@unittest.skip("88349: Fix auto-pruning in legacy FE") @unittest.skip("88349: Fix auto-pruning in legacy FE")
def test_conversion_model_oneshot_iterator_use_legacy_frontend(self): def test_conversion_model_oneshot_iterator_use_legacy_frontend(self):