[TF FE] Provide full support of TF1 Control flow and TensorArray* ops (#20270)
* [TF FE] Provide full support of TF1 Control flow and TensorArray ops Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add missed header for TensorArrayV3 op * Temporarily disable GRU cell fusion * Update src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp * Fix a case when element_shape for TensorArrayV3 * Fix translator for TensorArrayCloseV3 * Update summarize graph with TensorArrayCloseV3 * Add layer tests for TensorArrayScatterV3, Close, Size, Array * Fix output shape for Merge node * Remove unused variable * Fix translator for TensorArrayConcatV3 * Fix translator for TensorArrayConcatV3 * Add layer tests for TensorArrayWriteV3, Gather, and Concat Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add translator for GatherTree * Fix TF FE unit-test for GatherTree * Fix GatherTree translator * Fix GatherTree translator to handle 1d end_token * Fix undeclared parameter issue * Fix GatherTree unit-test * Add TensorArrayV3Replacer transformation * Temporarily disable dangling transformation * Recover RemoveMultiSubGraphOpDanglingParamsResults transformation * Recover GRUCellFusion transformation * Simplify check for GRUCellFusion transformation * Use proper name for unit-tests * Simplify translator for TensorArrayWriteV3 Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix RemoveMultiSubgraphOpDanglingParamsResults transformation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Additional fix for remove_multi_subgraph_op_dangling_params * Make static TI run a dynamic subgraph * Dedicated SL test * Change condition to respect stat shapes * Adjust test to cover the code path properly * Recover fallback for still failing case GNMT --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Co-authored-by: Maksim Kutakov <maksim.kutakov@intel.com>
This commit is contained in:
parent
99dfbb400a
commit
009ef5657c
@ -148,6 +148,15 @@ ov::pass::GRUCellFusion::GRUCellFusion() {
|
|||||||
Bh = rg.make<ov::op::v0::Constant>(WRh.get_element_type(), Shape{1, static_cast<size_t>(hidden_size)}, 0);
|
Bh = rg.make<ov::op::v0::Constant>(WRh.get_element_type(), Shape{1, static_cast<size_t>(hidden_size)}, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// perform additional check for applicability of the transformation
|
||||||
|
// without this check, process_weights can fail
|
||||||
|
if (WR.get_partial_shape()[1] != (hidden_size + input_size)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (WRh.get_partial_shape()[1] != (hidden_size + input_size)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
Output<Node> Wzrh, Rzrh, Bzrh;
|
Output<Node> Wzrh, Rzrh, Bzrh;
|
||||||
if (cnt_of_consumers_of_zero_out == 1 && cnt_of_consumers_of_first_out == 2) {
|
if (cnt_of_consumers_of_zero_out == 1 && cnt_of_consumers_of_first_out == 2) {
|
||||||
tie(Wzrh, Rzrh) = process_weights(rg, false, WR, WRh, input_size, hidden_size, axis_0, axis_1);
|
tie(Wzrh, Rzrh) = process_weights(rg, false, WR, WRh, input_size, hidden_size, axis_0, axis_1);
|
||||||
|
@ -116,7 +116,7 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
|
|||||||
}
|
}
|
||||||
// Remove inputs
|
// Remove inputs
|
||||||
bool pass_required = false;
|
bool pass_required = false;
|
||||||
std::set<ov::Output<ov::Node>> required_inputs;
|
std::set<uint64_t> required_inputs_indices;
|
||||||
auto op_inputs = multi_subgraph_op->input_values();
|
auto op_inputs = multi_subgraph_op->input_values();
|
||||||
std::vector<std::vector<size_t>> to_remove_descriptors_indexes;
|
std::vector<std::vector<size_t>> to_remove_descriptors_indexes;
|
||||||
to_remove_descriptors_indexes.resize(subgraphs_size);
|
to_remove_descriptors_indexes.resize(subgraphs_size);
|
||||||
@ -133,7 +133,7 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
|
|||||||
} else {
|
} else {
|
||||||
// collecting required inputs is needed to detect cases where the input
|
// collecting required inputs is needed to detect cases where the input
|
||||||
// is not needed in a one body, but the other one uses it (for example If case)
|
// is not needed in a one body, but the other one uses it (for example If case)
|
||||||
required_inputs.insert(op_inputs[body_in_descriptors[i]->m_input_index]); // only unique
|
required_inputs_indices.insert(body_in_descriptors[i]->m_input_index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -148,7 +148,9 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto update_op_inputs_desc = [&subgraphs_size](const std::shared_ptr<op::util::MultiSubGraphOp>& op,
|
auto update_op_inputs_desc = [&subgraphs_size](const std::shared_ptr<op::util::MultiSubGraphOp>& op,
|
||||||
|
std::set<uint64_t>& required_inputs_indices,
|
||||||
uint64_t removed_loop_idx) {
|
uint64_t removed_loop_idx) {
|
||||||
|
std::set<uint64_t> new_required_inputs_indices;
|
||||||
for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) {
|
for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) {
|
||||||
auto& descriptors = op->get_input_descriptions(static_cast<int>(body_idx));
|
auto& descriptors = op->get_input_descriptions(static_cast<int>(body_idx));
|
||||||
for (auto& desc : descriptors) {
|
for (auto& desc : descriptors) {
|
||||||
@ -157,6 +159,14 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (auto input_index : required_inputs_indices) {
|
||||||
|
if (input_index > removed_loop_idx) {
|
||||||
|
new_required_inputs_indices.insert(input_index - 1);
|
||||||
|
} else {
|
||||||
|
new_required_inputs_indices.insert(input_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
required_inputs_indices = new_required_inputs_indices;
|
||||||
};
|
};
|
||||||
// Remove dangling body params and input and update input descriptors
|
// Remove dangling body params and input and update input descriptors
|
||||||
for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) {
|
for (size_t body_idx = 0; body_idx < subgraphs_size; ++body_idx) {
|
||||||
@ -174,13 +184,17 @@ bool ov::pass::RemoveMultiSubGraphOpDanglingParamsResults::run_on_model(const st
|
|||||||
update_body_param_desc(body_in_descriptors,
|
update_body_param_desc(body_in_descriptors,
|
||||||
body_in_descriptors[desc_idx]->m_body_parameter_index);
|
body_in_descriptors[desc_idx]->m_body_parameter_index);
|
||||||
// remove dangling input of MultiSubGraphOp which was not removed earlier
|
// remove dangling input of MultiSubGraphOp which was not removed earlier
|
||||||
auto& current_input = op_inputs[body_in_descriptors[desc_idx]->m_input_index];
|
auto current_input_idx = body_in_descriptors[desc_idx]->m_input_index;
|
||||||
if (std::count(std::begin(required_inputs), std::end(required_inputs), current_input) == 0 &&
|
auto& current_input = op_inputs[current_input_idx];
|
||||||
|
// the same input tensor can go to different input ports
|
||||||
|
if (std::count(std::begin(required_inputs_indices),
|
||||||
|
std::end(required_inputs_indices),
|
||||||
|
current_input_idx) == 0 &&
|
||||||
std::count(std::begin(op_inputs), std::end(op_inputs), current_input) > 0) {
|
std::count(std::begin(op_inputs), std::end(op_inputs), current_input) > 0) {
|
||||||
op_inputs.erase(std::next(op_inputs.begin(), body_in_descriptors[desc_idx]->m_input_index));
|
op_inputs.erase(std::next(op_inputs.begin(), current_input_idx));
|
||||||
// Move all input indexes (in all bodies) which are after these indicated by
|
// Move all input indexes (in all bodies) which are after these indicated by
|
||||||
// to_remove_descriptors_indexes and are not used in any body
|
// to_remove_descriptors_indexes and are not used in any body
|
||||||
update_op_inputs_desc(multi_subgraph_op, body_in_descriptors[desc_idx]->m_input_index);
|
update_op_inputs_desc(multi_subgraph_op, required_inputs_indices, current_input_idx);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
updated_body_in_descriptors.emplace_back(body_in_descriptors[desc_idx]);
|
updated_body_in_descriptors.emplace_back(body_in_descriptors[desc_idx]);
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
||||||
#include "helper_transforms/gru_block_cell_replacer.hpp"
|
#include "helper_transforms/gru_block_cell_replacer.hpp"
|
||||||
#include "helper_transforms/saved_model_unused_remover.hpp"
|
#include "helper_transforms/saved_model_unused_remover.hpp"
|
||||||
|
#include "helper_transforms/tensor_array_v3_replacer.hpp"
|
||||||
#include "input_model.hpp"
|
#include "input_model.hpp"
|
||||||
#include "op_table.hpp"
|
#include "op_table.hpp"
|
||||||
#include "openvino/core/so_extension.hpp"
|
#include "openvino/core/so_extension.hpp"
|
||||||
@ -491,6 +492,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
|||||||
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
|
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
|
||||||
manager.register_pass<pass::BlockLSTMReplacer>();
|
manager.register_pass<pass::BlockLSTMReplacer>();
|
||||||
manager.register_pass<pass::GRUBlockCellReplacer>();
|
manager.register_pass<pass::GRUBlockCellReplacer>();
|
||||||
|
manager.register_pass<pass::TensorArrayV3Replacer>();
|
||||||
manager.register_pass<pass::ConstToResultRemover>();
|
manager.register_pass<pass::ConstToResultRemover>();
|
||||||
manager.register_pass<pass::SwitchMergeResolver>();
|
manager.register_pass<pass::SwitchMergeResolver>();
|
||||||
manager.register_pass<ov::pass::UnrollIf>();
|
manager.register_pass<ov::pass::UnrollIf>();
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
#include "helper_ops/merge.hpp"
|
#include "helper_ops/merge.hpp"
|
||||||
|
|
||||||
#include "common_op_table.hpp"
|
#include "common_op_table.hpp"
|
||||||
|
#include "helper_ops/enter.hpp"
|
||||||
|
#include "helper_ops/next_iteration.hpp"
|
||||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||||
#include "openvino/op/constant.hpp"
|
#include "openvino/op/constant.hpp"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
@ -24,20 +26,47 @@ OutputVector translate_merge_op(const NodeContext& node) {
|
|||||||
auto node_name = node.get_name();
|
auto node_name = node.get_name();
|
||||||
default_op_checks(node, 1, {"Merge"});
|
default_op_checks(node, 1, {"Merge"});
|
||||||
int input_size = static_cast<int>(node.get_input_size());
|
int input_size = static_cast<int>(node.get_input_size());
|
||||||
OutputVector inputs;
|
OutputVector inputs(input_size);
|
||||||
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
|
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
|
||||||
inputs.push_back(node.get_input(input_ind));
|
inputs[input_ind] = node.get_input(input_ind);
|
||||||
}
|
}
|
||||||
|
|
||||||
// if Merge node has just one input, there is nothing to merge
|
// if Merge node has just one input, there is nothing to merge
|
||||||
// return the same input and value_index equal to 0
|
// return the same input and value_index equal to 0
|
||||||
if (inputs.size() == 1) {
|
if (input_size == 1) {
|
||||||
auto value_index = make_shared<v0::Constant>(element::i32, Shape{}, 0);
|
auto value_index = make_shared<v0::Constant>(element::i32, Shape{}, 0);
|
||||||
value_index->output(0).set_names({node_name + ":1"});
|
value_index->output(0).set_names({node_name + ":1"});
|
||||||
inputs[0].add_names({node_name + ":0"});
|
inputs[0].add_names({node_name + ":0"});
|
||||||
return OutputVector{inputs[0], value_index};
|
return OutputVector{inputs[0], value_index};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if it is a case of TF1 While: Enter, NextIteration are going to Merge node
|
||||||
|
// in this case it can refine output shape and type for NextIteration based on Enter
|
||||||
|
if (input_size == 2) {
|
||||||
|
auto enter = as_type_ptr<Enter>(inputs[0].get_node_shared_ptr());
|
||||||
|
if (!enter) {
|
||||||
|
enter = as_type_ptr<Enter>(inputs[1].get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
auto next_iteration = as_type_ptr<NextIteration>(inputs[0].get_node_shared_ptr());
|
||||||
|
if (!next_iteration) {
|
||||||
|
next_iteration = as_type_ptr<NextIteration>(inputs[1].get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enter && next_iteration) {
|
||||||
|
// set output type and shape for NextIteration
|
||||||
|
// borrow them from Enter output
|
||||||
|
auto enter_output_type = enter->output(0).get_element_type();
|
||||||
|
auto enter_output_shape = enter->output(0).get_partial_shape();
|
||||||
|
auto next_iteration_output_shape = PartialShape::dynamic(enter_output_shape.rank());
|
||||||
|
next_iteration->set_output_shape_and_type(next_iteration_output_shape, enter_output_type);
|
||||||
|
|
||||||
|
// reset inputs
|
||||||
|
// refines input shapes and types for Merge node
|
||||||
|
inputs[0] = enter->output(0);
|
||||||
|
inputs[1] = next_iteration->output(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto merge_node = make_shared<Merge>(inputs, node.get_decoder());
|
auto merge_node = make_shared<Merge>(inputs, node.get_decoder());
|
||||||
set_node_name(node.get_name(), merge_node);
|
set_node_name(node.get_name(), merge_node);
|
||||||
|
|
||||||
|
332
src/frontends/tensorflow/src/op/tensor_array_operations.cpp
Normal file
332
src/frontends/tensorflow/src/op/tensor_array_operations.cpp
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "common_op_table.hpp"
|
||||||
|
#include "helper_ops/enter.hpp"
|
||||||
|
#include "helper_ops/tensor_array.hpp"
|
||||||
|
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||||
|
#include "openvino/op/add.hpp"
|
||||||
|
#include "openvino/op/broadcast.hpp"
|
||||||
|
#include "openvino/op/concat.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/convert.hpp"
|
||||||
|
#include "openvino/op/gather.hpp"
|
||||||
|
#include "openvino/op/maximum.hpp"
|
||||||
|
#include "openvino/op/multiply.hpp"
|
||||||
|
#include "openvino/op/reshape.hpp"
|
||||||
|
#include "openvino/op/scatter_nd_update.hpp"
|
||||||
|
#include "openvino/op/scatter_update.hpp"
|
||||||
|
#include "openvino/op/shape_of.hpp"
|
||||||
|
#include "openvino/op/slice.hpp"
|
||||||
|
#include "openvino/op/subtract.hpp"
|
||||||
|
#include "openvino/op/unsqueeze.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::op;
|
||||||
|
using namespace ov::frontend::tensorflow;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// the function creates the constant imitating initial tensor array container
|
||||||
|
Output<Node> create_initial_tensor_array_constant(int64_t tensor_element_rank,
|
||||||
|
const element::Type& element_type,
|
||||||
|
Output<Node> size,
|
||||||
|
const string& node_name) {
|
||||||
|
// adjust size to have it of shape [1] for further concatenation with element shape
|
||||||
|
auto new_size_shape = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||||
|
size = make_shared<v1::Reshape>(size, new_size_shape, false);
|
||||||
|
|
||||||
|
// create a vector of size element_shape.rank() with ones
|
||||||
|
// and compute a shape of initial tensor array [size, 1, ..., 1]
|
||||||
|
vector<int32_t> ones(tensor_element_rank, 1);
|
||||||
|
auto ones_const = make_shared<v0::Constant>(element::i32, Shape{ones.size()}, ones);
|
||||||
|
auto target_shape = make_shared<v0::Concat>(OutputVector{size, ones_const}, 0);
|
||||||
|
|
||||||
|
// create initial tensor array
|
||||||
|
auto scalar_value = make_shared<v0::Constant>(element_type, Shape{}, vector<int32_t>{0});
|
||||||
|
auto initial_tensor_array = make_shared<v3::Broadcast>(scalar_value, target_shape);
|
||||||
|
|
||||||
|
return initial_tensor_array->output(0);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArrayV3 has just one input:
|
||||||
|
// 0) size to initialize a size of tensor array
|
||||||
|
default_op_checks(node, 1, {"TensorArrayV3"});
|
||||||
|
auto dtype = node.get_attribute<element::Type>("dtype");
|
||||||
|
auto size = node.get_input(0);
|
||||||
|
auto element_shape = node.get_attribute<PartialShape>("element_shape");
|
||||||
|
|
||||||
|
if (element_shape.rank().is_static()) {
|
||||||
|
auto node_name = node.get_name();
|
||||||
|
auto new_output1 =
|
||||||
|
create_initial_tensor_array_constant(element_shape.rank().get_length(), dtype, size, node.get_name());
|
||||||
|
new_output1.set_names({node_name + ":0"});
|
||||||
|
auto new_output2 =
|
||||||
|
create_initial_tensor_array_constant(element_shape.rank().get_length(), dtype, size, node.get_name());
|
||||||
|
new_output2.set_names({node_name + ":1"});
|
||||||
|
return OutputVector{new_output1, new_output2};
|
||||||
|
}
|
||||||
|
|
||||||
|
// dynamic case when it is unable retrieve element rank from the attribute
|
||||||
|
auto tensor_array_v3 = make_shared<TensorArrayV3>(size, dtype, node.get_decoder());
|
||||||
|
set_node_name(node.get_name(), tensor_array_v3);
|
||||||
|
|
||||||
|
return tensor_array_v3->outputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_scatter_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArrayScatterV3 has four inputs:
|
||||||
|
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
|
||||||
|
// 1) indices, a Tensor of type int32. The locations at which to write the tensor elements.
|
||||||
|
// 2) value, a Tensor. The concatenated tensor to write to the TensorArray
|
||||||
|
// 3) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
|
||||||
|
// The operation has one output:
|
||||||
|
// 0) flow_out indicates that operation is complete and handle resource is updated
|
||||||
|
default_op_checks(node, 4, {"TensorArrayScatterV3"});
|
||||||
|
auto indices = node.get_input(1);
|
||||||
|
auto value = node.get_input(2);
|
||||||
|
// flow_in is used for transferring input tensor array
|
||||||
|
auto tensor_array = node.get_input(3);
|
||||||
|
|
||||||
|
// check if producer of tensor_array is TensorArrayV3, internal operation, still
|
||||||
|
// if yes, try to replace it with constant container
|
||||||
|
if (as_type_ptr<TensorArrayV3>(tensor_array.get_node_shared_ptr()) &&
|
||||||
|
value.get_partial_shape().rank().is_static()) {
|
||||||
|
// set tensor element rank that gets known from TensorArrayScatterV3 operation
|
||||||
|
auto tensor_array_v3 = as_type_ptr<TensorArrayV3>(tensor_array.get_node_shared_ptr());
|
||||||
|
TENSORFLOW_OP_VALIDATION(
|
||||||
|
node,
|
||||||
|
value.get_partial_shape().rank().get_length() > 0,
|
||||||
|
"[TensorFlow Frontend] internal error or inconsistent model: value to TensorArrayScatterV3 is a scalar");
|
||||||
|
int64_t tensor_element_rank = value.get_partial_shape().rank().get_length() - 1;
|
||||||
|
tensor_array_v3->set_element_rank(tensor_element_rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute element shape (shape of a tensor in the tensor array) using value
|
||||||
|
auto element_shape = make_shared<v3::ShapeOf>(value, element::i32)->output(0);
|
||||||
|
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||||
|
auto max_const = make_shared<v0::Constant>(element::i32, Shape{1}, numeric_limits<int32_t>::max());
|
||||||
|
element_shape = make_shared<v8::Slice>(element_shape, one_const, max_const, one_const);
|
||||||
|
|
||||||
|
// compute size of tensor array
|
||||||
|
auto tensor_array_size = make_shared<v3::ShapeOf>(tensor_array, element::i32)->output(0);
|
||||||
|
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
tensor_array_size = make_shared<v8::Gather>(tensor_array_size, zero_const, zero_const);
|
||||||
|
|
||||||
|
// compute the new shape for tensor array where new tensors will be inserted
|
||||||
|
auto new_shape = make_shared<v0::Concat>(OutputVector{tensor_array_size, element_shape}, 0);
|
||||||
|
tensor_array = make_shared<v3::Broadcast>(tensor_array, new_shape);
|
||||||
|
|
||||||
|
// adjust indices for ScatterNDUpdate to have a shape [N, 1] where N is a number of indices
|
||||||
|
indices = make_shared<v0::Unsqueeze>(indices, one_const);
|
||||||
|
|
||||||
|
// compute updated tensor array using ScatterNDUpdate
|
||||||
|
// value should be of a shape [N, <elem_shape>]
|
||||||
|
auto updated_tensor_array = make_shared<v3::ScatterNDUpdate>(tensor_array, indices, value);
|
||||||
|
set_node_name(node.get_name(), updated_tensor_array);
|
||||||
|
|
||||||
|
// TensorArrayScatterV3 has just one output flow_out
|
||||||
|
// that is used for transferring updated tensor array
|
||||||
|
return {updated_tensor_array};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_read_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArrayReadV3 read an element from the TensorArray into the output
|
||||||
|
// and it has three inputs:
|
||||||
|
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
|
||||||
|
// 1) index, a Tensor of type int32. The location from which to read the value
|
||||||
|
// 2) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
|
||||||
|
// The operation has one output
|
||||||
|
// 0) read value from tensor array
|
||||||
|
default_op_checks(node, 3, {"TensorArrayReadV3"});
|
||||||
|
auto index = node.get_input(1);
|
||||||
|
// flow_in is used for transferring input tensor array
|
||||||
|
auto tensor_array = node.get_input(2);
|
||||||
|
auto dtype = node.get_attribute<element::Type>("dtype");
|
||||||
|
|
||||||
|
// adjust the index to a scalar for using Gather operation
|
||||||
|
auto new_shape = make_shared<v0::Constant>(element::i32, Shape{0}, vector<int32_t>{});
|
||||||
|
index = make_shared<v1::Reshape>(index, new_shape, false);
|
||||||
|
|
||||||
|
// gather tensor element by the required position
|
||||||
|
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
Output<Node> tensor_element = make_shared<v8::Gather>(tensor_array, index, gather_axis);
|
||||||
|
tensor_element = make_shared<v0::Convert>(tensor_element, dtype);
|
||||||
|
|
||||||
|
set_node_name(node.get_name(), tensor_element.get_node_shared_ptr());
|
||||||
|
return {tensor_element};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_close_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArrayCloseV3 deletes the TensorArray from its resource container
|
||||||
|
// it outputs nothing
|
||||||
|
default_op_checks(node, 1, {"TensorArrayCloseV3"});
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_size_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArraySizeV3 gets the current size of the TensorArray
|
||||||
|
// it outputs int32 scalar equal to a size of the tensor array
|
||||||
|
default_op_checks(node, 2, {"TensorArraySizeV3"});
|
||||||
|
// skip the handle by the first input
|
||||||
|
auto tensor_array = node.get_input(1);
|
||||||
|
|
||||||
|
auto size = make_shared<v3::ShapeOf>(tensor_array, element::i32)->output(0);
|
||||||
|
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
size = make_shared<v8::Gather>(size, zero_const, zero_const);
|
||||||
|
|
||||||
|
// size must be scalar
|
||||||
|
auto scalar_shape = make_shared<v0::Constant>(element::i32, Shape{0}, vector<int32_t>{});
|
||||||
|
size = make_shared<v1::Reshape>(size, scalar_shape, false);
|
||||||
|
|
||||||
|
set_node_name(node.get_name(), size.get_node_shared_ptr());
|
||||||
|
return {size};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_gather_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArrayGatherV3 gathers specific elements from the TensorArray into output
|
||||||
|
// and it has three inputs:
|
||||||
|
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
|
||||||
|
// 1) indices, a Tensor of type int32. The location from which to read tensor elements
|
||||||
|
// 2) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
|
||||||
|
// The operation has one output
|
||||||
|
// 0) value with read tensor elements
|
||||||
|
// it outputs int32 scalar equal to a size of the tensor array
|
||||||
|
default_op_checks(node, 3, {"TensorArrayGatherV3"});
|
||||||
|
// skip the handle by the first input
|
||||||
|
auto indices = node.get_input(1);
|
||||||
|
// flow_in serves for transferring tensor array
|
||||||
|
// handle input is ignored
|
||||||
|
auto tensor_array = node.get_input(2);
|
||||||
|
auto dtype = node.get_attribute<element::Type>("dtype");
|
||||||
|
auto element_shape = node.get_attribute<PartialShape>("element_shape", PartialShape::dynamic());
|
||||||
|
|
||||||
|
// gather tensor element by the required position
|
||||||
|
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
Output<Node> tensor_element = make_shared<v8::Gather>(tensor_array, indices, gather_axis);
|
||||||
|
tensor_element = make_shared<v0::Convert>(tensor_element, dtype);
|
||||||
|
|
||||||
|
// concretize tensor_element shape if this is specified
|
||||||
|
if (tensor_element.get_partial_shape().rank().is_dynamic() && element_shape.is_static()) {
|
||||||
|
auto element_shape_value = element_shape.get_shape();
|
||||||
|
auto element_shape_const =
|
||||||
|
make_shared<v0::Constant>(element::i32, Shape{element_shape_value.size()}, element_shape_value);
|
||||||
|
auto size = make_shared<v3::ShapeOf>(tensor_array, element::i32)->output(0);
|
||||||
|
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
size = make_shared<v8::Gather>(size, zero_const, zero_const);
|
||||||
|
auto new_shape = make_shared<v0::Concat>(OutputVector{size, element_shape_const}, 0);
|
||||||
|
tensor_element = make_shared<v1::Reshape>(tensor_element, new_shape, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
set_node_name(node.get_name(), tensor_element.get_node_shared_ptr());
|
||||||
|
return {tensor_element};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_concat_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArrayConcatV3 Concat the elements from the TensorArray into value
|
||||||
|
// and it has two inputs:
|
||||||
|
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
|
||||||
|
// 1) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
|
||||||
|
// The operation has one output
|
||||||
|
// 0) concatenated value by the first dimension
|
||||||
|
default_op_checks(node, 2, {"TensorArrayConcatV3"});
|
||||||
|
// flow_in serves for transferring tensor array
|
||||||
|
// handle input is ignored
|
||||||
|
auto tensor_array = node.get_input(1);
|
||||||
|
auto dtype = node.get_attribute<element::Type>("dtype");
|
||||||
|
|
||||||
|
// since tensor array saves tensor elements in the concatenated form by the first dimension
|
||||||
|
// and for this operation they should be concatenated by the first dimension of the tensor element
|
||||||
|
// it needs to combine the first two dimensions
|
||||||
|
// tensor array is of shape [k, n0, n1, ..., nd]
|
||||||
|
// 1. compute element shape excluding the first dimension
|
||||||
|
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
auto one_const = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||||
|
auto two_const = make_shared<v0::Constant>(element::i32, Shape{1}, 2);
|
||||||
|
auto max_const = make_shared<v0::Constant>(element::i32, Shape{1}, numeric_limits<int32_t>::max());
|
||||||
|
auto tensor_array_shape = make_shared<v3::ShapeOf>(tensor_array, element::i64);
|
||||||
|
auto element_shape_no_two_dims = make_shared<v8::Slice>(tensor_array_shape, two_const, max_const, one_const);
|
||||||
|
// 2. compute the first and second dimensions k and n0
|
||||||
|
auto k = make_shared<v8::Gather>(tensor_array_shape, zero_const, zero_const);
|
||||||
|
auto n0 = make_shared<v8::Gather>(tensor_array_shape, one_const, zero_const);
|
||||||
|
auto k_by_n0 = make_shared<v1::Multiply>(k, n0);
|
||||||
|
// 3. compute the first output containing concatenated tensor elements
|
||||||
|
// it folds the first and second dimensions
|
||||||
|
auto new_shape = make_shared<v0::Concat>(OutputVector{k_by_n0, element_shape_no_two_dims}, 0);
|
||||||
|
auto concatenated_array = make_shared<v1::Reshape>(tensor_array, new_shape, false)->output(0);
|
||||||
|
concatenated_array = make_shared<v0::Convert>(concatenated_array, dtype);
|
||||||
|
concatenated_array.set_names({node.get_name() + ":0"});
|
||||||
|
// 4. compute the second output with length of each tensor element for the concatenation
|
||||||
|
auto lengths = make_shared<v3::Broadcast>(n0, k)->output(0);
|
||||||
|
lengths.set_names({node.get_name() + ":1"});
|
||||||
|
|
||||||
|
return {concatenated_array, lengths};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_tensor_array_write_v3_op(const NodeContext& node) {
|
||||||
|
// TensorArrayWriteV3 pushes an element onto the tensor_array.
|
||||||
|
// and it has four inputs
|
||||||
|
// 0) handle, a Tensor of type resource. The handle to a TensorArray.
|
||||||
|
// 1) index, a Tensor of type int32. The location where to write tensor element
|
||||||
|
// 2) value, a Tensor. The tensor to write at the specified location
|
||||||
|
// 3) flow_in A Tensor of type float32. A float scalar that enforces proper chaining of operations.
|
||||||
|
// The operation has one output
|
||||||
|
// 0) read value from tensor array
|
||||||
|
default_op_checks(node, 4, {"TensorArrayWriteV3"});
|
||||||
|
auto handle = node.get_input(0);
|
||||||
|
auto index = node.get_input(1);
|
||||||
|
auto value = node.get_input(2);
|
||||||
|
// flow_in is used for transferring input tensor array
|
||||||
|
// tensor array has a rank equal to 1 + rank(element of tensor array)
|
||||||
|
// if it just initialized, its shape is equal to [tensor_array_size, 1, ..., 1]
|
||||||
|
// otherwise, it is equal to [tensor_array_size, <element shape>]
|
||||||
|
auto tensor_array = node.get_input(3);
|
||||||
|
|
||||||
|
// reshape index to have it of [1] shape
|
||||||
|
auto new_index_shape = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
|
||||||
|
index = make_shared<v1::Reshape>(index, new_index_shape, false);
|
||||||
|
|
||||||
|
if (auto enter = as_type_ptr<Enter>(handle.get_node_shared_ptr())) {
|
||||||
|
if (as_type_ptr<TensorArrayV3>(enter->input_value(0).get_node_shared_ptr()) &&
|
||||||
|
value.get_partial_shape().rank().is_static()) {
|
||||||
|
// set tensor element rank that gets known from TensorArrayWriteV3 operation
|
||||||
|
auto tensor_array_v3 = as_type_ptr<TensorArrayV3>(enter->input_value(0).get_node_shared_ptr());
|
||||||
|
int64_t tensor_element_rank = value.get_partial_shape().rank().get_length();
|
||||||
|
tensor_array_v3->set_element_rank(tensor_element_rank);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute element shape in the input tensor array
|
||||||
|
auto tensor_array_shape = make_shared<v3::ShapeOf>(tensor_array, element::i32);
|
||||||
|
|
||||||
|
// compute the current size of tensor array
|
||||||
|
auto zero_const = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
auto tensor_array_size = make_shared<v8::Gather>(tensor_array_shape, zero_const, zero_const);
|
||||||
|
|
||||||
|
// adjust tensor array to have the correct shape [size, <real element shape>] before value insertion
|
||||||
|
auto element_shape = make_shared<v3::ShapeOf>(value, element::i32);
|
||||||
|
auto new_tensor_array_shape = make_shared<v0::Concat>(OutputVector{tensor_array_size, element_shape}, 0);
|
||||||
|
tensor_array = make_shared<v3::Broadcast>(tensor_array, new_tensor_array_shape);
|
||||||
|
|
||||||
|
// update the resulted tensor using ScatterUpdate
|
||||||
|
value = make_shared<v0::Unsqueeze>(value, zero_const);
|
||||||
|
auto scatter_update = make_shared<v3::ScatterUpdate>(tensor_array, index, value, zero_const);
|
||||||
|
|
||||||
|
set_node_name(node.get_name(), scatter_update);
|
||||||
|
// use flow_out for transferring updated tensor array
|
||||||
|
return {scatter_update};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -46,6 +46,14 @@ TF_OP_CONVERTER(translate_sparse_segment_sum_op);
|
|||||||
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
|
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
|
||||||
TF_OP_CONVERTER(translate_stringjoin_op);
|
TF_OP_CONVERTER(translate_stringjoin_op);
|
||||||
TF_OP_CONVERTER(translate_switch_op);
|
TF_OP_CONVERTER(translate_switch_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_close_v3_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_concat_v3_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_gather_v3_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_read_v3_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_scatter_v3_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_size_v3_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_v3_op);
|
||||||
|
TF_OP_CONVERTER(translate_tensor_array_write_v3_op);
|
||||||
TF_OP_CONVERTER(translate_varhandle_op);
|
TF_OP_CONVERTER(translate_varhandle_op);
|
||||||
TF_OP_CONVERTER(translate_variable_op);
|
TF_OP_CONVERTER(translate_variable_op);
|
||||||
TF_OP_CONVERTER(translate_varisinitialized_op);
|
TF_OP_CONVERTER(translate_varisinitialized_op);
|
||||||
@ -174,6 +182,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"Gather", CreatorFunction(translate_gather_op)},
|
{"Gather", CreatorFunction(translate_gather_op)},
|
||||||
{"GatherV2", CreatorFunction(translate_gather_v2_op)},
|
{"GatherV2", CreatorFunction(translate_gather_v2_op)},
|
||||||
{"GatherNd", CreatorFunction(translate_gather_nd_op)},
|
{"GatherNd", CreatorFunction(translate_gather_nd_op)},
|
||||||
|
{"GatherTree", CreatorFunction(translate_gather_tree_op)},
|
||||||
|
{"Addons>GatherTree", CreatorFunction(translate_gather_tree_op)},
|
||||||
{"HashTable", CreatorFunction(translate_hash_table_op)},
|
{"HashTable", CreatorFunction(translate_hash_table_op)},
|
||||||
{"HashTableV2", CreatorFunction(translate_hash_table_op)},
|
{"HashTableV2", CreatorFunction(translate_hash_table_op)},
|
||||||
{"Identity", CreatorFunction(translate_identity_op)},
|
{"Identity", CreatorFunction(translate_identity_op)},
|
||||||
@ -269,6 +279,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"StatelessWhile", CreatorFunction(translate_while_op)},
|
{"StatelessWhile", CreatorFunction(translate_while_op)},
|
||||||
{"StridedSlice", CreatorFunction(translate_strided_slice_op)},
|
{"StridedSlice", CreatorFunction(translate_strided_slice_op)},
|
||||||
{"Switch", CreatorFunction(translate_switch_op)},
|
{"Switch", CreatorFunction(translate_switch_op)},
|
||||||
|
{"TensorArrayCloseV3", CreatorFunction(translate_tensor_array_close_v3_op)},
|
||||||
|
{"TensorArrayConcatV3", CreatorFunction(translate_tensor_array_concat_v3_op)},
|
||||||
|
{"TensorArrayGatherV3", CreatorFunction(translate_tensor_array_gather_v3_op)},
|
||||||
|
{"TensorArrayReadV3", CreatorFunction(translate_tensor_array_read_v3_op)},
|
||||||
|
{"TensorArrayScatterV3", CreatorFunction(translate_tensor_array_scatter_v3_op)},
|
||||||
|
{"TensorArraySizeV3", CreatorFunction(translate_tensor_array_size_v3_op)},
|
||||||
|
{"TensorArrayV3", CreatorFunction(translate_tensor_array_v3_op)},
|
||||||
|
{"TensorArrayWriteV3", CreatorFunction(translate_tensor_array_write_v3_op)},
|
||||||
{"TensorListFromTensor", CreatorFunction(translate_tensor_list_from_tensor_op)},
|
{"TensorListFromTensor", CreatorFunction(translate_tensor_list_from_tensor_op)},
|
||||||
{"TensorListGetItem", CreatorFunction(translate_tensor_list_get_item_op)},
|
{"TensorListGetItem", CreatorFunction(translate_tensor_list_get_item_op)},
|
||||||
{"TensorListLength", CreatorFunction(translate_tensor_list_length_op)},
|
{"TensorListLength", CreatorFunction(translate_tensor_list_length_op)},
|
||||||
|
@ -423,7 +423,7 @@ shared_ptr<v5::Loop> create_loop_for_tf_while(const std::string& while_node_name
|
|||||||
FRONT_END_GENERAL_CHECK(
|
FRONT_END_GENERAL_CHECK(
|
||||||
cond_results.size() == 1 && cond_results[0],
|
cond_results.size() == 1 && cond_results[0],
|
||||||
"[TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node.");
|
"[TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node.");
|
||||||
auto body_condition_output_idx = static_cast<int64_t>(body_results.size());
|
auto body_condition_output_idx = body_results.size();
|
||||||
body_model->add_results(cond_results);
|
body_model->add_results(cond_results);
|
||||||
|
|
||||||
// type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present
|
// type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present
|
||||||
@ -435,14 +435,18 @@ shared_ptr<v5::Loop> create_loop_for_tf_while(const std::string& while_node_name
|
|||||||
loop->set_function(body_model);
|
loop->set_function(body_model);
|
||||||
|
|
||||||
// body_results may contain less nodes than body_params that means back edge exists not for all body_params
|
// body_results may contain less nodes than body_params that means back edge exists not for all body_params
|
||||||
for (size_t input_ind = 0; input_ind < static_cast<size_t>(body_condition_output_idx); ++input_ind) {
|
for (size_t input_ind = 0; input_ind < body_condition_output_idx; ++input_ind) {
|
||||||
loop->set_merged_input(body_params[input_ind], ov_inputs[input_ind], body_results[input_ind]->input_value(0));
|
loop->set_merged_input(body_params[input_ind], ov_inputs[input_ind], body_results[input_ind]->input_value(0));
|
||||||
}
|
}
|
||||||
loop->set_special_body_ports({-1, body_condition_output_idx});
|
loop->set_special_body_ports({-1, static_cast<int64_t>(body_condition_output_idx)});
|
||||||
|
// set invariant inputs for the loop
|
||||||
|
for (size_t input_ind = body_condition_output_idx; input_ind < input_size; ++input_ind) {
|
||||||
|
loop->set_invariant_input(body_params[input_ind], ov_inputs[input_ind]);
|
||||||
|
}
|
||||||
|
|
||||||
// set external outputs for Loop node
|
// set external outputs for Loop node
|
||||||
// do not get execution condition outside of the Loop node
|
// do not get execution condition outside of the Loop node
|
||||||
for (size_t output_ind = 0; output_ind < static_cast<size_t>(body_condition_output_idx); ++output_ind) {
|
for (size_t output_ind = 0; output_ind < body_condition_output_idx; ++output_ind) {
|
||||||
loop->get_iter_value(body_results[output_ind]);
|
loop->get_iter_value(body_results[output_ind]);
|
||||||
}
|
}
|
||||||
loop->validate_and_infer_types();
|
loop->validate_and_infer_types();
|
||||||
|
@ -15,7 +15,8 @@ static const std::vector<std::string> models{
|
|||||||
std::string("2in_2out/2in_2out.pb"),
|
std::string("2in_2out/2in_2out.pb"),
|
||||||
std::string("forward_edge_model/forward_edge_model.pbtxt"),
|
std::string("forward_edge_model/forward_edge_model.pbtxt"),
|
||||||
std::string("forward_edge_model2/forward_edge_model2.pbtxt"),
|
std::string("forward_edge_model2/forward_edge_model2.pbtxt"),
|
||||||
std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt")};
|
std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt"),
|
||||||
|
std::string("gather_tree_model/gather_tree_model.pbtxt")};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
|
INSTANTIATE_TEST_SUITE_P(TFConvertModelTest,
|
||||||
FrontEndConvertModelTest,
|
FrontEndConvertModelTest,
|
||||||
|
@ -0,0 +1,103 @@
|
|||||||
|
node {
|
||||||
|
name: "step_ids"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 20
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "parent_ids"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 20
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "max_seq_len"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "end_token"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Addons>GatherTree"
|
||||||
|
op: "Addons>GatherTree"
|
||||||
|
input: "step_ids"
|
||||||
|
input: "parent_ids"
|
||||||
|
input: "max_seq_len"
|
||||||
|
input: "end_token"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -72,6 +72,7 @@ OP_CONVERTER_NAMED(translate_fused_batch_norm_op);
|
|||||||
OP_CONVERTER(translate_gather_op);
|
OP_CONVERTER(translate_gather_op);
|
||||||
OP_CONVERTER(translate_gather_v2_op);
|
OP_CONVERTER(translate_gather_v2_op);
|
||||||
OP_CONVERTER(translate_gather_nd_op);
|
OP_CONVERTER(translate_gather_nd_op);
|
||||||
|
OP_CONVERTER(translate_gather_tree_op);
|
||||||
OP_CONVERTER(translate_identity_op);
|
OP_CONVERTER(translate_identity_op);
|
||||||
OP_CONVERTER(translate_identity_n_op);
|
OP_CONVERTER(translate_identity_n_op);
|
||||||
OP_CONVERTER(translate_input_arg_op);
|
OP_CONVERTER(translate_input_arg_op);
|
||||||
|
@ -33,20 +33,34 @@ public:
|
|||||||
ov::PartialShape output_data_shape = ov::PartialShape::dynamic();
|
ov::PartialShape output_data_shape = ov::PartialShape::dynamic();
|
||||||
|
|
||||||
auto input_size = get_input_size();
|
auto input_size = get_input_size();
|
||||||
bool merge_output_shape = true;
|
|
||||||
for (size_t input_ind = 0; input_ind < input_size; ++input_ind) {
|
for (size_t input_ind = 0; input_ind < input_size; ++input_ind) {
|
||||||
auto input_type = get_input_element_type(input_ind);
|
auto input_type = get_input_element_type(input_ind);
|
||||||
if (input_type.is_static()) {
|
if (input_type.is_static()) {
|
||||||
output_data_type = input_type;
|
output_data_type = input_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if it still needs to merge input shapes
|
auto input_shape = get_input_partial_shape(input_ind);
|
||||||
// if yes, it tries to merge them
|
if (input_shape.rank().is_dynamic()) {
|
||||||
if (merge_output_shape &&
|
continue;
|
||||||
!PartialShape::merge_into(output_data_shape, get_input_partial_shape(input_ind))) {
|
}
|
||||||
merge_output_shape = false;
|
|
||||||
// reset output shape to dynamic rank
|
if (output_data_shape.rank().is_dynamic()) {
|
||||||
|
// firstly met shape of static rank
|
||||||
|
// immediately use this shape of static rank
|
||||||
|
output_data_shape = input_shape;
|
||||||
|
} else if (output_data_shape.rank().is_static() &&
|
||||||
|
output_data_shape.rank().get_length() != input_shape.rank().get_length()) {
|
||||||
|
// different inputs have different rank means output must be of a dynamic rank
|
||||||
output_data_shape = ov::PartialShape::dynamic();
|
output_data_shape = ov::PartialShape::dynamic();
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
auto output_rank = output_data_shape.rank().get_length();
|
||||||
|
for (int64_t dim_ind = 0; dim_ind < output_rank; ++dim_ind) {
|
||||||
|
if (input_shape[dim_ind] != output_data_shape[dim_ind]) {
|
||||||
|
// different inputs can have different dimensions so it must combine them
|
||||||
|
output_data_shape[dim_ind] = ov::Dimension::dynamic();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +43,10 @@ public:
|
|||||||
producer_output_port_idx = m_producer_output_port_idx;
|
producer_output_port_idx = m_producer_output_port_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_output_shape_and_type(const ov::PartialShape& output_shape, const ov::element::Type& output_type) {
|
||||||
|
set_output_type(0, output_type, output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool m_back_edge_set;
|
bool m_back_edge_set;
|
||||||
std::string m_producer_name;
|
std::string m_producer_name;
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "internal_operation.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Internal operation for TensorArrayV3
|
||||||
|
// An array of Tensors of given size
|
||||||
|
// It has two outputs:
|
||||||
|
// 1. handle - resource (a reference) for tensor array
|
||||||
|
// 2. flow_out - float type will be used for storing tensor array
|
||||||
|
class TensorArrayV3 : public InternalOperation {
|
||||||
|
public:
|
||||||
|
OPENVINO_OP("TensorArrayV3", "ov::frontend::tensorflow", InternalOperation);
|
||||||
|
|
||||||
|
TensorArrayV3(const Output<Node>& size,
|
||||||
|
const ov::element::Type element_type,
|
||||||
|
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
|
||||||
|
: InternalOperation(decoder, OutputVector{size}, 2, "TensorArrayV3"),
|
||||||
|
m_element_type(element_type),
|
||||||
|
m_element_rank(-1) {
|
||||||
|
validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
void validate_and_infer_types() override {
|
||||||
|
set_output_type(0, m_element_type, ov::PartialShape::dynamic());
|
||||||
|
set_output_type(1, m_element_type, ov::PartialShape::dynamic());
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::element::Type get_element_type() const {
|
||||||
|
return m_element_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_element_rank() const {
|
||||||
|
return m_element_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_element_rank(int64_t element_rank) {
|
||||||
|
FRONT_END_GENERAL_CHECK(
|
||||||
|
element_rank >= 0,
|
||||||
|
"[TensorFlow Frontend] internal error: negavite element rank tries to set for TensorArrayV3");
|
||||||
|
m_element_rank = element_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ov::element::Type m_element_type;
|
||||||
|
int64_t m_element_rank;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
// This transformation replaces internal operation TensorArrayV3 with a Constant
|
||||||
|
// that simulates initial state of tensor array container
|
||||||
|
class TensorArrayV3Replacer : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorArrayV3Replacer");
|
||||||
|
TensorArrayV3Replacer();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -0,0 +1,71 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "helper_transforms/tensor_array_v3_replacer.hpp"
|
||||||
|
|
||||||
|
#include "helper_ops/tensor_array.hpp"
|
||||||
|
#include "openvino/op/broadcast.hpp"
|
||||||
|
#include "openvino/op/concat.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/reshape.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::op;
|
||||||
|
using namespace ov::pass;
|
||||||
|
|
||||||
|
ov::frontend::tensorflow::pass::TensorArrayV3Replacer::TensorArrayV3Replacer() {
|
||||||
|
auto tensor_array_v3 = pattern::wrap_type<TensorArrayV3>();
|
||||||
|
|
||||||
|
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
NodeRegistry rg;
|
||||||
|
|
||||||
|
auto tensor_array_v3 = dynamic_pointer_cast<TensorArrayV3>(m.get_match_root());
|
||||||
|
if (!tensor_array_v3) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t tensor_element_rank = static_cast<int32_t>(tensor_array_v3->get_element_rank());
|
||||||
|
if (tensor_element_rank < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve all TensorArrayV3 inputs
|
||||||
|
auto size = tensor_array_v3->input_value(0);
|
||||||
|
auto element_type = tensor_array_v3->get_element_type();
|
||||||
|
|
||||||
|
// adjust size to have it of shape [1] for further concatenation with element shape
|
||||||
|
auto new_size_shape = rg.make<v0::Constant>(element::i32, Shape{1}, 1);
|
||||||
|
auto new_size = rg.make<v1::Reshape>(size, new_size_shape, false);
|
||||||
|
|
||||||
|
// create a vector of size element_shape.rank() with ones
|
||||||
|
// and compute a shape of initial tensor array [size, 1, ..., 1]
|
||||||
|
Output<Node> target_shape;
|
||||||
|
if (tensor_element_rank == 0) {
|
||||||
|
target_shape = new_size->output(0);
|
||||||
|
} else {
|
||||||
|
vector<int32_t> ones(tensor_element_rank, 1);
|
||||||
|
auto ones_const = rg.make<v0::Constant>(element::i32, Shape{ones.size()}, ones);
|
||||||
|
target_shape = rg.make<v0::Concat>(OutputVector{new_size, ones_const}, 0)->output(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create initial tensor array
|
||||||
|
auto scalar_value = make_shared<v0::Constant>(element_type, Shape{}, vector<int32_t>{0});
|
||||||
|
auto initial_tensor_array = make_shared<v3::Broadcast>(scalar_value, target_shape);
|
||||||
|
|
||||||
|
// preserve names of the node and the output tensor
|
||||||
|
initial_tensor_array->set_friendly_name(tensor_array_v3->get_friendly_name());
|
||||||
|
copy_runtime_info(tensor_array_v3, rg.get());
|
||||||
|
|
||||||
|
ov::replace_node(tensor_array_v3,
|
||||||
|
ov::OutputVector{initial_tensor_array->output(0), initial_tensor_array->output(0)});
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m =
|
||||||
|
std::make_shared<pattern::Matcher>(tensor_array_v3, "ov::frontend::tensorflow::pass::TensorArrayV3Replacer");
|
||||||
|
register_matcher(m, callback);
|
||||||
|
}
|
39
src/frontends/tensorflow_common/src/op/gather_tree.cpp
Normal file
39
src/frontends/tensorflow_common/src/op/gather_tree.cpp
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/op/gather_tree.hpp"
|
||||||
|
|
||||||
|
#include "common_op_table.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/reshape.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
OutputVector translate_gather_tree_op(const NodeContext& node) {
|
||||||
|
default_op_checks(node, 4, {"GatherTree", "Addons>GatherTree"});
|
||||||
|
auto step_ids = node.get_input(0);
|
||||||
|
auto parent_ids = node.get_input(1);
|
||||||
|
auto max_sequence_lengths = node.get_input(2);
|
||||||
|
auto end_token = node.get_input(3);
|
||||||
|
|
||||||
|
// adjust end_token that must be a scalar
|
||||||
|
auto new_shape_end_token = make_shared<v0::Constant>(element::i32, Shape{0}, vector<int32_t>{});
|
||||||
|
end_token = make_shared<v1::Reshape>(end_token, new_shape_end_token, false);
|
||||||
|
|
||||||
|
auto gather_tree = make_shared<v1::GatherTree>(step_ids, parent_ids, max_sequence_lengths, end_token);
|
||||||
|
set_node_name(node.get_name(), gather_tree);
|
||||||
|
|
||||||
|
return {gather_tree};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -513,7 +513,7 @@ void TensorIterator::createPrimitive() {
|
|||||||
lastUsedCond = initial_cond_check->getStatus();
|
lastUsedCond = initial_cond_check->getStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isDynamicNode())
|
if (runAsDynamic())
|
||||||
prepareDynamicBuffers();
|
prepareDynamicBuffers();
|
||||||
|
|
||||||
Node::createPrimitive();
|
Node::createPrimitive();
|
||||||
@ -556,7 +556,7 @@ void TensorIterator::prepareParams() {
|
|||||||
prepareContinueCond();
|
prepareContinueCond();
|
||||||
prepareLoopBodyCurrentIteration();
|
prepareLoopBodyCurrentIteration();
|
||||||
|
|
||||||
if (!isDynamicNode()) {
|
if (!runAsDynamic()) {
|
||||||
prepareOutputPorts();
|
prepareOutputPorts();
|
||||||
prepareBackEdges();
|
prepareBackEdges();
|
||||||
}
|
}
|
||||||
@ -568,6 +568,12 @@ void TensorIterator::prepareParams() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TensorIterator::execute(dnnl::stream strm) {
|
void TensorIterator::execute(dnnl::stream strm) {
|
||||||
|
//Special case, the subgraph is dynamic while the node has all static shapes
|
||||||
|
if (runAsDynamic()) {
|
||||||
|
executeDynamicImpl(strm);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
sub_graph.ResetInferCount();
|
sub_graph.ResetInferCount();
|
||||||
|
|
||||||
bool continue_cond = initial_cond_check->getStatus();
|
bool continue_cond = initial_cond_check->getStatus();
|
||||||
@ -872,6 +878,10 @@ int TensorIterator::getNumIteration(const std::vector<PortMap>& inputPortMap, co
|
|||||||
return numIterations;
|
return numIterations;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TensorIterator::runAsDynamic() const {
|
||||||
|
return isDynamicNode() || Graph::Status::ReadyDynamic == sub_graph.getStatus();
|
||||||
|
}
|
||||||
|
|
||||||
bool TensorIterator::created() const {
|
bool TensorIterator::created() const {
|
||||||
return getType() == Type::TensorIterator;
|
return getType() == Type::TensorIterator;
|
||||||
}
|
}
|
||||||
|
@ -138,6 +138,7 @@ private:
|
|||||||
void reshapeAndFillOutput(dnnl::stream strm);
|
void reshapeAndFillOutput(dnnl::stream strm);
|
||||||
bool checkForInputAndBodyShapesInequality() const;
|
bool checkForInputAndBodyShapesInequality() const;
|
||||||
int getNumIteration(const std::vector<PortMap>& inputPortMap, const std::vector<PortMap>& outputPortMap) const;
|
int getNumIteration(const std::vector<PortMap>& inputPortMap, const std::vector<PortMap>& outputPortMap) const;
|
||||||
|
bool runAsDynamic() const;
|
||||||
|
|
||||||
ExtensionManager::Ptr ext_mng;
|
ExtensionManager::Ptr ext_mng;
|
||||||
Graph sub_graph;
|
Graph sub_graph;
|
||||||
|
@ -371,6 +371,65 @@ protected:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class StaticLoopDynamicSubgraphCPUTest : public SubgraphBaseTest {
|
||||||
|
void SetUp() override {
|
||||||
|
InputShape input_shape = {{25, 1, 1}, {{25, 1, 1}}};
|
||||||
|
InputShape input_exec_flag_shape = {{1}, {{1}}};
|
||||||
|
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||||
|
ElementType netType = ov::element::f32;
|
||||||
|
init_input_shapes({input_shape, input_exec_flag_shape});
|
||||||
|
|
||||||
|
ov::ParameterVector params;
|
||||||
|
params.push_back(std::make_shared<ov::op::v0::Parameter>(netType, inputDynamicShapes[0]));
|
||||||
|
|
||||||
|
// exec_condition
|
||||||
|
params.push_back(std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[1]));
|
||||||
|
|
||||||
|
auto trip_count_input = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 2);
|
||||||
|
auto body_condition_const = std::make_shared<ov::op::v0::Constant>(ov::element::boolean, ov::Shape{1}, true);
|
||||||
|
|
||||||
|
// Body parameters
|
||||||
|
ov::ParameterVector body_params = {std::make_shared<ov::op::v0::Parameter>(netType, ov::PartialShape{25, 1, -1})};
|
||||||
|
|
||||||
|
// Body
|
||||||
|
auto broadcast_target_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{25, 1, 256});
|
||||||
|
auto broadcast_axis_mapping = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 0);
|
||||||
|
auto broadcast = std::make_shared<ov::op::v3::Broadcast>(body_params[0], broadcast_target_shape);
|
||||||
|
auto body = std::make_shared<ov::Model>(ov::OutputVector{body_condition_const, broadcast}, body_params);
|
||||||
|
|
||||||
|
auto loop = std::make_shared<ov::op::v5::Loop>(trip_count_input, params[1]);
|
||||||
|
loop->set_function(body);
|
||||||
|
loop->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 0});
|
||||||
|
|
||||||
|
loop->set_merged_input(body_params.front(), params.front(), broadcast);
|
||||||
|
|
||||||
|
auto out0 = loop->get_iter_value(body_condition_const, -1);
|
||||||
|
auto out1 = loop->get_iter_value(broadcast, -1);
|
||||||
|
|
||||||
|
auto result0 = std::make_shared<ov::op::v0::Result>(out0);
|
||||||
|
auto result1 = std::make_shared<ov::op::v0::Result>(out1);
|
||||||
|
function = std::make_shared<ov::Model>(ov::ResultVector{result0, result1}, params, "loop");
|
||||||
|
}
|
||||||
|
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||||
|
inputs.clear();
|
||||||
|
const auto& funcInputs = function->inputs();
|
||||||
|
for (size_t i = 0; i < funcInputs.size(); ++i) {
|
||||||
|
const auto& funcInput = funcInputs[i];
|
||||||
|
ov::Tensor tensor;
|
||||||
|
|
||||||
|
if (i == 1) {
|
||||||
|
tensor = ov::Tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
|
||||||
|
auto* dataPtr = tensor.data<bool>();
|
||||||
|
*dataPtr = true;
|
||||||
|
} else {
|
||||||
|
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 2560, 0, 256);
|
||||||
|
}
|
||||||
|
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
TEST_P(LoopLayerCPUTest, CompareWithRefs) {
|
TEST_P(LoopLayerCPUTest, CompareWithRefs) {
|
||||||
run();
|
run();
|
||||||
}
|
}
|
||||||
@ -387,6 +446,10 @@ TEST_P(LoopForConcatLayerCPUTest, CompareWithRefs) {
|
|||||||
run();
|
run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(StaticLoopDynamicSubgraphCPUTest, smoke_StaticLoopWithDynSubgraph) {
|
||||||
|
run();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const std::vector<ElementType> inputPrecisions = {
|
const std::vector<ElementType> inputPrecisions = {
|
||||||
|
@ -98,7 +98,7 @@ def summarize_graph(model_path, output_nodes_for_freeze=None, reshape_net=None):
|
|||||||
variables = list()
|
variables = list()
|
||||||
outputs = list()
|
outputs = list()
|
||||||
graph = load_graph(model_path, output_nodes_for_freeze)
|
graph = load_graph(model_path, output_nodes_for_freeze)
|
||||||
unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert', 'switch_t', 'switch_f']
|
unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert', 'switch_t', 'switch_f', 'TensorArrayCloseV3']
|
||||||
control_dependents_map = collect_control_dependencies(graph)
|
control_dependents_map = collect_control_dependencies(graph)
|
||||||
for node in graph.as_graph_def().node:
|
for node in graph.as_graph_def().node:
|
||||||
if node.op == 'Placeholder':
|
if node.op == 'Placeholder':
|
||||||
|
200
tests/layer_tests/tensorflow_tests/test_tf_TensorArrayOps.py
Normal file
200
tests/layer_tests/tensorflow_tests/test_tf_TensorArrayOps.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import tensorflow as tf
|
||||||
|
from common.tf_layer_test_class import CommonTFLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
def create_tensor_array(data_shape, data_type):
|
||||||
|
size = data_shape[0]
|
||||||
|
data = tf.compat.v1.placeholder(data_type, data_shape, 'data')
|
||||||
|
indices = tf.compat.v1.placeholder(tf.int32, [size], 'indices')
|
||||||
|
size_const = tf.constant(size, dtype=tf.int32, shape=[])
|
||||||
|
handle, flow = tf.raw_ops.TensorArrayV3(size=size_const, dtype=tf.as_dtype(data_type))
|
||||||
|
flow = tf.raw_ops.TensorArrayScatterV3(handle=handle, indices=indices, value=data, flow_in=flow)
|
||||||
|
return handle, flow
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorArraySizeV3(CommonTFLayerTest):
|
||||||
|
def _prepare_input(self, inputs_info):
|
||||||
|
assert 'data' in inputs_info
|
||||||
|
assert 'indices' in inputs_info
|
||||||
|
data_shape = inputs_info['data']
|
||||||
|
inputs_data = {}
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
|
||||||
|
inputs_data['indices'] = rng.permutation(self.size).astype(np.int32)
|
||||||
|
return inputs_data
|
||||||
|
|
||||||
|
def create_tensor_array_size_v3(self, data_shape, data_type):
|
||||||
|
size = data_shape[0]
|
||||||
|
self.data_type = data_type
|
||||||
|
self.size = size
|
||||||
|
tf.compat.v1.reset_default_graph()
|
||||||
|
# Create the graph and model
|
||||||
|
with tf.compat.v1.Session() as sess:
|
||||||
|
handle, flow = create_tensor_array(data_shape, data_type)
|
||||||
|
tf.raw_ops.TensorArraySizeV3(handle=handle, flow_in=flow)
|
||||||
|
tf.raw_ops.TensorArrayCloseV3(handle=handle)
|
||||||
|
tf.compat.v1.global_variables_initializer()
|
||||||
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
|
return tf_net, None
|
||||||
|
|
||||||
|
test_data_basic = [
|
||||||
|
dict(data_shape=[5], data_type=np.float32),
|
||||||
|
dict(data_shape=[10, 20, 30], data_type=np.int32),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("params", test_data_basic)
|
||||||
|
@pytest.mark.precommit_tf_fe
|
||||||
|
@pytest.mark.nightly
|
||||||
|
def test_tensor_array_size_v3(self, params, ie_device, precision, ir_version, temp_dir,
|
||||||
|
use_new_frontend, use_old_api):
|
||||||
|
self._test(*self.create_tensor_array_size_v3(**params),
|
||||||
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorArrayReadV3(CommonTFLayerTest):
|
||||||
|
def _prepare_input(self, inputs_info):
|
||||||
|
assert 'data' in inputs_info
|
||||||
|
assert 'indices' in inputs_info
|
||||||
|
data_shape = inputs_info['data']
|
||||||
|
inputs_data = {}
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
|
||||||
|
inputs_data['index_to_read'] = rng.integers(0, data_shape[0], []).astype(np.int32)
|
||||||
|
inputs_data['indices'] = rng.permutation(self.size).astype(np.int32)
|
||||||
|
return inputs_data
|
||||||
|
|
||||||
|
def create_tensor_array_read_v3(self, data_shape, data_type):
|
||||||
|
size = data_shape[0]
|
||||||
|
self.data_type = data_type
|
||||||
|
self.size = size
|
||||||
|
tf.compat.v1.reset_default_graph()
|
||||||
|
# Create the graph and model
|
||||||
|
with tf.compat.v1.Session() as sess:
|
||||||
|
handle, flow = create_tensor_array(data_shape, data_type)
|
||||||
|
index_to_read = tf.compat.v1.placeholder(tf.int32, [], 'index_to_read')
|
||||||
|
tf.raw_ops.TensorArrayReadV3(handle=handle, index=index_to_read, flow_in=flow,
|
||||||
|
dtype=tf.dtypes.as_dtype(data_type))
|
||||||
|
tf.raw_ops.TensorArrayCloseV3(handle=handle)
|
||||||
|
tf.compat.v1.global_variables_initializer()
|
||||||
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
|
return tf_net, None
|
||||||
|
|
||||||
|
test_data_basic = [
|
||||||
|
dict(data_shape=[6], data_type=np.float32),
|
||||||
|
dict(data_shape=[8, 5, 6, 10], data_type=np.int32),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("params", test_data_basic)
|
||||||
|
@pytest.mark.precommit_tf_fe
|
||||||
|
@pytest.mark.nightly
|
||||||
|
def test_tensor_array_read_v3(self, params, ie_device, precision, ir_version, temp_dir,
|
||||||
|
use_new_frontend, use_old_api):
|
||||||
|
self._test(*self.create_tensor_array_read_v3(**params),
|
||||||
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorArrayWriteGatherV3(CommonTFLayerTest):
|
||||||
|
def _prepare_input(self, inputs_info):
|
||||||
|
assert 'data' in inputs_info
|
||||||
|
assert 'indices' in inputs_info
|
||||||
|
assert 'value_to_write' in inputs_info
|
||||||
|
data_shape = inputs_info['data']
|
||||||
|
value_shape = inputs_info['value_to_write']
|
||||||
|
inputs_data = {}
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
|
||||||
|
inputs_data['value_to_write'] = rng.integers(-10, 10, value_shape).astype(self.data_type)
|
||||||
|
indices_data = rng.permutation(self.size).astype(np.int32)
|
||||||
|
inputs_data['indices'] = np.delete(indices_data, np.where(indices_data == self.index_to_write))
|
||||||
|
return inputs_data
|
||||||
|
|
||||||
|
def create_tensor_array_write_v3(self, size, data_shape, data_type, index_to_write, indices_to_gather):
|
||||||
|
self.data_type = data_type
|
||||||
|
self.size = size
|
||||||
|
self.index_to_write = index_to_write
|
||||||
|
tf.compat.v1.reset_default_graph()
|
||||||
|
# Create the graph and model
|
||||||
|
with tf.compat.v1.Session() as sess:
|
||||||
|
value_to_write = tf.compat.v1.placeholder(data_type, data_shape[1:], 'value_to_write')
|
||||||
|
index_to_write_const = tf.constant(index_to_write, dtype=tf.int32, shape=[])
|
||||||
|
indices_to_gather_const = tf.constant(indices_to_gather, dtype=tf.int32, shape=[len(indices_to_gather)])
|
||||||
|
data = tf.compat.v1.placeholder(data_type, data_shape, 'data')
|
||||||
|
indices = tf.compat.v1.placeholder(tf.int32, [size - 1], 'indices')
|
||||||
|
size_const = tf.constant(size, dtype=tf.int32, shape=[])
|
||||||
|
handle, flow = tf.raw_ops.TensorArrayV3(size=size_const, dtype=tf.as_dtype(data_type))
|
||||||
|
flow = tf.raw_ops.TensorArrayScatterV3(handle=handle, indices=indices, value=data, flow_in=flow)
|
||||||
|
flow = tf.raw_ops.TensorArrayWriteV3(handle=handle, index=index_to_write_const,
|
||||||
|
value=value_to_write, flow_in=flow)
|
||||||
|
tf.raw_ops.TensorArrayGatherV3(handle=handle, indices=indices_to_gather_const, flow_in=flow,
|
||||||
|
dtype=tf.dtypes.as_dtype(data_type))
|
||||||
|
tf.raw_ops.TensorArrayCloseV3(handle=handle)
|
||||||
|
tf.compat.v1.global_variables_initializer()
|
||||||
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
|
return tf_net, None
|
||||||
|
|
||||||
|
test_data_basic = [
|
||||||
|
dict(size=7, data_shape=[6], data_type=np.float32, index_to_write=3, indices_to_gather=[0, 3, 1]),
|
||||||
|
dict(size=10, data_shape=[9, 2, 4], data_type=np.int32, index_to_write=2, indices_to_gather=[2, 1, 4, 3]),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("params", test_data_basic)
|
||||||
|
@pytest.mark.precommit_tf_fe
|
||||||
|
@pytest.mark.nightly
|
||||||
|
def test_tensor_array_write_v3(self, params, ie_device, precision, ir_version, temp_dir,
|
||||||
|
use_new_frontend, use_old_api):
|
||||||
|
self._test(*self.create_tensor_array_write_v3(**params),
|
||||||
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorArrayConcatV3(CommonTFLayerTest):
|
||||||
|
def _prepare_input(self, inputs_info):
|
||||||
|
assert 'data' in inputs_info
|
||||||
|
assert 'indices' in inputs_info
|
||||||
|
data_shape = inputs_info['data']
|
||||||
|
inputs_data = {}
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
inputs_data['data'] = rng.integers(-10, 10, data_shape).astype(self.data_type)
|
||||||
|
inputs_data['indices'] = rng.permutation(self.size).astype(np.int32)
|
||||||
|
return inputs_data
|
||||||
|
|
||||||
|
def create_tensor_array_concat_v3(self, data_shape, data_type):
|
||||||
|
size = data_shape[0]
|
||||||
|
self.data_type = data_type
|
||||||
|
self.size = size
|
||||||
|
tf.compat.v1.reset_default_graph()
|
||||||
|
# Create the graph and model
|
||||||
|
with tf.compat.v1.Session() as sess:
|
||||||
|
handle, flow = create_tensor_array(data_shape, data_type)
|
||||||
|
tensor_array_concat_v3 = tf.raw_ops.TensorArrayConcatV3(handle=handle, flow_in=flow,
|
||||||
|
dtype=tf.as_dtype(data_type))
|
||||||
|
tf.identity(tensor_array_concat_v3[0], name='values')
|
||||||
|
tf.identity(tensor_array_concat_v3[1], name='length')
|
||||||
|
tf.raw_ops.TensorArrayCloseV3(handle=handle)
|
||||||
|
tf.compat.v1.global_variables_initializer()
|
||||||
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
|
return tf_net, None
|
||||||
|
|
||||||
|
test_data_basic = [
|
||||||
|
dict(data_shape=[5, 3, 11, 2], data_type=np.int32),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("params", test_data_basic)
|
||||||
|
@pytest.mark.precommit_tf_fe
|
||||||
|
@pytest.mark.nightly
|
||||||
|
def test_tensor_array_concat_v3(self, params, ie_device, precision, ir_version, temp_dir,
|
||||||
|
use_new_frontend, use_old_api):
|
||||||
|
self._test(*self.create_tensor_array_concat_v3(**params),
|
||||||
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
@ -312,8 +312,8 @@ def update_fallback_with_conversion_error(use_new_frontend: bool, is_tf: bool, e
|
|||||||
conversion_error_re = r"^(\[TensorFlow\ Frontend\]\ Internal\ error\,\ no\ translator\ found\ for\ operation\(s\)\:\ )((\w+)(\,\ \w+)*)$"
|
conversion_error_re = r"^(\[TensorFlow\ Frontend\]\ Internal\ error\,\ no\ translator\ found\ for\ operation\(s\)\:\ )((\w+)(\,\ \w+)*)$"
|
||||||
conversion_error_match = re.findall(conversion_error_re, ex_msg, re.MULTILINE)
|
conversion_error_match = re.findall(conversion_error_re, ex_msg, re.MULTILINE)
|
||||||
all_fallback_operations = [
|
all_fallback_operations = [
|
||||||
# corresponds to TF1 TensorList operation
|
# corresponds to TF1 While operation
|
||||||
"TensorArrayScatterV3", "TensorArrayV3", "TensorArraySizeV3", "TensorArrayGatherV3",
|
"LoopCond", "Enter", "NextIteration", "Exit", "Switch", "Merge",
|
||||||
# corresponds to operations with complex tensors
|
# corresponds to operations with complex tensors
|
||||||
"FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D",
|
"FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D",
|
||||||
"RFFT", "RFFT2D", "RFFT3D", "IRFFT", "IRFFT2D", "IRFFT3D",
|
"RFFT", "RFFT2D", "RFFT3D", "IRFFT", "IRFFT2D", "IRFFT3D",
|
||||||
|
@ -235,17 +235,13 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
|||||||
freeze_placeholder_with_value,
|
freeze_placeholder_with_value,
|
||||||
input_shape, only_conversion, True)
|
input_shape, only_conversion, True)
|
||||||
|
|
||||||
def test_conversion_failure_fallback_default(self):
|
def test_conversion_tf1_while_default(self):
|
||||||
self.basic("ctc_model_based.pbtxt", None, None, None, None,
|
self.basic("ctc_model_based.pbtxt", None, None, None, None,
|
||||||
None, None, True, True, False, False)
|
None, None, True, True, False, False)
|
||||||
|
|
||||||
@unittest.skipIf(platform == 'darwin', reason="Ticket - 122182")
|
def test_conversion_tf1_while_use_new_frontend(self):
|
||||||
def test_conversion_failure_fallback_use_new_frontend(self):
|
self.basic("ctc_model_based.pbtxt", None, None, None, None,
|
||||||
with self.assertRaisesRegex(Exception,
|
None, None, True, True, True, False)
|
||||||
"\[TensorFlow Frontend\] Internal error, no translator found for operation\(s\)\: "
|
|
||||||
"TensorArrayGatherV3\, TensorArrayReadV3\, TensorArraySizeV3\, TensorArrayV3\, TensorArrayWriteV3"):
|
|
||||||
self.basic("ctc_model_based.pbtxt", None, None, None, None,
|
|
||||||
None, None, True, True, True, False)
|
|
||||||
|
|
||||||
@unittest.skip("88349: Fix auto-pruning in legacy FE")
|
@unittest.skip("88349: Fix auto-pruning in legacy FE")
|
||||||
def test_conversion_model_oneshot_iterator_use_legacy_frontend(self):
|
def test_conversion_model_oneshot_iterator_use_legacy_frontend(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user