[TF FE] Add translators for NormalizeL2, ReverseSequence (#12913)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-09-07 18:02:48 +03:00 committed by GitHub
parent 8922d73e7d
commit 0b1a70be0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 18 deletions

View File

@ -26,6 +26,20 @@ OutputVector translate_identity_op(const NodeContext& node) {
return {input}; return {input};
} }
OutputVector translate_identity_n_op(const NodeContext& node) {
auto input_size = node.get_input_size();
auto node_name = node.get_name();
OutputVector result;
for (int input_idx = 0; input_idx < input_size; ++input_idx) {
auto input = node.get_input(input_idx);
set_out_name(node_name + ":" + std::to_string(input_idx), input);
result.push_back(input);
}
return result;
}
} // namespace op } // namespace op
} // namespace tensorflow } // namespace tensorflow
} // namespace frontend } // namespace frontend

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op_table.hpp"
#include "openvino/opsets/opset8.hpp"
using namespace std;
using namespace ov::opset8;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_normalize_l2_op(const NodeContext& node) {
default_op_checks(node, 2, {"NormalizeL2"});
auto x = node.get_input(0);
auto axes = node.get_input(1);
// retrieve attribute
auto eps = node.get_attribute<float>("epsilon");
auto normalize_l2 = make_shared<NormalizeL2>(x, axes, eps, ov::op::EpsMode::MAX);
set_node_name(node.get_name(), normalize_l2);
return {normalize_l2};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,31 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op_table.hpp"
#include "openvino/opsets/opset8.hpp"
using namespace std;
using namespace ov::opset8;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_reverse_sequence_op(const NodeContext& node) {
default_op_checks(node, 2, {"ReverseSequence"});
auto input = node.get_input(0);
auto seq_lengths = node.get_input(1);
// retrieve attributes
auto seq_axis = node.get_attribute<int64_t>("seq_dim");
auto batch_axis = node.get_attribute<int64_t>("batch_dim", 0);
auto reverse_sequence = make_shared<ReverseSequence>(input, seq_lengths, batch_axis, seq_axis);
set_node_name(node.get_name(), reverse_sequence);
return {reverse_sequence};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -12,25 +12,32 @@ namespace ov {
namespace frontend { namespace frontend {
namespace tensorflow { namespace tensorflow {
namespace op { namespace op {
OutputVector translate_top_k_base_op(const NodeContext& node, const ov::Output<ov::Node>& k_input, int min_input_size) {
default_op_checks(node, min_input_size, {"TopK", "TopKV2"});
auto input = node.get_input(0);
// retrieve k attribute
bool sorted = node.get_attribute<bool>("sorted", true);
auto top_k = make_shared<TopK>(input,
k_input,
-1,
ov::op::v1::TopK::Mode::MAX,
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES,
ov::element::i32);
set_node_name(node.get_name(), top_k);
return {top_k};
}
OutputVector translate_top_k_op(const NodeContext& node) {
// retrieve k attribute
auto k = node.get_attribute<int64_t>("k");
auto k_input = make_shared<Constant>(ov::element::i64, Shape{}, std::vector<int64_t>({k}));
return translate_top_k_base_op(node, k_input, 1);
}
OutputVector translate_top_k_v2_op(const NodeContext& node) { OutputVector translate_top_k_v2_op(const NodeContext& node) {
auto input = node.get_input(0); default_op_checks(node, 2, {"TopKV2"});
auto k = node.get_input(1); auto k_input = node.get_input(1);
return translate_top_k_base_op(node, k_input, 1);
TENSORFLOW_OP_VALIDATION(node, input.get_partial_shape().rank().is_static(), "Input rank must be static.");
TENSORFLOW_OP_VALIDATION(node,
input.get_partial_shape().rank().get_length() >= 1,
"Input rank must be greater than 0.");
// axis along which to compute top k indices
int64_t k_axis = input.get_partial_shape().rank().get_length() - 1;
bool sorted = node.get_attribute<bool>("sorted", true);
auto res = std::make_shared<TopK>(input,
k,
k_axis,
TopK::Mode::MAX,
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES);
set_node_name(node.get_name(), res);
return res->outputs();
} }
} // namespace op } // namespace op

View File

@ -54,6 +54,7 @@ OP_CONVERTER(translate_gather_op);
OP_CONVERTER(translate_gather_v2_op); OP_CONVERTER(translate_gather_v2_op);
OP_CONVERTER(translate_gather_nd_op); OP_CONVERTER(translate_gather_nd_op);
OP_CONVERTER(translate_identity_op); OP_CONVERTER(translate_identity_op);
OP_CONVERTER(translate_identity_n_op);
OP_CONVERTER(translate_interpolate_op); OP_CONVERTER(translate_interpolate_op);
OP_CONVERTER(translate_is_finite_op); OP_CONVERTER(translate_is_finite_op);
OP_CONVERTER(translate_l2_loss_op); OP_CONVERTER(translate_l2_loss_op);
@ -67,6 +68,7 @@ OP_CONVERTER(translate_mat_mul_op);
OP_CONVERTER(translate_matrix_diag_op); OP_CONVERTER(translate_matrix_diag_op);
OP_CONVERTER(translate_max_pool_op); OP_CONVERTER(translate_max_pool_op);
OP_CONVERTER(translate_non_max_suppression_op); OP_CONVERTER(translate_non_max_suppression_op);
OP_CONVERTER(translate_normalize_l2_op);
OP_CONVERTER(translate_pad_op); OP_CONVERTER(translate_pad_op);
OP_CONVERTER(translate_placeholder_op); OP_CONVERTER(translate_placeholder_op);
OP_CONVERTER(translate_placeholder_with_default_op); OP_CONVERTER(translate_placeholder_with_default_op);
@ -82,6 +84,7 @@ OP_CONVERTER(translate_reciprocal_op);
OP_CONVERTER(translate_reshape_op); OP_CONVERTER(translate_reshape_op);
OP_CONVERTER(translate_resource_gather_op); OP_CONVERTER(translate_resource_gather_op);
OP_CONVERTER(translate_reverse_op); OP_CONVERTER(translate_reverse_op);
OP_CONVERTER(translate_reverse_sequence_op);
OP_CONVERTER(translate_roll_op); OP_CONVERTER(translate_roll_op);
OP_CONVERTER(translate_round_op); OP_CONVERTER(translate_round_op);
OP_CONVERTER(translate_rsqrt_op); OP_CONVERTER(translate_rsqrt_op);
@ -99,6 +102,7 @@ OP_CONVERTER(translate_squeeze_op);
OP_CONVERTER(translate_strided_slice_op); OP_CONVERTER(translate_strided_slice_op);
OP_CONVERTER(translate_sqrt_op); OP_CONVERTER(translate_sqrt_op);
OP_CONVERTER(translate_tile_op); OP_CONVERTER(translate_tile_op);
OP_CONVERTER(translate_top_k_op);
OP_CONVERTER(translate_top_k_v2_op); OP_CONVERTER(translate_top_k_v2_op);
OP_CONVERTER(translate_transpose_op); OP_CONVERTER(translate_transpose_op);
OP_CONVERTER(translate_unpack_op); OP_CONVERTER(translate_unpack_op);
@ -208,7 +212,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"GatherV2", translate_gather_v2_op}, {"GatherV2", translate_gather_v2_op},
{"GatherNd", translate_gather_nd_op}, {"GatherNd", translate_gather_nd_op},
{"Identity", translate_identity_op}, {"Identity", translate_identity_op},
{"IdentityN", translate_identity_op}, {"IdentityN", translate_identity_n_op},
{"IsFinite", translate_is_finite_op}, {"IsFinite", translate_is_finite_op},
{"L2Loss", translate_l2_loss_op}, {"L2Loss", translate_l2_loss_op},
{"LeakyRelu", translate_leaky_relu_op}, {"LeakyRelu", translate_leaky_relu_op},
@ -229,6 +233,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"NonMaxSuppressionV4", translate_non_max_suppression_op}, {"NonMaxSuppressionV4", translate_non_max_suppression_op},
{"NonMaxSuppressionV5", translate_non_max_suppression_op}, {"NonMaxSuppressionV5", translate_non_max_suppression_op},
{"NoOp", translate_no_op}, // do nothing {"NoOp", translate_no_op}, // do nothing
{"NormalizeL2", translate_normalize_l2_op},
{"OneHot", translate_one_hot_op}, {"OneHot", translate_one_hot_op},
{"Pack", translate_pack_op}, {"Pack", translate_pack_op},
{"Pad", translate_pad_op}, {"Pad", translate_pad_op},
@ -244,6 +249,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Relu6", translate_relu_6_op}, {"Relu6", translate_relu_6_op},
{"Reshape", translate_reshape_op}, {"Reshape", translate_reshape_op},
{"Reverse", translate_reverse_op}, {"Reverse", translate_reverse_op},
{"ReverseSequence", translate_reverse_sequence_op},
{"ReverseV2", translate_reverse_op}, {"ReverseV2", translate_reverse_op},
{"ResizeBilinear", translate_interpolate_op}, {"ResizeBilinear", translate_interpolate_op},
{"ResizeNearestNeighbor", translate_interpolate_op}, {"ResizeNearestNeighbor", translate_interpolate_op},
@ -269,6 +275,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"SpaceToBatchND", translate_batch_nd_and_space_nd_op}, {"SpaceToBatchND", translate_batch_nd_and_space_nd_op},
{"StridedSlice", translate_strided_slice_op}, {"StridedSlice", translate_strided_slice_op},
{"Tile", translate_tile_op}, {"Tile", translate_tile_op},
{"TopK", translate_top_k_op},
{"TopKV2", translate_top_k_v2_op}, {"TopKV2", translate_top_k_v2_op},
{"Transpose", translate_transpose_op}, {"Transpose", translate_transpose_op},
{"Unpack", translate_unpack_op}, {"Unpack", translate_unpack_op},