[TF FE] Add translators for ScatterND, Conv3DBackpropInputV2 ops (#10550)
* Add translators for ScatterND, ConvBackpropInputV2 ops * add a new line
This commit is contained in:
parent
5247fdfcaf
commit
472ebc0cd9
104
src/frontends/tensorflow/src/op/conv_3d_backprop.cpp
Normal file
104
src/frontends/tensorflow/src/op/conv_3d_backprop.cpp
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "op_table.hpp"
|
||||||
|
#include "openvino/opsets/opset8.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov::opset8;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) {
|
||||||
|
auto ng_filter = node.get_input(1);
|
||||||
|
auto ng_out_backprop = node.get_input(2);
|
||||||
|
|
||||||
|
// TODO: refactor me to be less redundant with other convolution ops
|
||||||
|
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
|
||||||
|
auto tf_dilations = node.get_attribute<std::vector<int64_t>>("dilations");
|
||||||
|
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||||
|
auto tf_data_format = node.get_attribute<std::string>("data_format");
|
||||||
|
|
||||||
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
|
tf_data_format == "NDHWC" || tf_data_format == "NCDHW",
|
||||||
|
"Conv3DBackpropInputV2 data format is neither NDHWC nor NCDHW. "
|
||||||
|
"Provided data format: ",
|
||||||
|
tf_data_format);
|
||||||
|
|
||||||
|
std::vector<int64_t> tf_input_sizes;
|
||||||
|
get_const_input(node, 0, &tf_input_sizes);
|
||||||
|
|
||||||
|
if (std::any_of(tf_input_sizes.begin(), tf_input_sizes.end(), [](int32_t size) {
|
||||||
|
return size <= 0;
|
||||||
|
})) {
|
||||||
|
FRONT_END_THROW("Conv3DBackpropInputV2 input sizes must be positive integers");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_ndhwc = (tf_data_format == "NDHWC");
|
||||||
|
|
||||||
|
ov::Strides ng_strides(3);
|
||||||
|
ov::Strides ng_dilations(3);
|
||||||
|
ov::Shape ng_image_shape(3);
|
||||||
|
ov::Shape ng_kernel_shape(3);
|
||||||
|
ov::Shape ng_batch_shape(5);
|
||||||
|
|
||||||
|
convert_nhwc_to_hw(is_ndhwc, tf_strides, ng_strides);
|
||||||
|
convert_nhwc_to_hw(is_ndhwc, tf_dilations, ng_dilations);
|
||||||
|
convert_nhwc_to_hw(is_ndhwc, tf_input_sizes, ng_image_shape);
|
||||||
|
convert_nhwc_to_nchw(node.get_name(), is_ndhwc, ng_out_backprop);
|
||||||
|
if (is_ndhwc) {
|
||||||
|
ng_batch_shape = {static_cast<unsigned long>(tf_input_sizes[0]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[4]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[1]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[2]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[3])};
|
||||||
|
} else {
|
||||||
|
ng_batch_shape = {static_cast<unsigned long>(tf_input_sizes[0]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[1]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[2]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[3]),
|
||||||
|
static_cast<unsigned long>(tf_input_sizes[4])};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& ng_filter_shape = ng_filter.get_shape();
|
||||||
|
ng_kernel_shape[0] = ng_filter_shape[0];
|
||||||
|
ng_kernel_shape[1] = ng_filter_shape[1];
|
||||||
|
ng_kernel_shape[2] = ng_filter_shape[2];
|
||||||
|
transpose_3d<4, 3, 0, 1, 2>(ng_filter);
|
||||||
|
|
||||||
|
ov::CoordinateDiff ng_padding_below;
|
||||||
|
ov::CoordinateDiff ng_padding_above;
|
||||||
|
|
||||||
|
make_padding(tf_padding_type,
|
||||||
|
ng_image_shape,
|
||||||
|
ng_kernel_shape,
|
||||||
|
ng_strides,
|
||||||
|
ng_dilations,
|
||||||
|
ng_padding_below,
|
||||||
|
ng_padding_above);
|
||||||
|
|
||||||
|
auto ng_output_shape = make_shared<Constant>(element::i64,
|
||||||
|
Shape{ng_batch_shape.size() - 2},
|
||||||
|
vector<size_t>(ng_batch_shape.begin() + 2, ng_batch_shape.end()));
|
||||||
|
|
||||||
|
auto res_node = make_shared<ConvolutionBackpropData>(ng_out_backprop,
|
||||||
|
ng_filter,
|
||||||
|
ng_output_shape,
|
||||||
|
ng_strides,
|
||||||
|
ng_padding_below,
|
||||||
|
ng_padding_above,
|
||||||
|
ng_dilations);
|
||||||
|
auto res = res_node->output(0);
|
||||||
|
|
||||||
|
convert_nchw_to_nhwc(node.get_name(), is_ndhwc, res);
|
||||||
|
set_node_name(node.get_name(), res.get_node_shared_ptr());
|
||||||
|
return {res};
|
||||||
|
}
|
||||||
|
} // namespace op
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
29
src/frontends/tensorflow/src/op/scatter_nd.cpp
Normal file
29
src/frontends/tensorflow/src/op/scatter_nd.cpp
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "op_table.hpp"
|
||||||
|
#include "openvino/opsets/opset8.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov::opset8;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace op {
|
||||||
|
OutputVector translate_scatter_nd_op(const NodeContext& node) {
|
||||||
|
auto input_indices = node.get_input(0);
|
||||||
|
auto updates = node.get_input(1);
|
||||||
|
auto shape = node.get_input(2);
|
||||||
|
|
||||||
|
auto input_data = make_shared<opset8::Constant>(updates.get_element_type(), Shape{1}, 0);
|
||||||
|
auto broadcast = make_shared<opset8::Broadcast>(input_data, shape);
|
||||||
|
auto res = make_shared<opset8::ScatterNDUpdate>(broadcast, input_indices, updates);
|
||||||
|
set_node_name(node.get_name(), res);
|
||||||
|
return res->outputs();
|
||||||
|
}
|
||||||
|
} // namespace op
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -33,6 +33,7 @@ OP_CONVERTER(translate_const_op);
|
|||||||
OP_CONVERTER(translate_conv_2d_op);
|
OP_CONVERTER(translate_conv_2d_op);
|
||||||
OP_CONVERTER(translate_conv_2d_backprop_input_op);
|
OP_CONVERTER(translate_conv_2d_backprop_input_op);
|
||||||
OP_CONVERTER(translate_conv_3d_op);
|
OP_CONVERTER(translate_conv_3d_op);
|
||||||
|
OP_CONVERTER(translate_conv_3d_backprop_input_v2_op);
|
||||||
OP_CONVERTER(translate_cumsum_op);
|
OP_CONVERTER(translate_cumsum_op);
|
||||||
OP_CONVERTER(translate_crop_and_resize_op);
|
OP_CONVERTER(translate_crop_and_resize_op);
|
||||||
OP_CONVERTER(translate_depth_to_space_op);
|
OP_CONVERTER(translate_depth_to_space_op);
|
||||||
@ -73,6 +74,7 @@ OP_CONVERTER(translate_reverse_op);
|
|||||||
OP_CONVERTER(translate_roll_op);
|
OP_CONVERTER(translate_roll_op);
|
||||||
OP_CONVERTER(translate_round_op);
|
OP_CONVERTER(translate_round_op);
|
||||||
OP_CONVERTER(translate_rsqrt_op);
|
OP_CONVERTER(translate_rsqrt_op);
|
||||||
|
OP_CONVERTER(translate_scatter_nd_op);
|
||||||
OP_CONVERTER(translate_select_op);
|
OP_CONVERTER(translate_select_op);
|
||||||
OP_CONVERTER(translate_shape_op);
|
OP_CONVERTER(translate_shape_op);
|
||||||
OP_CONVERTER(translate_size_op);
|
OP_CONVERTER(translate_size_op);
|
||||||
@ -167,6 +169,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"Conv2D", translate_conv_2d_op},
|
{"Conv2D", translate_conv_2d_op},
|
||||||
{"Conv2DBackpropInput", translate_conv_2d_backprop_input_op},
|
{"Conv2DBackpropInput", translate_conv_2d_backprop_input_op},
|
||||||
{"Conv3D", translate_conv_3d_op},
|
{"Conv3D", translate_conv_3d_op},
|
||||||
|
{"Conv3DBackpropInputV2", translate_conv_3d_backprop_input_v2_op},
|
||||||
{"CropAndResize", translate_crop_and_resize_op},
|
{"CropAndResize", translate_crop_and_resize_op},
|
||||||
{"Cumsum", translate_cumsum_op},
|
{"Cumsum", translate_cumsum_op},
|
||||||
{"DepthToSpace", translate_depth_to_space_op},
|
{"DepthToSpace", translate_depth_to_space_op},
|
||||||
@ -220,6 +223,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"Roll", translate_roll_op},
|
{"Roll", translate_roll_op},
|
||||||
{"Round", translate_round_op},
|
{"Round", translate_round_op},
|
||||||
{"Rsqrt", translate_rsqrt_op},
|
{"Rsqrt", translate_rsqrt_op},
|
||||||
|
{"ScatterNd", translate_scatter_nd_op},
|
||||||
{"Select", translate_select_op},
|
{"Select", translate_select_op},
|
||||||
{"SelectV2", translate_select_op},
|
{"SelectV2", translate_select_op},
|
||||||
{"Shape", translate_shape_op},
|
{"Shape", translate_shape_op},
|
||||||
|
Loading…
Reference in New Issue
Block a user