diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index d3fe7faf6d3..60a0795b3de 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -192,7 +192,9 @@ void FrontEnd::translate_graph(const ov::frontend::InputModel::Ptr& model, results.push_back(result); } else { auto param = std::dynamic_pointer_cast(output.get_node_shared_ptr()); - if (param && operation_decoder->get_op_type() != "Identity") { + // avoid duplicating Parameter nodes if they are already in the Parameters vector + if (param && operation_decoder->get_op_type() != "Identity" && + std::find(params.begin(), params.end(), param) == params.end()) { params.push_back(param); } ng_op_map[operation_name].push_back(output); diff --git a/src/frontends/tensorflow/src/op/avg_pool.cpp b/src/frontends/tensorflow/src/op/avg_pool.cpp index d2a5618c9e7..7fe55a6884f 100644 --- a/src/frontends/tensorflow/src/op/avg_pool.cpp +++ b/src/frontends/tensorflow/src/op/avg_pool.cpp @@ -14,13 +14,23 @@ namespace tensorflow { namespace op { OutputVector translate_avg_pool_op(const NodeContext& node) { - Output ng_input = node.get_input(0); + default_op_checks(node, 1, {"AvgPool", "AvgPool3D"}); + auto op_type = node.get_op_type(); + auto input = node.get_input(0); + auto spatial_dim = (op_type == "AvgPool") ? 2 : 3; + + // retrieve attributes for AvgPool operation auto tf_strides = node.get_attribute>("strides"); auto tf_ksize = node.get_attribute>("ksize"); auto tf_padding_type = node.get_attribute("padding"); - auto tf_data_format = node.get_attribute("data_format"); + ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type); + TENSORFLOW_OP_VALIDATION(node, + auto_pad == ov::op::PadType::VALID || auto_pad == ov::op::PadType::SAME_UPPER, + "AvgPool and AvgPool3D supports only VALID or SAME_UPPER padding mode."); + // retrieve optional attribute + auto tf_data_format = node.get_attribute("data_format", (spatial_dim == 2) ? "NHWC" : "NDHWC"); TENSORFLOW_OP_VALIDATION( node, tf_data_format == "NHWC" || tf_data_format == "NCHW" || tf_data_format == "NDHWC" || tf_data_format == "NCDHW", @@ -28,47 +38,26 @@ OutputVector translate_avg_pool_op(const NodeContext& node) { bool is_nhwc = (tf_data_format == "NHWC") || (tf_data_format == "NDHWC"); - int N = 2; - if (node.get_op_type() == "AvgPool3D") { - N = 3; - } + // prepare inputs for OpenVINO AvgPool + Strides strides(spatial_dim); + Shape kernel_shape(spatial_dim); + Shape dilations(spatial_dim, 1); + convert_nhwc_to_hw(is_nhwc, tf_strides, strides); + convert_nhwc_to_hw(is_nhwc, tf_ksize, kernel_shape); + convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dim + 2)); - Strides ng_strides(N); - Shape ng_image_shape(N); - Shape ng_kernel_shape(N); - convert_nhwc_to_hw(is_nhwc, tf_strides, ng_strides); - convert_nhwc_to_hw(is_nhwc, ng_input.get_shape(), ng_image_shape); - convert_nhwc_to_hw(is_nhwc, tf_ksize, ng_kernel_shape); - convert_nhwc_to_nchw(is_nhwc, ng_input); - - CoordinateDiff padding_below; - CoordinateDiff padding_above; - Shape ng_dilations(N, 1); - make_padding(tf_padding_type, - ng_image_shape, - ng_kernel_shape, - ng_strides, - ng_dilations, - padding_below, - padding_above); - - // TODO: remove this once OV supports negative padding - // (CoordinateDiff) for AvgPool - Shape ng_padding_below(padding_below.begin(), padding_below.end()); - Shape ng_padding_above(padding_above.begin(), padding_above.end()); - - auto res_node = make_shared(ng_input, - ng_strides, - ng_padding_below, - ng_padding_above, - ng_kernel_shape, + auto avg_pool = make_shared(input, + strides, + Shape({}), + Shape({}), + kernel_shape, true, - ov::op::RoundingType::FLOOR); - auto res = res_node->output(0); - - convert_nchw_to_nhwc(is_nhwc, res); - set_node_name(node.get_name(), res.get_node_shared_ptr()); - return {res}; + ov::op::RoundingType::FLOOR, + auto_pad); + auto avg_pool_output = avg_pool->output(0); + convert_nchw_to_nhwc(is_nhwc, avg_pool_output, ov::Rank(spatial_dim + 2)); + set_node_name(node.get_name(), avg_pool); + return {avg_pool}; } } // namespace op } // namespace tensorflow diff --git a/src/frontends/tensorflow/src/op/bias_add.cpp b/src/frontends/tensorflow/src/op/bias_add.cpp index be7f457dcfe..2306cfadc19 100644 --- a/src/frontends/tensorflow/src/op/bias_add.cpp +++ b/src/frontends/tensorflow/src/op/bias_add.cpp @@ -14,28 +14,28 @@ namespace tensorflow { namespace op { OutputVector translate_bias_add_op(const NodeContext& node) { + default_op_checks(node, 2, {"BiasAdd"}); auto value = node.get_input(0); auto bias = node.get_input(1); + // retrieve optional attributes std::string data_format = node.get_attribute("data_format", "NHWC"); - TENSORFLOW_OP_VALIDATION(node, data_format == "NHWC" || data_format == "NCHW", "BiasAdd data format is neither NHWC nor NCHW."); - auto value_shape = value.get_partial_shape(); - auto bias_shape = bias.get_partial_shape(); - TENSORFLOW_OP_VALIDATION(node, bias_shape.size() == 1, "Bias input of BiasAdd must have one dimension"); - Output bias_reshaped = bias; // in case NCHW layout bias must be reshaped to have a shape (1, C, 1, ...) // for further correct use of Add operation if (data_format == "NCHW") { + // TODO: add support for dynamic rank in case NCHW layout + auto value_shape = value.get_partial_shape(); TENSORFLOW_OP_VALIDATION(node, value_shape.rank().is_static(), "Value of dynamic rank for BiasAdd in NCHW layout is not supported."); auto value_rank = value_shape.rank().get_length(); + std::vector axes_unsqueeze; for (size_t dim_ind = 0; dim_ind < value_rank; ++dim_ind) { if (dim_ind != 1) { diff --git a/src/frontends/tensorflow/src/op/conv_2d_backprop.cpp b/src/frontends/tensorflow/src/op/conv_2d_backprop.cpp index 0f5c2cc1b26..9b8dde95fd3 100644 --- a/src/frontends/tensorflow/src/op/conv_2d_backprop.cpp +++ b/src/frontends/tensorflow/src/op/conv_2d_backprop.cpp @@ -16,7 +16,7 @@ namespace tensorflow { namespace op { OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) { - TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 3, "Conv2DBackpropInput must have at least three inputs."); + default_op_checks(node, 3, {"Conv2DBackpropInput"}); auto input_sizes = node.get_input(0); auto filter = node.get_input(1); auto out_backprop = node.get_input(2); @@ -73,7 +73,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) { // prepare inputs to ConvolutionBackpropData filter = make_transpose(filter, {3, 2, 0, 1}); - convert_nhwc_to_nchw(is_nhwc, out_backprop); + convert_nhwc_to_nchw(is_nhwc, out_backprop, ov::Rank(4)); // initially think that output shape defined for NCHW layout auto ss_begin = make_shared(element::i64, Shape{1}, std::vector{2}); @@ -104,7 +104,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) { // insert Transpose only if original Conv2DBackpropInput is in NHWC layout auto conv_backprop_output = conv_backprop->output(0); - convert_nchw_to_nhwc(is_nhwc, conv_backprop_output); + convert_nchw_to_nhwc(is_nhwc, conv_backprop_output, ov::Rank(4)); // move the original name to new ConvolutionBackpropData if original layout is NCHW // move the original name to Transpose if original layout is NHWC diff --git a/src/frontends/tensorflow/src/op/conv_3d_backprop.cpp b/src/frontends/tensorflow/src/op/conv_3d_backprop.cpp index 03253ad8e5e..65b9fd7c1a6 100644 --- a/src/frontends/tensorflow/src/op/conv_3d_backprop.cpp +++ b/src/frontends/tensorflow/src/op/conv_3d_backprop.cpp @@ -14,7 +14,7 @@ namespace tensorflow { namespace op { OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) { - TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 3, "Conv3DBackpropInput must have at least three inputs."); + default_op_checks(node, 3, {"Conv3DBackpropInput"}); auto input_sizes = node.get_input(0); auto filter = node.get_input(1); auto out_backprop = node.get_input(2); @@ -75,7 +75,7 @@ OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) { // prepare inputs to ConvolutionBackpropData filter = make_transpose(filter, {4, 3, 0, 1, 2}); - convert_nhwc_to_nchw(is_nhwc, out_backprop); + convert_nhwc_to_nchw(is_nhwc, out_backprop, ov::Rank(5)); // initially think that output shape defined for NCDHW layout auto ss_begin = make_shared(element::i64, Shape{1}, std::vector{2}); @@ -106,7 +106,7 @@ OutputVector translate_conv_3d_backprop_input_v2_op(const NodeContext& node) { // insert Transpose only if original Conv3DBackpropInput is in NDHWC layout auto conv_backprop_output = conv_backprop->output(0); - convert_nchw_to_nhwc(is_nhwc, conv_backprop_output); + convert_nchw_to_nhwc(is_nhwc, conv_backprop_output, ov::Rank(5)); // move the original name to new ConvolutionBackpropData if original layout is NCHW // move the original name to Transpose if original layout is NHWC diff --git a/src/frontends/tensorflow/src/op/depthwise_conv_2d.cpp b/src/frontends/tensorflow/src/op/depthwise_conv_2d.cpp index 7844fad2e94..f9eb976b7d9 100644 --- a/src/frontends/tensorflow/src/op/depthwise_conv_2d.cpp +++ b/src/frontends/tensorflow/src/op/depthwise_conv_2d.cpp @@ -14,66 +14,55 @@ namespace tensorflow { namespace op { OutputVector translate_depthwise_conv_2d_native_op(const NodeContext& node) { - auto ng_input = node.get_input(0); - auto ng_filter = node.get_input(1); + default_op_checks(node, 2, {"DepthwiseConv2dNative"}); + auto input = node.get_input(0); + auto filter = node.get_input(1); + // retrive mandatory attributes for DepthwiseConv2dNative auto tf_strides = node.get_attribute>("strides"); - auto tf_dilations = node.get_attribute>("dilations"); auto tf_padding_type = node.get_attribute("padding"); - auto tf_data_format = node.get_attribute("data_format"); + ov::op::PadType auto_pad = convert_tf_padding(node, tf_padding_type); + + // retrieve optional attributes + auto tf_data_format = node.get_attribute("data_format", "NHWC"); + auto tf_dilations = node.get_attribute>("dilations", {1, 1, 1, 1}); + TENSORFLOW_OP_VALIDATION(node, + auto_pad != ov::op::PadType::EXPLICIT, + "Explicit padding for DepthwiseConv2dNative is not supported."); TENSORFLOW_OP_VALIDATION(node, tf_data_format == "NHWC" || tf_data_format == "NCHW", - "DepthwiseConv2D data format is neither NHWC nor NCHW"); + "DepthwiseConv2dNative data format is neither NHWC nor NCHW"); bool is_nhwc = (tf_data_format == "NHWC"); - Strides ng_strides(2); - Strides ng_dilations(2); + Strides strides(2); + Strides dilations(2); + convert_nhwc_to_hw(is_nhwc, tf_strides, strides); + convert_nhwc_to_hw(is_nhwc, tf_dilations, dilations); + Shape ng_image_shape(2); Shape ng_kernel_shape(2); - convert_nhwc_to_hw(is_nhwc, ng_input.get_shape(), ng_image_shape); - convert_nhwc_to_hw(is_nhwc, tf_strides, ng_strides); - convert_nhwc_to_hw(is_nhwc, tf_dilations, ng_dilations); - convert_nhwc_to_nchw(is_nhwc, ng_input); + convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(4)); - auto& ng_filter_shape = ng_filter.get_shape(); - ng_kernel_shape[0] = ng_filter_shape[0]; - ng_kernel_shape[1] = ng_filter_shape[1]; + // prepare filter to have a number of groups equal to CIN + auto unsqueeze_filter = + make_shared(filter, make_shared(element::i64, Shape{1}, std::vector{3})); + auto transposed_filter = + make_shared(unsqueeze_filter, + make_shared(element::i64, Shape{5}, std::vector{2, 4, 3, 0, 1})); - CoordinateDiff ng_padding_below; - CoordinateDiff ng_padding_above; - make_padding(tf_padding_type, - ng_image_shape, - ng_kernel_shape, - ng_strides, - ng_dilations, - ng_padding_below, - ng_padding_above); - - // H W I M -> H W I 1 M - auto filter_shape = make_shared( - element::u64, - Shape{5}, - ov::Shape{ng_filter_shape[0], ng_filter_shape[1], ng_filter_shape[2], 1, ng_filter_shape[3]}); - auto reshaped_filter = make_shared(ng_filter, filter_shape, false); - - // H W I 1 M -> I M 1 H W - auto order = make_shared(element::i64, Shape{5}, vector{2, 4, 3, 0, 1}); - auto transposed_filter = make_shared(reshaped_filter, order); - - auto ng_conv_node = make_shared(ng_input, - transposed_filter, - ng_strides, - ng_padding_below, - ng_padding_above, - ng_dilations); - auto ng_conv = ng_conv_node->output(0); - - convert_nchw_to_nhwc(is_nhwc, ng_conv); - set_node_name(node.get_name(), ng_conv.get_node_shared_ptr()); - return {ng_conv}; + ov::Output group_conv = make_shared(input, + transposed_filter, + strides, + CoordinateDiff({}), + CoordinateDiff({}), + dilations, + auto_pad); + ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, group_conv, ov::Rank(4)); + ov::frontend::tensorflow::set_node_name(node.get_name(), group_conv.get_node_shared_ptr()); + return {group_conv}; } } // namespace op } // namespace tensorflow diff --git a/src/frontends/tensorflow/src/op/fused_batch_norm.cpp b/src/frontends/tensorflow/src/op/fused_batch_norm.cpp index b6aa146f10d..64ef6179256 100644 --- a/src/frontends/tensorflow/src/op/fused_batch_norm.cpp +++ b/src/frontends/tensorflow/src/op/fused_batch_norm.cpp @@ -15,6 +15,7 @@ namespace tensorflow { namespace op { OutputVector translate_fused_batch_norm_op(const NodeContext& node) { + default_op_checks(node, 5, {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"}); auto ng_input = node.get_input(0); auto ng_scale = node.get_input(1); auto ng_offset = node.get_input(2); @@ -35,11 +36,11 @@ OutputVector translate_fused_batch_norm_op(const NodeContext& node) { OPENVINO_DEBUG << "epsilon: " << tf_epsilon; - convert_nhwc_to_nchw(is_nhwc, ng_input); + convert_nhwc_to_nchw(is_nhwc, ng_input, ov::Rank(4)); auto ng_batch_norm = make_shared(ng_input, ng_scale, ng_offset, ng_mean, ng_variance, tf_epsilon)->output(0); - convert_nchw_to_nhwc(is_nhwc, ng_batch_norm); + convert_nchw_to_nhwc(is_nhwc, ng_batch_norm, ov::Rank(4)); // TODO: Why are there so many? Is it correct? OutputVector result = {ng_batch_norm, ng_mean, ng_variance, ng_mean, ng_variance}; diff --git a/src/frontends/tensorflow/src/op/max_pool.cpp b/src/frontends/tensorflow/src/op/max_pool.cpp index eb45700155b..0f8a1042a78 100644 --- a/src/frontends/tensorflow/src/op/max_pool.cpp +++ b/src/frontends/tensorflow/src/op/max_pool.cpp @@ -19,7 +19,7 @@ OutputVector translate_max_pool_util(const NodeContext& node, size_t spatial_dims_num, const std::vector& tf_kernel_sizes, const std::vector& tf_strides) { - TENSORFLOW_OP_VALIDATION(node, node.get_input_size() > 0, "MaxPool operation must have at least one input."); + default_op_checks(node, 1, {"MaxPool2D", "MaxPool3D"}); TENSORFLOW_OP_VALIDATION(node, spatial_dims_num == 2 || spatial_dims_num == 3, "Only MaxPool2D and MaxPool3D are supported."); @@ -61,7 +61,7 @@ OutputVector translate_max_pool_util(const NodeContext& node, } // prepare input to MaxPool - convert_nhwc_to_nchw(is_nhwc, input); + convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dims_num + 2)); auto max_pool_node = std::make_shared(input, strides, @@ -72,7 +72,7 @@ OutputVector translate_max_pool_util(const NodeContext& node, ov::op::RoundingType::FLOOR, auto_pad); auto max_pool = max_pool_node->output(0); - ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, max_pool); + ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, max_pool, ov::Rank(spatial_dims_num + 2)); ov::frontend::tensorflow::set_node_name(node.get_name(), max_pool.get_node_shared_ptr()); return {max_pool}; } diff --git a/src/frontends/tensorflow/src/openvino_conversions.cpp b/src/frontends/tensorflow/src/openvino_conversions.cpp index 068d671ab16..4811dd331d2 100644 --- a/src/frontends/tensorflow/src/openvino_conversions.cpp +++ b/src/frontends/tensorflow/src/openvino_conversions.cpp @@ -10,27 +10,37 @@ namespace ov { namespace frontend { namespace tensorflow { -void convert_nhwc_to_nchw(bool need_convert, ov::Output& node) { +void convert_nhwc_to_nchw(bool need_convert, ov::Output& node, ov::Rank input_rank) { if (need_convert) { - OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(), - "The input rank must be static to convert to the first channel format."); - auto rank = node.get_partial_shape().rank().get_length(); - if (rank == 4) { + if (input_rank.is_dynamic()) { + // TODO: use ShapeOf sub-graph to generate permutation vector + OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(), + "For conversion into the first channel format, the input rank must be static or determined " + "based on the operation."); + input_rank = node.get_partial_shape().rank(); + } + auto rank_value = input_rank.get_length(); + if (rank_value == 4) { node = make_transpose(node, {0, 3, 1, 2}); - } else if (rank == 5) { + } else if (rank_value == 5) { node = make_transpose(node, {0, 4, 1, 2, 3}); } } } -void convert_nchw_to_nhwc(bool need_convert, ov::Output& node) { +void convert_nchw_to_nhwc(bool need_convert, ov::Output& node, ov::Rank input_rank) { if (need_convert) { - OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(), - "The input rank must be static to convert to the last channel format."); - auto rank = node.get_partial_shape().rank().get_length(); - if (rank == 4) { + if (input_rank.is_dynamic()) { + // TODO: use ShapeOf sub-graph to generate permutation vector + OPENVINO_ASSERT(node.get_partial_shape().rank().is_static(), + "For conversion into the last channel format, the input rank must be static or determined " + "based on the operation."); + input_rank = node.get_partial_shape().rank(); + } + auto rank_value = input_rank.get_length(); + if (rank_value == 4) { node = make_transpose(node, {0, 2, 3, 1}); - } else if (rank == 5) { + } else if (rank_value == 5) { node = make_transpose(node, {0, 2, 3, 4, 1}); } } diff --git a/src/frontends/tensorflow/src/openvino_conversions.hpp b/src/frontends/tensorflow/src/openvino_conversions.hpp index 94fec13033e..010effffb81 100644 --- a/src/frontends/tensorflow/src/openvino_conversions.hpp +++ b/src/frontends/tensorflow/src/openvino_conversions.hpp @@ -39,9 +39,9 @@ void convert_nchw_to_hw(const std::vector& src, std::vector& dst) { } } // namespace detail -void convert_nhwc_to_nchw(bool need_convert, ov::Output& node); +void convert_nhwc_to_nchw(bool need_convert, ov::Output& node, ov::Rank input_rank = ov::Rank::dynamic()); -void convert_nchw_to_nhwc(bool need_convert, ov::Output& node); +void convert_nchw_to_nhwc(bool need_convert, ov::Output& node, ov::Rank input_rank = ov::Rank::dynamic()); template void convert_nhwc_to_hw(bool is_nhwc, const std::vector& src, std::vector& dst) { diff --git a/src/frontends/tensorflow/src/utils.cpp b/src/frontends/tensorflow/src/utils.cpp index f567a18e06c..79818cfde91 100644 --- a/src/frontends/tensorflow/src/utils.cpp +++ b/src/frontends/tensorflow/src/utils.cpp @@ -33,7 +33,10 @@ ov::op::PadType ov::frontend::tensorflow::convert_tf_padding(const ov::frontend: "MaxPool", "MaxPoolV2", "MaxPool3D", - "ExtractImagePatches"}; + "ExtractImagePatches", + "DepthwiseConv2dNative", + "AvgPool", + "AvgPool3D"}; auto op_type = node.get_op_type(); TENSORFLOW_OP_VALIDATION(node, @@ -55,7 +58,8 @@ ov::op::PadType ov::frontend::tensorflow::convert_tf_padding(const ov::frontend: return ov::op::PadType::SAME_LOWER; } } else if (op_type == "Conv2D" || op_type == "Conv3D" || op_type == "MaxPool" || op_type == "MaxPoolV2" || - op_type == "MaxPool3D" || op_type == "ExtractImagePatches") { + op_type == "MaxPool3D" || op_type == "ExtractImagePatches" || op_type == "DepthwiseConv2dNative" || + op_type == "AvgPool" || op_type == "AvgPool3D") { if (tf_padding == "SAME") { // According to the formulas for calculating auto_pad values of the // Conv layer in the Operation specification, @@ -166,7 +170,7 @@ ov::OutputVector ov::frontend::tensorflow::translate_convolution_op(const ov::fr } // prepare inputs to Convolution - ov::frontend::tensorflow::convert_nhwc_to_nchw(is_nhwc, input); + ov::frontend::tensorflow::convert_nhwc_to_nchw(is_nhwc, input, ov::Rank(spatial_dims_num + 2)); ov::AxisVector permutation_2d = {3, 2, 0, 1}; ov::AxisVector permutation_3d = {4, 3, 0, 1, 2}; filter = ov::frontend::tensorflow::make_transpose(filter, spatial_dims_num == 2 ? permutation_2d : permutation_3d); @@ -174,11 +178,23 @@ ov::OutputVector ov::frontend::tensorflow::translate_convolution_op(const ov::fr ov::Output conv = std::make_shared(input, filter, strides, pads_begin, pads_end, dilations, auto_pad); - ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, conv); + ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, conv, ov::Rank(spatial_dims_num + 2)); ov::frontend::tensorflow::set_node_name(node.get_name(), conv.get_node_shared_ptr()); return {conv}; } +void ov::frontend::tensorflow::default_op_checks(const ov::frontend::tensorflow::NodeContext& node, + int min_input_size, + const std::vector& supported_ops) { + auto op_type = node.get_op_type(); + TENSORFLOW_OP_VALIDATION(node, + std::find(supported_ops.begin(), supported_ops.end(), op_type) != supported_ops.end(), + op_type + " is not supported for conversion."); + TENSORFLOW_OP_VALIDATION(node, + node.get_input_size() >= min_input_size, + op_type + " must have at least " + std::to_string(min_input_size) + " inputs."); +} + bool ov::frontend::tensorflow::is_conditional_edge(const std::string& input_tensor_name) { return input_tensor_name.length() > 0 && input_tensor_name[0] == '^'; } diff --git a/src/frontends/tensorflow/src/utils.hpp b/src/frontends/tensorflow/src/utils.hpp index 81170f8adfc..5f93e67ce56 100644 --- a/src/frontends/tensorflow/src/utils.hpp +++ b/src/frontends/tensorflow/src/utils.hpp @@ -72,6 +72,9 @@ void fill_explicit_pads_vectors(const NodeContext& node, const std::vector& tf_explicit_paddings, ov::CoordinateDiff& pads_begin, ov::CoordinateDiff& pads_end); + +void default_op_checks(const NodeContext& node, int min_input_size, const std::vector& supported_ops); + } // namespace tensorflow } // namespace frontend } // namespace ov