From d90ceb93d167bc6713559f8ec00778230dea2b14 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 18 Sep 2023 16:51:33 +0400 Subject: [PATCH] [TF Hub][TF FE] Fix 5D case for FusedBatchNorm (#19904) Signed-off-by: Kazantsev, Roman --- .../tensorflow_common/src/op/fused_batch_norm.cpp | 2 +- .../layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py | 7 ++++++- tests/model_hub_tests/tf_hub_tests/precommit_models | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/fused_batch_norm.cpp b/src/frontends/tensorflow_common/src/op/fused_batch_norm.cpp index 656213e1573..027b5fcdf0d 100644 --- a/src/frontends/tensorflow_common/src/op/fused_batch_norm.cpp +++ b/src/frontends/tensorflow_common/src/op/fused_batch_norm.cpp @@ -146,7 +146,7 @@ void compute_fused_batch_norm_inference(const NodeContext& node, // retrieve attributes auto epsilon = node.get_attribute("epsilon", 0.0001f); auto data_format = node.get_attribute("data_format", "NHWC"); - bool is_nhwc = (data_format == "NHWC"); + bool is_nhwc = (data_format == "NHWC" || data_format == "NDHWC"); // create auxiliary Constant nodes for some attributes: epsilon and exponential_avg_factor auto eps_const = create_same_type_const_scalar(x, epsilon); diff --git a/tests/layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py b/tests/layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py index 884a8a5bf44..94ae7dc6282 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py @@ -35,7 +35,7 @@ class TestFusedBatchNorm(CommonTFLayerTest): # Create the graph and model with tf.compat.v1.Session() as sess: c_dim = x_shape[-1] - if data_format == "NCHW": + if data_format == "NCHW" or data_format == "NCDHW": c_dim = x_shape[1] x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x') if empty_mean_variance: @@ -92,6 +92,11 @@ class TestFusedBatchNorm(CommonTFLayerTest): fbn_version="v3"), dict(x_shape=[5, 10, 8, 2], epsilon=0.0002, exponential_avg_factor=0.2, data_format="NHWC", is_training=True, fbn_version="v3", empty_mean_variance=False), + # 5D cases + dict(x_shape=[5, 4, 3, 2, 3], epsilon=0.0005, exponential_avg_factor=0.0, data_format="NCDHW", + is_training=False, fbn_version="v3"), + dict(x_shape=[3, 4, 3, 3, 2], epsilon=0.0003, exponential_avg_factor=0.0, data_format="NDHWC", + is_training=False, fbn_version="v3"), ] @pytest.mark.parametrize("params", test_data_basic) diff --git a/tests/model_hub_tests/tf_hub_tests/precommit_models b/tests/model_hub_tests/tf_hub_tests/precommit_models index 90dd3f4c5e1..f9be334761d 100644 --- a/tests/model_hub_tests/tf_hub_tests/precommit_models +++ b/tests/model_hub_tests/tf_hub_tests/precommit_models @@ -9,6 +9,7 @@ imagenet/efficientnet_v2_imagenet1k_b0/feature_vector,https://tfhub.dev/google/i imagenet/mobilenet_v1_100_224/classification,https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/5?tf-hub-format=compressed,skip,119718 - Accuracy issue magenta/arbitrary-image-stylization-v1-256,https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2?tf-hub-format=compressed small_bert/bert_en_uncased_L-4_H-256_A-4,https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/2?tf-hub-format=compressed,skip,119718 - Accuracy issue +movinet/a5/base/kinetics-600/classification,https://tfhub.dev/tensorflow/movinet/a5/base/kinetics-600/classification/3?tf-hub-format=compressed # secure notebook models unet/industrial/class_1,https://tfhub.dev/nvidia/unet/industrial/class_1/1?tf-hub-format=compressed movenet/singlepose/thunder,https://tfhub.dev/google/movenet/singlepose/thunder/4?tf-hub-format=compressed