[TF FE] Support dynamic rank support for Convolutional and Pooling operations (#12661)
* [TF FE] Add dynamic rank support for Convolutional and Pooling operations Refactor DepthwiseConv2D, AvgPool, and FusedBatchNorm operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix build issue with rvalue Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix build issue with climit Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Skip duplication of Parameter nodes Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Revert changes in StridedSlice and add check for AvgPool operation type Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Revert the rest of changes for StridedSlice Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix translator for AvgPool: add pad mode Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Introduce helper default_op_checks Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
d70d7b2171
commit
fdac22042c
@ -192,7 +192,9 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model,
|
||||
results.push_back(result);
|
||||
} else {
|
||||
auto param = std::dynamic_pointer_cast<ov::opset8::Parameter>(output.get_node_shared_ptr());
|
||||
if (param && operation_decoder->get_op_type() != "Identity") {
|
||||
// avoid duplicating Parameter nodes if they are already in the Parameters vector
|
||||
if (param && operation_decoder->get_op_type() != "Identity" &&
|
||||
std::find(params.begin(), params.end(), param) == params.end()) {
|
||||
params.push_back(param);
|
||||
}
|
||||
ng_op_map[operation_name].push_back(output);
|
||||
|
@ -14,13 +14,23 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_avg_pool_op(const NodeContext& node) {
|
||||
Output<Node> ng_input = node.get_input(0);
|
||||
default_op_checks(node, 1, {"AvgPool", "AvgPool3D"});
|
||||
auto op_type = node.get_op_type();
|
||||
auto input = node.get_input(0);
|
||||
|
||||
auto spatial_dim = (op_type == "AvgPool") ? 2 : 3;
|
||||
|
||||
// retrieve attributes for AvgPool operation
|
||||
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
|
||||
auto tf_ksize = node.get_attribute<std::vector<int64_t>>("ksize");
|
||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||
auto tf_data_format = node.get_attribute<std::string>("data_format");
|
||||
ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type);
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
auto_pad == ov::op::PadType::VALID || auto_pad == ov::op::PadType::SAME_UPPER,
|
||||
"AvgPool and AvgPool3D supports only VALID or SAME_UPPER padding mode.");
|
||||
|
||||
// retrieve optional attribute
|
||||
auto tf_data_format = node.get_attribute<std::string>("data_format", (spatial_dim == 2) ? "NHWC" : "NDHWC");
|
||||
TENSORFLOW_OP_VALIDATION(
|
||||
node,
|
||||
tf_data_format == "NHWC" || tf_data_format == "NCHW" || tf_data_format == "NDHWC" || tf_data_format == "NCDHW",
|
||||
@ -28,47 +38,26 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
|
||||
|
||||
bool is_nhwc = (tf_data_format == "NHWC") || (tf_data_format == "NDHWC");
|
||||
|
||||
int N = 2;
|
||||
if (node.get_op_type() == "AvgPool3D") {
|
||||
N = 3;
|
||||
}
|
||||
// prepare inputs for OpenVINO AvgPool
|
||||
Strides strides(spatial_dim);
|
||||
Shape kernel_shape(spatial_dim);
|
||||
Shape dilations(spatial_dim, 1);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_ksize, kernel_shape);
|
||||
convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dim + 2));
|
||||
|
||||
Strides ng_strides(N);
|
||||
Shape ng_image_shape(N);
|
||||
Shape ng_kernel_shape(N);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_strides, ng_strides);
|
||||
convert_nhwc_to_hw(is_nhwc, ng_input.get_shape(), ng_image_shape);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_ksize, ng_kernel_shape);
|
||||
convert_nhwc_to_nchw(is_nhwc, ng_input);
|
||||
|
||||
CoordinateDiff padding_below;
|
||||
CoordinateDiff padding_above;
|
||||
Shape ng_dilations(N, 1);
|
||||
make_padding(tf_padding_type,
|
||||
ng_image_shape,
|
||||
ng_kernel_shape,
|
||||
ng_strides,
|
||||
ng_dilations,
|
||||
padding_below,
|
||||
padding_above);
|
||||
|
||||
// TODO: remove this once OV supports negative padding
|
||||
// (CoordinateDiff) for AvgPool
|
||||
Shape ng_padding_below(padding_below.begin(), padding_below.end());
|
||||
Shape ng_padding_above(padding_above.begin(), padding_above.end());
|
||||
|
||||
auto res_node = make_shared<AvgPool>(ng_input,
|
||||
ng_strides,
|
||||
ng_padding_below,
|
||||
ng_padding_above,
|
||||
ng_kernel_shape,
|
||||
auto avg_pool = make_shared<AvgPool>(input,
|
||||
strides,
|
||||
Shape({}),
|
||||
Shape({}),
|
||||
kernel_shape,
|
||||
true,
|
||||
ov::op::RoundingType::FLOOR);
|
||||
auto res = res_node->output(0);
|
||||
|
||||
convert_nchw_to_nhwc(is_nhwc, res);
|
||||
set_node_name(node.get_name(), res.get_node_shared_ptr());
|
||||
return {res};
|
||||
ov::op::RoundingType::FLOOR,
|
||||
auto_pad);
|
||||
auto avg_pool_output = avg_pool->output(0);
|
||||
convert_nchw_to_nhwc(is_nhwc, avg_pool_output, ov::Rank(spatial_dim + 2));
|
||||
set_node_name(node.get_name(), avg_pool);
|
||||
return {avg_pool};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -14,28 +14,28 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_bias_add_op(const NodeContext& node) {
|
||||
default_op_checks(node, 2, {"BiasAdd"});
|
||||
auto value = node.get_input(0);
|
||||
auto bias = node.get_input(1);
|
||||
|
||||
// retrieve optional attributes
|
||||
std::string data_format = node.get_attribute<std::string>("data_format", "NHWC");
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
data_format == "NHWC" || data_format == "NCHW",
|
||||
"BiasAdd data format is neither NHWC nor NCHW.");
|
||||
|
||||
auto value_shape = value.get_partial_shape();
|
||||
auto bias_shape = bias.get_partial_shape();
|
||||
TENSORFLOW_OP_VALIDATION(node, bias_shape.size() == 1, "Bias input of BiasAdd must have one dimension");
|
||||
|
||||
Output<Node> bias_reshaped = bias;
|
||||
|
||||
// in case NCHW layout bias must be reshaped to have a shape (1, C, 1, ...)
|
||||
// for further correct use of Add operation
|
||||
if (data_format == "NCHW") {
|
||||
// TODO: add support for dynamic rank in case NCHW layout
|
||||
auto value_shape = value.get_partial_shape();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
value_shape.rank().is_static(),
|
||||
"Value of dynamic rank for BiasAdd in NCHW layout is not supported.");
|
||||
auto value_rank = value_shape.rank().get_length();
|
||||
|
||||
std::vector<int64_t> axes_unsqueeze;
|
||||
for (size_t dim_ind = 0; dim_ind < value_rank; ++dim_ind) {
|
||||
if (dim_ind != 1) {
|
||||
|
@ -16,7 +16,7 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) {
|
||||
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 3, "Conv2DBackpropInput must have at least three inputs.");
|
||||
default_op_checks(node, 3, {"Conv2DBackpropInput"});
|
||||
auto input_sizes = node.get_input(0);
|
||||
auto filter = node.get_input(1);
|
||||
auto out_backprop = node.get_input(2);
|
||||
@ -73,7 +73,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) {
|
||||
|
||||
// prepare inputs to ConvolutionBackpropData
|
||||
filter = make_transpose(filter, {3, 2, 0, 1});
|
||||
convert_nhwc_to_nchw(is_nhwc, out_backprop);
|
||||
convert_nhwc_to_nchw(is_nhwc, out_backprop, ov::Rank(4));
|
||||
|
||||
// initially think that output shape defined for NCHW layout
|
||||
auto ss_begin = make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{2});
|
||||
@ -104,7 +104,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) {
|
||||
|
||||
// insert Transpose only if original Conv2DBackpropInput is in NHWC layout
|
||||
auto conv_backprop_output = conv_backprop->output(0);
|
||||
convert_nchw_to_nhwc(is_nhwc, conv_backprop_output);
|
||||
convert_nchw_to_nhwc(is_nhwc, conv_backprop_output, ov::Rank(4));
|
||||
|
||||
// move the original name to new ConvolutionBackpropData if original layout is NCHW
|
||||
// move the original name to Transpose if original layout is NHWC
|
||||
|
@ -14,7 +14,7 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) {
|
||||
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 3, "Conv3DBackpropInput must have at least three inputs.");
|
||||
default_op_checks(node, 3, {"Conv3DBackpropInput"});
|
||||
auto input_sizes = node.get_input(0);
|
||||
auto filter = node.get_input(1);
|
||||
auto out_backprop = node.get_input(2);
|
||||
@ -75,7 +75,7 @@ OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) {
|
||||
|
||||
// prepare inputs to ConvolutionBackpropData
|
||||
filter = make_transpose(filter, {4, 3, 0, 1, 2});
|
||||
convert_nhwc_to_nchw(is_nhwc, out_backprop);
|
||||
convert_nhwc_to_nchw(is_nhwc, out_backprop, ov::Rank(5));
|
||||
|
||||
// initially think that output shape defined for NCDHW layout
|
||||
auto ss_begin = make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{2});
|
||||
@ -106,7 +106,7 @@ OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) {
|
||||
|
||||
// insert Transpose only if original Conv3DBackpropInput is in NDHWC layout
|
||||
auto conv_backprop_output = conv_backprop->output(0);
|
||||
convert_nchw_to_nhwc(is_nhwc, conv_backprop_output);
|
||||
convert_nchw_to_nhwc(is_nhwc, conv_backprop_output, ov::Rank(5));
|
||||
|
||||
// move the original name to new ConvolutionBackpropData if original layout is NCHW
|
||||
// move the original name to Transpose if original layout is NHWC
|
||||
|
@ -14,66 +14,55 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_depthwise_conv_2d_native_op(const NodeContext& node) {
|
||||
auto ng_input = node.get_input(0);
|
||||
auto ng_filter = node.get_input(1);
|
||||
default_op_checks(node, 2, {"DepthwiseConv2dNative"});
|
||||
auto input = node.get_input(0);
|
||||
auto filter = node.get_input(1);
|
||||
|
||||
// retrive mandatory attributes for DepthwiseConv2dNative
|
||||
auto tf_strides = node.get_attribute<std::vector<int64_t>>("strides");
|
||||
auto tf_dilations = node.get_attribute<std::vector<int64_t>>("dilations");
|
||||
auto tf_padding_type = node.get_attribute<std::string>("padding");
|
||||
auto tf_data_format = node.get_attribute<std::string>("data_format");
|
||||
ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type);
|
||||
|
||||
// retrieve optional attributes
|
||||
auto tf_data_format = node.get_attribute<std::string>("data_format", "NHWC");
|
||||
auto tf_dilations = node.get_attribute<std::vector<int64_t>>("dilations", {1, 1, 1, 1});
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
auto_pad != ov::op::PadType::EXPLICIT,
|
||||
"Explicit padding for DepthwiseConv2dNative is not supported.");
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
tf_data_format == "NHWC" || tf_data_format == "NCHW",
|
||||
"DepthwiseConv2D data format is neither NHWC nor NCHW");
|
||||
"DepthwiseConv2dNative data format is neither NHWC nor NCHW");
|
||||
|
||||
bool is_nhwc = (tf_data_format == "NHWC");
|
||||
|
||||
Strides ng_strides(2);
|
||||
Strides ng_dilations(2);
|
||||
Strides strides(2);
|
||||
Strides dilations(2);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_strides, strides);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_dilations, dilations);
|
||||
|
||||
Shape ng_image_shape(2);
|
||||
Shape ng_kernel_shape(2);
|
||||
|
||||
convert_nhwc_to_hw(is_nhwc, ng_input.get_shape(), ng_image_shape);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_strides, ng_strides);
|
||||
convert_nhwc_to_hw(is_nhwc, tf_dilations, ng_dilations);
|
||||
convert_nhwc_to_nchw(is_nhwc, ng_input);
|
||||
convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(4));
|
||||
|
||||
auto& ng_filter_shape = ng_filter.get_shape();
|
||||
ng_kernel_shape[0] = ng_filter_shape[0];
|
||||
ng_kernel_shape[1] = ng_filter_shape[1];
|
||||
// prepare filter to have a number of groups equal to CIN
|
||||
auto unsqueeze_filter =
|
||||
make_shared<Unsqueeze>(filter, make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{3}));
|
||||
auto transposed_filter =
|
||||
make_shared<Transpose>(unsqueeze_filter,
|
||||
make_shared<Constant>(element::i64, Shape{5}, std::vector<int64_t>{2, 4, 3, 0, 1}));
|
||||
|
||||
CoordinateDiff ng_padding_below;
|
||||
CoordinateDiff ng_padding_above;
|
||||
make_padding(tf_padding_type,
|
||||
ng_image_shape,
|
||||
ng_kernel_shape,
|
||||
ng_strides,
|
||||
ng_dilations,
|
||||
ng_padding_below,
|
||||
ng_padding_above);
|
||||
|
||||
// H W I M -> H W I 1 M
|
||||
auto filter_shape = make_shared<Constant>(
|
||||
element::u64,
|
||||
Shape{5},
|
||||
ov::Shape{ng_filter_shape[0], ng_filter_shape[1], ng_filter_shape[2], 1, ng_filter_shape[3]});
|
||||
auto reshaped_filter = make_shared<Reshape>(ng_filter, filter_shape, false);
|
||||
|
||||
// H W I 1 M -> I M 1 H W
|
||||
auto order = make_shared<Constant>(element::i64, Shape{5}, vector<int64_t>{2, 4, 3, 0, 1});
|
||||
auto transposed_filter = make_shared<opset8::Transpose>(reshaped_filter, order);
|
||||
|
||||
auto ng_conv_node = make_shared<GroupConvolution>(ng_input,
|
||||
transposed_filter,
|
||||
ng_strides,
|
||||
ng_padding_below,
|
||||
ng_padding_above,
|
||||
ng_dilations);
|
||||
auto ng_conv = ng_conv_node->output(0);
|
||||
|
||||
convert_nchw_to_nhwc(is_nhwc, ng_conv);
|
||||
set_node_name(node.get_name(), ng_conv.get_node_shared_ptr());
|
||||
return {ng_conv};
|
||||
ov::Output<ov::Node> group_conv = make_shared<GroupConvolution>(input,
|
||||
transposed_filter,
|
||||
strides,
|
||||
CoordinateDiff({}),
|
||||
CoordinateDiff({}),
|
||||
dilations,
|
||||
auto_pad);
|
||||
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, group_conv, ov::Rank(4));
|
||||
ov::frontend::tensorflow::set_node_name(node.get_name(), group_conv.get_node_shared_ptr());
|
||||
return {group_conv};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -15,6 +15,7 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_fused_batch_norm_op(const NodeContext& node) {
|
||||
default_op_checks(node, 5, {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"});
|
||||
auto ng_input = node.get_input(0);
|
||||
auto ng_scale = node.get_input(1);
|
||||
auto ng_offset = node.get_input(2);
|
||||
@ -35,11 +36,11 @@ OutputVector translate_fused_batch_norm_op(const NodeContext& node) {
|
||||
|
||||
OPENVINO_DEBUG << "epsilon: " << tf_epsilon;
|
||||
|
||||
convert_nhwc_to_nchw(is_nhwc, ng_input);
|
||||
convert_nhwc_to_nchw(is_nhwc, ng_input, ov::Rank(4));
|
||||
|
||||
auto ng_batch_norm =
|
||||
make_shared<BatchNormInference>(ng_input, ng_scale, ng_offset, ng_mean, ng_variance, tf_epsilon)->output(0);
|
||||
convert_nchw_to_nhwc(is_nhwc, ng_batch_norm);
|
||||
convert_nchw_to_nhwc(is_nhwc, ng_batch_norm, ov::Rank(4));
|
||||
|
||||
// TODO: Why are there so many? Is it correct?
|
||||
OutputVector result = {ng_batch_norm, ng_mean, ng_variance, ng_mean, ng_variance};
|
||||
|
@ -19,7 +19,7 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
||||
size_t spatial_dims_num,
|
||||
const std::vector<int64_t>& tf_kernel_sizes,
|
||||
const std::vector<int64_t>& tf_strides) {
|
||||
TENSORFLOW_OP_VALIDATION(node, node.get_input_size() > 0, "MaxPool operation must have at least one input.");
|
||||
default_op_checks(node, 1, {"MaxPool2D", "MaxPool3D"});
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
spatial_dims_num == 2 || spatial_dims_num == 3,
|
||||
"Only MaxPool2D and MaxPool3D are supported.");
|
||||
@ -61,7 +61,7 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
||||
}
|
||||
|
||||
// prepare input to MaxPool
|
||||
convert_nhwc_to_nchw(is_nhwc, input);
|
||||
convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dims_num + 2));
|
||||
|
||||
auto max_pool_node = std::make_shared<ov::opset8::MaxPool>(input,
|
||||
strides,
|
||||
@ -72,7 +72,7 @@ OutputVector translate_max_pool_util(const NodeContext& node,
|
||||
ov::op::RoundingType::FLOOR,
|
||||
auto_pad);
|
||||
auto max_pool = max_pool_node->output(0);
|
||||
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, max_pool);
|
||||
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, max_pool, ov::Rank(spatial_dims_num + 2));
|
||||
ov::frontend::tensorflow::set_node_name(node.get_name(), max_pool.get_node_shared_ptr());
|
||||
return {max_pool};
|
||||
}
|
||||
|
@ -10,27 +10,37 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
void convert_nhwc_to_nchw(bool need_convert, ov::Output<ov::Node>& node) {
|
||||
void convert_nhwc_to_nchw(bool need_convert, ov::Output<ov::Node>& node, ov::Rank input_rank) {
|
||||
if (need_convert) {
|
||||
OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(),
|
||||
"The input rank must be static to convert to the first channel format.");
|
||||
auto rank = node.get_partial_shape().rank().get_length();
|
||||
if (rank == 4) {
|
||||
if (input_rank.is_dynamic()) {
|
||||
// TODO: use ShapeOf sub-graph to generate permutation vector
|
||||
OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(),
|
||||
"For conversion into the first channel format, the input rank must be static or determined "
|
||||
"based on the operation.");
|
||||
input_rank = node.get_partial_shape().rank();
|
||||
}
|
||||
auto rank_value = input_rank.get_length();
|
||||
if (rank_value == 4) {
|
||||
node = make_transpose(node, {0, 3, 1, 2});
|
||||
} else if (rank == 5) {
|
||||
} else if (rank_value == 5) {
|
||||
node = make_transpose(node, {0, 4, 1, 2, 3});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void convert_nchw_to_nhwc(bool need_convert, ov::Output<ov::Node>& node) {
|
||||
void convert_nchw_to_nhwc(bool need_convert, ov::Output<ov::Node>& node, ov::Rank input_rank) {
|
||||
if (need_convert) {
|
||||
OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(),
|
||||
"The input rank must be static to convert to the last channel format.");
|
||||
auto rank = node.get_partial_shape().rank().get_length();
|
||||
if (rank == 4) {
|
||||
if (input_rank.is_dynamic()) {
|
||||
// TODO: use ShapeOf sub-graph to generate permutation vector
|
||||
OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(),
|
||||
"For conversion into the last channel format, the input rank must be static or determined "
|
||||
"based on the operation.");
|
||||
input_rank = node.get_partial_shape().rank();
|
||||
}
|
||||
auto rank_value = input_rank.get_length();
|
||||
if (rank_value == 4) {
|
||||
node = make_transpose(node, {0, 2, 3, 1});
|
||||
} else if (rank == 5) {
|
||||
} else if (rank_value == 5) {
|
||||
node = make_transpose(node, {0, 2, 3, 4, 1});
|
||||
}
|
||||
}
|
||||
|
@ -39,9 +39,9 @@ void convert_nchw_to_hw(const std::vector<T>& src, std::vector<size_t>& dst) {
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
void convert_nhwc_to_nchw(bool need_convert, ov::Output<ov::Node>& node);
|
||||
void convert_nhwc_to_nchw(bool need_convert, ov::Output<ov::Node>& node, ov::Rank input_rank = ov::Rank::dynamic());
|
||||
|
||||
void convert_nchw_to_nhwc(bool need_convert, ov::Output<ov::Node>& node);
|
||||
void convert_nchw_to_nhwc(bool need_convert, ov::Output<ov::Node>& node, ov::Rank input_rank = ov::Rank::dynamic());
|
||||
|
||||
template <typename T>
|
||||
void convert_nhwc_to_hw(bool is_nhwc, const std::vector<T>& src, std::vector<size_t>& dst) {
|
||||
|
@ -33,7 +33,10 @@ ov::op::PadType ov::frontend::tensorflow::convert_tf_padding(const ov::frontend:
|
||||
"MaxPool",
|
||||
"MaxPoolV2",
|
||||
"MaxPool3D",
|
||||
"ExtractImagePatches"};
|
||||
"ExtractImagePatches",
|
||||
"DepthwiseConv2dNative",
|
||||
"AvgPool",
|
||||
"AvgPool3D"};
|
||||
auto op_type = node.get_op_type();
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
@ -55,7 +58,8 @@ ov::op::PadType ov::frontend::tensorflow::convert_tf_padding(const ov::frontend:
|
||||
return ov::op::PadType::SAME_LOWER;
|
||||
}
|
||||
} else if (op_type == "Conv2D" || op_type == "Conv3D" || op_type == "MaxPool" || op_type == "MaxPoolV2" ||
|
||||
op_type == "MaxPool3D" || op_type == "ExtractImagePatches") {
|
||||
op_type == "MaxPool3D" || op_type == "ExtractImagePatches" || op_type == "DepthwiseConv2dNative" ||
|
||||
op_type == "AvgPool" || op_type == "AvgPool3D") {
|
||||
if (tf_padding == "SAME") {
|
||||
// According to the formulas for calculating auto_pad values of the
|
||||
// Conv layer in the Operation specification,
|
||||
@ -166,7 +170,7 @@ ov::OutputVector ov::frontend::tensorflow::translate_convolution_op(const ov::fr
|
||||
}
|
||||
|
||||
// prepare inputs to Convolution
|
||||
ov::frontend::tensorflow::convert_nhwc_to_nchw(is_nhwc, input);
|
||||
ov::frontend::tensorflow::convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dims_num + 2));
|
||||
ov::AxisVector permutation_2d = {3, 2, 0, 1};
|
||||
ov::AxisVector permutation_3d = {4, 3, 0, 1, 2};
|
||||
filter = ov::frontend::tensorflow::make_transpose(filter, spatial_dims_num == 2 ? permutation_2d : permutation_3d);
|
||||
@ -174,11 +178,23 @@ ov::OutputVector ov::frontend::tensorflow::translate_convolution_op(const ov::fr
|
||||
ov::Output<ov::Node> conv =
|
||||
std::make_shared<Convolution>(input, filter, strides, pads_begin, pads_end, dilations, auto_pad);
|
||||
|
||||
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, conv);
|
||||
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, conv, ov::Rank(spatial_dims_num + 2));
|
||||
ov::frontend::tensorflow::set_node_name(node.get_name(), conv.get_node_shared_ptr());
|
||||
return {conv};
|
||||
}
|
||||
|
||||
void ov::frontend::tensorflow::default_op_checks(const ov::frontend::tensorflow::NodeContext& node,
|
||||
int min_input_size,
|
||||
const std::vector<std::string>& supported_ops) {
|
||||
auto op_type = node.get_op_type();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
std::find(supported_ops.begin(), supported_ops.end(), op_type) != supported_ops.end(),
|
||||
op_type + " is not supported for conversion.");
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
node.get_input_size() >= min_input_size,
|
||||
op_type + " must have at least " + std::to_string(min_input_size) + " inputs.");
|
||||
}
|
||||
|
||||
bool ov::frontend::tensorflow::is_conditional_edge(const std::string& input_tensor_name) {
|
||||
return input_tensor_name.length() > 0 && input_tensor_name[0] == '^';
|
||||
}
|
||||
|
@ -72,6 +72,9 @@ void fill_explicit_pads_vectors(const NodeContext& node,
|
||||
const std::vector<int64_t>& tf_explicit_paddings,
|
||||
ov::CoordinateDiff& pads_begin,
|
||||
ov::CoordinateDiff& pads_end);
|
||||
|
||||
void default_op_checks(const NodeContext& node, int min_input_size, const std::vector<std::string>& supported_ops);
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user