[TF FE] Fix translators for multiple output operations (#20787)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-10-31 17:22:09 +04:00 committed by GitHub
parent 48c9598892
commit 8d6f56dd12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 8 deletions

View File

@ -16,7 +16,7 @@ namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) {
NamedOutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) {
// Currently, the translation for SparseReshape is possible only if new shape value is the same as the input shape
// value or it is different just by one dynamic dimension of the new shape that can be replace with the
// corresponding static dimension of the input shape.
@ -67,7 +67,12 @@ OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeCon
"This case with SparseReshape is not possible to translate to OpenVINO opset. The number "
"of dynamic shapes in new shape must be 1 at most.");
*/
return {input_indices, input_shape};
auto output_indices = input_indices;
auto output_shape = input_shape;
set_out_name(node.get_name() + ":0", output_indices);
set_out_name(node.get_name() + ":1", output_shape);
return {{"output_indices", output_indices}, {"output_shape", output_shape}};
}
NamedOutputVector translate_sparse_fill_empty_rows_op(const ov::frontend::tensorflow::NodeContext& node) {

View File

@ -43,7 +43,7 @@ TF_OP_CONVERTER(translate_queue_dequeue_many_op);
TF_OP_CONVERTER(translate_readvariable_op);
TF_OP_CONVERTER(translate_restorev2_op);
TF_OP_CONVERTER_NAMED(translate_sparse_fill_empty_rows_op);
TF_OP_CONVERTER(translate_sparse_reshape_op);
TF_OP_CONVERTER_NAMED(translate_sparse_reshape_op);
TF_OP_CONVERTER(translate_sparse_segment_sum_op);
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
TF_OP_CONVERTER(translate_stringjoin_op);
@ -216,7 +216,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"MaxPool", CreatorFunction(translate_max_pool_op)},
{"MaxPoolV2", CreatorFunction(translate_max_pool_op)},
{"MaxPool3D", CreatorFunction(translate_max_pool_op)},
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_op)},
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_with_argmax)},
{"Merge", CreatorFunction(translate_merge_op)},
{"MirrorPad", CreatorFunction(translate_mirror_pad_op)},
{"MutableHashTable", CreatorFunction(translate_hash_table_op)},

View File

@ -93,6 +93,7 @@ OP_CONVERTER(translate_lrn_op);
OP_CONVERTER(translate_mat_mul_op);
OP_CONVERTER(translate_matrix_diag_op);
OP_CONVERTER(translate_max_pool_op);
OP_CONVERTER_NAMED(translate_max_pool_with_argmax);
OP_CONVERTER(translate_mirror_pad_op);
OP_CONVERTER_NAMED(translate_non_max_suppression_op);
OP_CONVERTER(translate_parallel_dynamic_stitch_op);

View File

@ -128,7 +128,7 @@ OutputVector translate_max_pool_v2(const NodeContext& node) {
return translate_max_pool_util(node, 2, ksize_vector, strides_vector);
}
OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
NamedOutputVector translate_max_pool_with_argmax(const NodeContext& node) {
// MaxPoolWithArgmax has just one input. ksize and strides are attributes
TENSORFLOW_OP_VALIDATION(node,
node.get_input_size() > 0,
@ -199,8 +199,9 @@ OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
convert_nchw_to_nhwc(true, output_indices, 4);
}
set_out_name(node_name + ":0", max_pool);
set_out_name(node_name + ":1", output_indices);
return {max_pool, output_indices};
return {{"output", max_pool}, {"argmax", output_indices}};
}
OutputVector translate_max_pool_op(const NodeContext& node) {
@ -210,8 +211,6 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
return translate_max_pool_v2(node);
} else if (node.get_op_type() == "MaxPool3D") {
return translate_max_pool(node, 3);
} else if (node.get_op_type() == "MaxPoolWithArgmax") {
return translate_max_pool_with_argmax(node);
} else {
TENSORFLOW_OP_VALIDATION(node, false, "Only MaxPool2D, MaxPoolV2 and MaxPool3D are supported.");
}