[TF FE] Fix translators for multiple output operations (#20787)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
48c9598892
commit
8d6f56dd12
@ -16,7 +16,7 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) {
|
||||
NamedOutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeContext& node) {
|
||||
// Currently, the translation for SparseReshape is possible only if new shape value is the same as the input shape
|
||||
// value or it is different just by one dynamic dimension of the new shape that can be replace with the
|
||||
// corresponding static dimension of the input shape.
|
||||
@ -67,7 +67,12 @@ OutputVector translate_sparse_reshape_op(const ov::frontend::tensorflow::NodeCon
|
||||
"This case with SparseReshape is not possible to translate to OpenVINO opset. The number "
|
||||
"of dynamic shapes in new shape must be 1 at most.");
|
||||
*/
|
||||
return {input_indices, input_shape};
|
||||
auto output_indices = input_indices;
|
||||
auto output_shape = input_shape;
|
||||
set_out_name(node.get_name() + ":0", output_indices);
|
||||
set_out_name(node.get_name() + ":1", output_shape);
|
||||
|
||||
return {{"output_indices", output_indices}, {"output_shape", output_shape}};
|
||||
}
|
||||
|
||||
NamedOutputVector translate_sparse_fill_empty_rows_op(const ov::frontend::tensorflow::NodeContext& node) {
|
||||
|
@ -43,7 +43,7 @@ TF_OP_CONVERTER(translate_queue_dequeue_many_op);
|
||||
TF_OP_CONVERTER(translate_readvariable_op);
|
||||
TF_OP_CONVERTER(translate_restorev2_op);
|
||||
TF_OP_CONVERTER_NAMED(translate_sparse_fill_empty_rows_op);
|
||||
TF_OP_CONVERTER(translate_sparse_reshape_op);
|
||||
TF_OP_CONVERTER_NAMED(translate_sparse_reshape_op);
|
||||
TF_OP_CONVERTER(translate_sparse_segment_sum_op);
|
||||
TF_OP_CONVERTER(translate_staticregexfullmatch_op);
|
||||
TF_OP_CONVERTER(translate_stringjoin_op);
|
||||
@ -216,7 +216,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"MaxPool", CreatorFunction(translate_max_pool_op)},
|
||||
{"MaxPoolV2", CreatorFunction(translate_max_pool_op)},
|
||||
{"MaxPool3D", CreatorFunction(translate_max_pool_op)},
|
||||
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_op)},
|
||||
{"MaxPoolWithArgmax", CreatorFunction(translate_max_pool_with_argmax)},
|
||||
{"Merge", CreatorFunction(translate_merge_op)},
|
||||
{"MirrorPad", CreatorFunction(translate_mirror_pad_op)},
|
||||
{"MutableHashTable", CreatorFunction(translate_hash_table_op)},
|
||||
|
@ -93,6 +93,7 @@ OP_CONVERTER(translate_lrn_op);
|
||||
OP_CONVERTER(translate_mat_mul_op);
|
||||
OP_CONVERTER(translate_matrix_diag_op);
|
||||
OP_CONVERTER(translate_max_pool_op);
|
||||
OP_CONVERTER_NAMED(translate_max_pool_with_argmax);
|
||||
OP_CONVERTER(translate_mirror_pad_op);
|
||||
OP_CONVERTER_NAMED(translate_non_max_suppression_op);
|
||||
OP_CONVERTER(translate_parallel_dynamic_stitch_op);
|
||||
|
@ -128,7 +128,7 @@ OutputVector translate_max_pool_v2(const NodeContext& node) {
|
||||
return translate_max_pool_util(node, 2, ksize_vector, strides_vector);
|
||||
}
|
||||
|
||||
OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
|
||||
NamedOutputVector translate_max_pool_with_argmax(const NodeContext& node) {
|
||||
// MaxPoolWithArgmax has just one input. ksize and strides are attributes
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
node.get_input_size() > 0,
|
||||
@ -199,8 +199,9 @@ OutputVector translate_max_pool_with_argmax(const NodeContext& node) {
|
||||
convert_nchw_to_nhwc(true, output_indices, 4);
|
||||
}
|
||||
|
||||
set_out_name(node_name + ":0", max_pool);
|
||||
set_out_name(node_name + ":1", output_indices);
|
||||
return {max_pool, output_indices};
|
||||
return {{"output", max_pool}, {"argmax", output_indices}};
|
||||
}
|
||||
|
||||
OutputVector translate_max_pool_op(const NodeContext& node) {
|
||||
@ -210,8 +211,6 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
|
||||
return translate_max_pool_v2(node);
|
||||
} else if (node.get_op_type() == "MaxPool3D") {
|
||||
return translate_max_pool(node, 3);
|
||||
} else if (node.get_op_type() == "MaxPoolWithArgmax") {
|
||||
return translate_max_pool_with_argmax(node);
|
||||
} else {
|
||||
TENSORFLOW_OP_VALIDATION(node, false, "Only MaxPool2D, MaxPoolV2 and MaxPool3D are supported.");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user