[TF FE] Add translators for NormalizeL2, ReverseSequence (#12913)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
8922d73e7d
commit
0b1a70be0b
@ -26,6 +26,20 @@ OutputVector translate_identity_op(const NodeContext& node) {
|
|||||||
return {input};
|
return {input};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OutputVector translate_identity_n_op(const NodeContext& node) {
|
||||||
|
auto input_size = node.get_input_size();
|
||||||
|
auto node_name = node.get_name();
|
||||||
|
|
||||||
|
OutputVector result;
|
||||||
|
for (int input_idx = 0; input_idx < input_size; ++input_idx) {
|
||||||
|
auto input = node.get_input(input_idx);
|
||||||
|
set_out_name(node_name + ":" + std::to_string(input_idx), input);
|
||||||
|
result.push_back(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
} // namespace frontend
|
} // namespace frontend
|
||||||
|
30
src/frontends/tensorflow/src/op/normalize_l2.cpp
Normal file
30
src/frontends/tensorflow/src/op/normalize_l2.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "op_table.hpp"
|
||||||
|
#include "openvino/opsets/opset8.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov::opset8;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace op {
|
||||||
|
OutputVector translate_normalize_l2_op(const NodeContext& node) {
|
||||||
|
default_op_checks(node, 2, {"NormalizeL2"});
|
||||||
|
auto x = node.get_input(0);
|
||||||
|
auto axes = node.get_input(1);
|
||||||
|
|
||||||
|
// retrieve attribute
|
||||||
|
auto eps = node.get_attribute<float>("epsilon");
|
||||||
|
|
||||||
|
auto normalize_l2 = make_shared<NormalizeL2>(x, axes, eps, ov::op::EpsMode::MAX);
|
||||||
|
set_node_name(node.get_name(), normalize_l2);
|
||||||
|
return {normalize_l2};
|
||||||
|
}
|
||||||
|
} // namespace op
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
31
src/frontends/tensorflow/src/op/reverse_sequence.cpp
Normal file
31
src/frontends/tensorflow/src/op/reverse_sequence.cpp
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "op_table.hpp"
|
||||||
|
#include "openvino/opsets/opset8.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov::opset8;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace op {
|
||||||
|
OutputVector translate_reverse_sequence_op(const NodeContext& node) {
|
||||||
|
default_op_checks(node, 2, {"ReverseSequence"});
|
||||||
|
auto input = node.get_input(0);
|
||||||
|
auto seq_lengths = node.get_input(1);
|
||||||
|
|
||||||
|
// retrieve attributes
|
||||||
|
auto seq_axis = node.get_attribute<int64_t>("seq_dim");
|
||||||
|
auto batch_axis = node.get_attribute<int64_t>("batch_dim", 0);
|
||||||
|
|
||||||
|
auto reverse_sequence = make_shared<ReverseSequence>(input, seq_lengths, batch_axis, seq_axis);
|
||||||
|
set_node_name(node.get_name(), reverse_sequence);
|
||||||
|
return {reverse_sequence};
|
||||||
|
}
|
||||||
|
} // namespace op
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -12,25 +12,32 @@ namespace ov {
|
|||||||
namespace frontend {
|
namespace frontend {
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace op {
|
namespace op {
|
||||||
|
OutputVector translate_top_k_base_op(const NodeContext& node, const ov::Output<ov::Node>& k_input, int min_input_size) {
|
||||||
|
default_op_checks(node, min_input_size, {"TopK", "TopKV2"});
|
||||||
|
auto input = node.get_input(0);
|
||||||
|
|
||||||
|
// retrieve k attribute
|
||||||
|
bool sorted = node.get_attribute<bool>("sorted", true);
|
||||||
|
auto top_k = make_shared<TopK>(input,
|
||||||
|
k_input,
|
||||||
|
-1,
|
||||||
|
ov::op::v1::TopK::Mode::MAX,
|
||||||
|
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES,
|
||||||
|
ov::element::i32);
|
||||||
|
set_node_name(node.get_name(), top_k);
|
||||||
|
return {top_k};
|
||||||
|
}
|
||||||
|
OutputVector translate_top_k_op(const NodeContext& node) {
|
||||||
|
// retrieve k attribute
|
||||||
|
auto k = node.get_attribute<int64_t>("k");
|
||||||
|
auto k_input = make_shared<Constant>(ov::element::i64, Shape{}, std::vector<int64_t>({k}));
|
||||||
|
return translate_top_k_base_op(node, k_input, 1);
|
||||||
|
}
|
||||||
|
|
||||||
OutputVector translate_top_k_v2_op(const NodeContext& node) {
|
OutputVector translate_top_k_v2_op(const NodeContext& node) {
|
||||||
auto input = node.get_input(0);
|
default_op_checks(node, 2, {"TopKV2"});
|
||||||
auto k = node.get_input(1);
|
auto k_input = node.get_input(1);
|
||||||
|
return translate_top_k_base_op(node, k_input, 1);
|
||||||
TENSORFLOW_OP_VALIDATION(node, input.get_partial_shape().rank().is_static(), "Input rank must be static.");
|
|
||||||
TENSORFLOW_OP_VALIDATION(node,
|
|
||||||
input.get_partial_shape().rank().get_length() >= 1,
|
|
||||||
"Input rank must be greater than 0.");
|
|
||||||
// axis along which to compute top k indices
|
|
||||||
int64_t k_axis = input.get_partial_shape().rank().get_length() - 1;
|
|
||||||
bool sorted = node.get_attribute<bool>("sorted", true);
|
|
||||||
auto res = std::make_shared<TopK>(input,
|
|
||||||
k,
|
|
||||||
k_axis,
|
|
||||||
TopK::Mode::MAX,
|
|
||||||
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES);
|
|
||||||
set_node_name(node.get_name(), res);
|
|
||||||
return res->outputs();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -54,6 +54,7 @@ OP_CONVERTER(translate_gather_op);
|
|||||||
OP_CONVERTER(translate_gather_v2_op);
|
OP_CONVERTER(translate_gather_v2_op);
|
||||||
OP_CONVERTER(translate_gather_nd_op);
|
OP_CONVERTER(translate_gather_nd_op);
|
||||||
OP_CONVERTER(translate_identity_op);
|
OP_CONVERTER(translate_identity_op);
|
||||||
|
OP_CONVERTER(translate_identity_n_op);
|
||||||
OP_CONVERTER(translate_interpolate_op);
|
OP_CONVERTER(translate_interpolate_op);
|
||||||
OP_CONVERTER(translate_is_finite_op);
|
OP_CONVERTER(translate_is_finite_op);
|
||||||
OP_CONVERTER(translate_l2_loss_op);
|
OP_CONVERTER(translate_l2_loss_op);
|
||||||
@ -67,6 +68,7 @@ OP_CONVERTER(translate_mat_mul_op);
|
|||||||
OP_CONVERTER(translate_matrix_diag_op);
|
OP_CONVERTER(translate_matrix_diag_op);
|
||||||
OP_CONVERTER(translate_max_pool_op);
|
OP_CONVERTER(translate_max_pool_op);
|
||||||
OP_CONVERTER(translate_non_max_suppression_op);
|
OP_CONVERTER(translate_non_max_suppression_op);
|
||||||
|
OP_CONVERTER(translate_normalize_l2_op);
|
||||||
OP_CONVERTER(translate_pad_op);
|
OP_CONVERTER(translate_pad_op);
|
||||||
OP_CONVERTER(translate_placeholder_op);
|
OP_CONVERTER(translate_placeholder_op);
|
||||||
OP_CONVERTER(translate_placeholder_with_default_op);
|
OP_CONVERTER(translate_placeholder_with_default_op);
|
||||||
@ -82,6 +84,7 @@ OP_CONVERTER(translate_reciprocal_op);
|
|||||||
OP_CONVERTER(translate_reshape_op);
|
OP_CONVERTER(translate_reshape_op);
|
||||||
OP_CONVERTER(translate_resource_gather_op);
|
OP_CONVERTER(translate_resource_gather_op);
|
||||||
OP_CONVERTER(translate_reverse_op);
|
OP_CONVERTER(translate_reverse_op);
|
||||||
|
OP_CONVERTER(translate_reverse_sequence_op);
|
||||||
OP_CONVERTER(translate_roll_op);
|
OP_CONVERTER(translate_roll_op);
|
||||||
OP_CONVERTER(translate_round_op);
|
OP_CONVERTER(translate_round_op);
|
||||||
OP_CONVERTER(translate_rsqrt_op);
|
OP_CONVERTER(translate_rsqrt_op);
|
||||||
@ -99,6 +102,7 @@ OP_CONVERTER(translate_squeeze_op);
|
|||||||
OP_CONVERTER(translate_strided_slice_op);
|
OP_CONVERTER(translate_strided_slice_op);
|
||||||
OP_CONVERTER(translate_sqrt_op);
|
OP_CONVERTER(translate_sqrt_op);
|
||||||
OP_CONVERTER(translate_tile_op);
|
OP_CONVERTER(translate_tile_op);
|
||||||
|
OP_CONVERTER(translate_top_k_op);
|
||||||
OP_CONVERTER(translate_top_k_v2_op);
|
OP_CONVERTER(translate_top_k_v2_op);
|
||||||
OP_CONVERTER(translate_transpose_op);
|
OP_CONVERTER(translate_transpose_op);
|
||||||
OP_CONVERTER(translate_unpack_op);
|
OP_CONVERTER(translate_unpack_op);
|
||||||
@ -208,7 +212,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"GatherV2", translate_gather_v2_op},
|
{"GatherV2", translate_gather_v2_op},
|
||||||
{"GatherNd", translate_gather_nd_op},
|
{"GatherNd", translate_gather_nd_op},
|
||||||
{"Identity", translate_identity_op},
|
{"Identity", translate_identity_op},
|
||||||
{"IdentityN", translate_identity_op},
|
{"IdentityN", translate_identity_n_op},
|
||||||
{"IsFinite", translate_is_finite_op},
|
{"IsFinite", translate_is_finite_op},
|
||||||
{"L2Loss", translate_l2_loss_op},
|
{"L2Loss", translate_l2_loss_op},
|
||||||
{"LeakyRelu", translate_leaky_relu_op},
|
{"LeakyRelu", translate_leaky_relu_op},
|
||||||
@ -229,6 +233,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"NonMaxSuppressionV4", translate_non_max_suppression_op},
|
{"NonMaxSuppressionV4", translate_non_max_suppression_op},
|
||||||
{"NonMaxSuppressionV5", translate_non_max_suppression_op},
|
{"NonMaxSuppressionV5", translate_non_max_suppression_op},
|
||||||
{"NoOp", translate_no_op}, // do nothing
|
{"NoOp", translate_no_op}, // do nothing
|
||||||
|
{"NormalizeL2", translate_normalize_l2_op},
|
||||||
{"OneHot", translate_one_hot_op},
|
{"OneHot", translate_one_hot_op},
|
||||||
{"Pack", translate_pack_op},
|
{"Pack", translate_pack_op},
|
||||||
{"Pad", translate_pad_op},
|
{"Pad", translate_pad_op},
|
||||||
@ -244,6 +249,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"Relu6", translate_relu_6_op},
|
{"Relu6", translate_relu_6_op},
|
||||||
{"Reshape", translate_reshape_op},
|
{"Reshape", translate_reshape_op},
|
||||||
{"Reverse", translate_reverse_op},
|
{"Reverse", translate_reverse_op},
|
||||||
|
{"ReverseSequence", translate_reverse_sequence_op},
|
||||||
{"ReverseV2", translate_reverse_op},
|
{"ReverseV2", translate_reverse_op},
|
||||||
{"ResizeBilinear", translate_interpolate_op},
|
{"ResizeBilinear", translate_interpolate_op},
|
||||||
{"ResizeNearestNeighbor", translate_interpolate_op},
|
{"ResizeNearestNeighbor", translate_interpolate_op},
|
||||||
@ -269,6 +275,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"SpaceToBatchND", translate_batch_nd_and_space_nd_op},
|
{"SpaceToBatchND", translate_batch_nd_and_space_nd_op},
|
||||||
{"StridedSlice", translate_strided_slice_op},
|
{"StridedSlice", translate_strided_slice_op},
|
||||||
{"Tile", translate_tile_op},
|
{"Tile", translate_tile_op},
|
||||||
|
{"TopK", translate_top_k_op},
|
||||||
{"TopKV2", translate_top_k_v2_op},
|
{"TopKV2", translate_top_k_v2_op},
|
||||||
{"Transpose", translate_transpose_op},
|
{"Transpose", translate_transpose_op},
|
||||||
{"Unpack", translate_unpack_op},
|
{"Unpack", translate_unpack_op},
|
||||||
|
Loading…
Reference in New Issue
Block a user