[TF FE] Add translators for NormalizeL2, ReverseSequence (#12913)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
8922d73e7d
commit
0b1a70be0b
@ -26,6 +26,20 @@ OutputVector translate_identity_op(const NodeContext& node) {
|
||||
return {input};
|
||||
}
|
||||
|
||||
OutputVector translate_identity_n_op(const NodeContext& node) {
|
||||
auto input_size = node.get_input_size();
|
||||
auto node_name = node.get_name();
|
||||
|
||||
OutputVector result;
|
||||
for (int input_idx = 0; input_idx < input_size; ++input_idx) {
|
||||
auto input = node.get_input(input_idx);
|
||||
set_out_name(node_name + ":" + std::to_string(input_idx), input);
|
||||
result.push_back(input);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
|
30
src/frontends/tensorflow/src/op/normalize_l2.cpp
Normal file
30
src/frontends/tensorflow/src/op/normalize_l2.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_normalize_l2_op(const NodeContext& node) {
|
||||
default_op_checks(node, 2, {"NormalizeL2"});
|
||||
auto x = node.get_input(0);
|
||||
auto axes = node.get_input(1);
|
||||
|
||||
// retrieve attribute
|
||||
auto eps = node.get_attribute<float>("epsilon");
|
||||
|
||||
auto normalize_l2 = make_shared<NormalizeL2>(x, axes, eps, ov::op::EpsMode::MAX);
|
||||
set_node_name(node.get_name(), normalize_l2);
|
||||
return {normalize_l2};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
31
src/frontends/tensorflow/src/op/reverse_sequence.cpp
Normal file
31
src/frontends/tensorflow/src/op/reverse_sequence.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_reverse_sequence_op(const NodeContext& node) {
|
||||
default_op_checks(node, 2, {"ReverseSequence"});
|
||||
auto input = node.get_input(0);
|
||||
auto seq_lengths = node.get_input(1);
|
||||
|
||||
// retrieve attributes
|
||||
auto seq_axis = node.get_attribute<int64_t>("seq_dim");
|
||||
auto batch_axis = node.get_attribute<int64_t>("batch_dim", 0);
|
||||
|
||||
auto reverse_sequence = make_shared<ReverseSequence>(input, seq_lengths, batch_axis, seq_axis);
|
||||
set_node_name(node.get_name(), reverse_sequence);
|
||||
return {reverse_sequence};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -12,25 +12,32 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_top_k_base_op(const NodeContext& node, const ov::Output<ov::Node>& k_input, int min_input_size) {
|
||||
default_op_checks(node, min_input_size, {"TopK", "TopKV2"});
|
||||
auto input = node.get_input(0);
|
||||
|
||||
// retrieve k attribute
|
||||
bool sorted = node.get_attribute<bool>("sorted", true);
|
||||
auto top_k = make_shared<TopK>(input,
|
||||
k_input,
|
||||
-1,
|
||||
ov::op::v1::TopK::Mode::MAX,
|
||||
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES,
|
||||
ov::element::i32);
|
||||
set_node_name(node.get_name(), top_k);
|
||||
return {top_k};
|
||||
}
|
||||
OutputVector translate_top_k_op(const NodeContext& node) {
|
||||
// retrieve k attribute
|
||||
auto k = node.get_attribute<int64_t>("k");
|
||||
auto k_input = make_shared<Constant>(ov::element::i64, Shape{}, std::vector<int64_t>({k}));
|
||||
return translate_top_k_base_op(node, k_input, 1);
|
||||
}
|
||||
|
||||
OutputVector translate_top_k_v2_op(const NodeContext& node) {
|
||||
auto input = node.get_input(0);
|
||||
auto k = node.get_input(1);
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node, input.get_partial_shape().rank().is_static(), "Input rank must be static.");
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
input.get_partial_shape().rank().get_length() >= 1,
|
||||
"Input rank must be greater than 0.");
|
||||
// axis along which to compute top k indices
|
||||
int64_t k_axis = input.get_partial_shape().rank().get_length() - 1;
|
||||
bool sorted = node.get_attribute<bool>("sorted", true);
|
||||
auto res = std::make_shared<TopK>(input,
|
||||
k,
|
||||
k_axis,
|
||||
TopK::Mode::MAX,
|
||||
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES);
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
default_op_checks(node, 2, {"TopKV2"});
|
||||
auto k_input = node.get_input(1);
|
||||
return translate_top_k_base_op(node, k_input, 1);
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
@ -54,6 +54,7 @@ OP_CONVERTER(translate_gather_op);
|
||||
OP_CONVERTER(translate_gather_v2_op);
|
||||
OP_CONVERTER(translate_gather_nd_op);
|
||||
OP_CONVERTER(translate_identity_op);
|
||||
OP_CONVERTER(translate_identity_n_op);
|
||||
OP_CONVERTER(translate_interpolate_op);
|
||||
OP_CONVERTER(translate_is_finite_op);
|
||||
OP_CONVERTER(translate_l2_loss_op);
|
||||
@ -67,6 +68,7 @@ OP_CONVERTER(translate_mat_mul_op);
|
||||
OP_CONVERTER(translate_matrix_diag_op);
|
||||
OP_CONVERTER(translate_max_pool_op);
|
||||
OP_CONVERTER(translate_non_max_suppression_op);
|
||||
OP_CONVERTER(translate_normalize_l2_op);
|
||||
OP_CONVERTER(translate_pad_op);
|
||||
OP_CONVERTER(translate_placeholder_op);
|
||||
OP_CONVERTER(translate_placeholder_with_default_op);
|
||||
@ -82,6 +84,7 @@ OP_CONVERTER(translate_reciprocal_op);
|
||||
OP_CONVERTER(translate_reshape_op);
|
||||
OP_CONVERTER(translate_resource_gather_op);
|
||||
OP_CONVERTER(translate_reverse_op);
|
||||
OP_CONVERTER(translate_reverse_sequence_op);
|
||||
OP_CONVERTER(translate_roll_op);
|
||||
OP_CONVERTER(translate_round_op);
|
||||
OP_CONVERTER(translate_rsqrt_op);
|
||||
@ -99,6 +102,7 @@ OP_CONVERTER(translate_squeeze_op);
|
||||
OP_CONVERTER(translate_strided_slice_op);
|
||||
OP_CONVERTER(translate_sqrt_op);
|
||||
OP_CONVERTER(translate_tile_op);
|
||||
OP_CONVERTER(translate_top_k_op);
|
||||
OP_CONVERTER(translate_top_k_v2_op);
|
||||
OP_CONVERTER(translate_transpose_op);
|
||||
OP_CONVERTER(translate_unpack_op);
|
||||
@ -208,7 +212,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"GatherV2", translate_gather_v2_op},
|
||||
{"GatherNd", translate_gather_nd_op},
|
||||
{"Identity", translate_identity_op},
|
||||
{"IdentityN", translate_identity_op},
|
||||
{"IdentityN", translate_identity_n_op},
|
||||
{"IsFinite", translate_is_finite_op},
|
||||
{"L2Loss", translate_l2_loss_op},
|
||||
{"LeakyRelu", translate_leaky_relu_op},
|
||||
@ -229,6 +233,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"NonMaxSuppressionV4", translate_non_max_suppression_op},
|
||||
{"NonMaxSuppressionV5", translate_non_max_suppression_op},
|
||||
{"NoOp", translate_no_op}, // do nothing
|
||||
{"NormalizeL2", translate_normalize_l2_op},
|
||||
{"OneHot", translate_one_hot_op},
|
||||
{"Pack", translate_pack_op},
|
||||
{"Pad", translate_pad_op},
|
||||
@ -244,6 +249,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"Relu6", translate_relu_6_op},
|
||||
{"Reshape", translate_reshape_op},
|
||||
{"Reverse", translate_reverse_op},
|
||||
{"ReverseSequence", translate_reverse_sequence_op},
|
||||
{"ReverseV2", translate_reverse_op},
|
||||
{"ResizeBilinear", translate_interpolate_op},
|
||||
{"ResizeNearestNeighbor", translate_interpolate_op},
|
||||
@ -269,6 +275,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"SpaceToBatchND", translate_batch_nd_and_space_nd_op},
|
||||
{"StridedSlice", translate_strided_slice_op},
|
||||
{"Tile", translate_tile_op},
|
||||
{"TopK", translate_top_k_op},
|
||||
{"TopKV2", translate_top_k_v2_op},
|
||||
{"Transpose", translate_transpose_op},
|
||||
{"Unpack", translate_unpack_op},
|
||||
|
Loading…
Reference in New Issue
Block a user