[TF FE] Add translators for ScatterND, Conv3DBackpropInputV2 ops (#10550)

* Add translators for ScatterND, ConvBackpropInputV2 ops

* add a new line
This commit is contained in:
Ivan Tikhonov 2022-02-22 12:20:32 +03:00 committed by GitHub
parent 5247fdfcaf
commit 472ebc0cd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 137 additions and 0 deletions

View File

@ -0,0 +1,104 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op_table.hpp"
#include "openvino/opsets/opset8.hpp"
using namespace std;
using namespace ov::opset8;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) {
auto ng_filter = node.get_input(1);
auto ng_out_backprop = node.get_input(2);
// TODO: refactor me to be less redundant with other convolution ops
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
auto tf_dilations = node.get_attribute<std::vector<int64_t>>("dilations");
auto tf_padding_type = node.get_attribute<std::string>("padding");
auto tf_data_format = node.get_attribute<std::string>("data_format");
TENSORFLOW_OP_VALIDATION(node,
tf_data_format == "NDHWC" || tf_data_format == "NCDHW",
"Conv3DBackpropInputV2 data format is neither NDHWC nor NCDHW. "
"Provided data format: ",
tf_data_format);
std::vector<int64_t> tf_input_sizes;
get_const_input(node, 0, &tf_input_sizes);
if (std::any_of(tf_input_sizes.begin(), tf_input_sizes.end(), [](int32_t size) {
return size <= 0;
})) {
FRONT_END_THROW("Conv3DBackpropInputV2 input sizes must be positive integers");
}
bool is_ndhwc = (tf_data_format == "NDHWC");
ov::Strides ng_strides(3);
ov::Strides ng_dilations(3);
ov::Shape ng_image_shape(3);
ov::Shape ng_kernel_shape(3);
ov::Shape ng_batch_shape(5);
convert_nhwc_to_hw(is_ndhwc, tf_strides, ng_strides);
convert_nhwc_to_hw(is_ndhwc, tf_dilations, ng_dilations);
convert_nhwc_to_hw(is_ndhwc, tf_input_sizes, ng_image_shape);
convert_nhwc_to_nchw(node.get_name(), is_ndhwc, ng_out_backprop);
if (is_ndhwc) {
ng_batch_shape = {static_cast<unsigned long>(tf_input_sizes[0]),
static_cast<unsigned long>(tf_input_sizes[4]),
static_cast<unsigned long>(tf_input_sizes[1]),
static_cast<unsigned long>(tf_input_sizes[2]),
static_cast<unsigned long>(tf_input_sizes[3])};
} else {
ng_batch_shape = {static_cast<unsigned long>(tf_input_sizes[0]),
static_cast<unsigned long>(tf_input_sizes[1]),
static_cast<unsigned long>(tf_input_sizes[2]),
static_cast<unsigned long>(tf_input_sizes[3]),
static_cast<unsigned long>(tf_input_sizes[4])};
}
auto& ng_filter_shape = ng_filter.get_shape();
ng_kernel_shape[0] = ng_filter_shape[0];
ng_kernel_shape[1] = ng_filter_shape[1];
ng_kernel_shape[2] = ng_filter_shape[2];
transpose_3d<4, 3, 0, 1, 2>(ng_filter);
ov::CoordinateDiff ng_padding_below;
ov::CoordinateDiff ng_padding_above;
make_padding(tf_padding_type,
ng_image_shape,
ng_kernel_shape,
ng_strides,
ng_dilations,
ng_padding_below,
ng_padding_above);
auto ng_output_shape = make_shared<Constant>(element::i64,
Shape{ng_batch_shape.size() - 2},
vector<size_t>(ng_batch_shape.begin() + 2, ng_batch_shape.end()));
auto res_node = make_shared<ConvolutionBackpropData>(ng_out_backprop,
ng_filter,
ng_output_shape,
ng_strides,
ng_padding_below,
ng_padding_above,
ng_dilations);
auto res = res_node->output(0);
convert_nchw_to_nhwc(node.get_name(), is_ndhwc, res);
set_node_name(node.get_name(), res.get_node_shared_ptr());
return {res};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,29 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op_table.hpp"
#include "openvino/opsets/opset8.hpp"
using namespace std;
using namespace ov::opset8;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_scatter_nd_op(const NodeContext& node) {
auto input_indices = node.get_input(0);
auto updates = node.get_input(1);
auto shape = node.get_input(2);
auto input_data = make_shared<opset8::Constant>(updates.get_element_type(), Shape{1}, 0);
auto broadcast = make_shared<opset8::Broadcast>(input_data, shape);
auto res = make_shared<opset8::ScatterNDUpdate>(broadcast, input_indices, updates);
set_node_name(node.get_name(), res);
return res->outputs();
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -33,6 +33,7 @@ OP_CONVERTER(translate_const_op);
OP_CONVERTER(translate_conv_2d_op);
OP_CONVERTER(translate_conv_2d_backprop_input_op);
OP_CONVERTER(translate_conv_3d_op);
OP_CONVERTER(translate_conv_3d_backprop_input_v2_op);
OP_CONVERTER(translate_cumsum_op);
OP_CONVERTER(translate_crop_and_resize_op);
OP_CONVERTER(translate_depth_to_space_op);
@ -73,6 +74,7 @@ OP_CONVERTER(translate_reverse_op);
OP_CONVERTER(translate_roll_op);
OP_CONVERTER(translate_round_op);
OP_CONVERTER(translate_rsqrt_op);
OP_CONVERTER(translate_scatter_nd_op);
OP_CONVERTER(translate_select_op);
OP_CONVERTER(translate_shape_op);
OP_CONVERTER(translate_size_op);
@ -167,6 +169,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Conv2D", translate_conv_2d_op},
{"Conv2DBackpropInput", translate_conv_2d_backprop_input_op},
{"Conv3D", translate_conv_3d_op},
{"Conv3DBackpropInputV2", translate_conv_3d_backprop_input_v2_op},
{"CropAndResize", translate_crop_and_resize_op},
{"Cumsum", translate_cumsum_op},
{"DepthToSpace", translate_depth_to_space_op},
@ -220,6 +223,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Roll", translate_roll_op},
{"Round", translate_round_op},
{"Rsqrt", translate_rsqrt_op},
{"ScatterNd", translate_scatter_nd_op},
{"Select", translate_select_op},
{"SelectV2", translate_select_op},
{"Shape", translate_shape_op},