[TF Hub][TF FE] Fix 5D case for FusedBatchNorm (#19904)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-09-18 16:51:33 +04:00 committed by GitHub
parent df19699e3a
commit d90ceb93d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 2 deletions

View File

@ -146,7 +146,7 @@ void compute_fused_batch_norm_inference(const NodeContext& node,
// retrieve attributes
auto epsilon = node.get_attribute<float>("epsilon", 0.0001f);
auto data_format = node.get_attribute<string>("data_format", "NHWC");
bool is_nhwc = (data_format == "NHWC");
bool is_nhwc = (data_format == "NHWC" || data_format == "NDHWC");
// create auxiliary Constant nodes for some attributes: epsilon and exponential_avg_factor
auto eps_const = create_same_type_const_scalar<float>(x, epsilon);

View File

@ -35,7 +35,7 @@ class TestFusedBatchNorm(CommonTFLayerTest):
# Create the graph and model
with tf.compat.v1.Session() as sess:
c_dim = x_shape[-1]
if data_format == "NCHW":
if data_format == "NCHW" or data_format == "NCDHW":
c_dim = x_shape[1]
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
if empty_mean_variance:
@ -92,6 +92,11 @@ class TestFusedBatchNorm(CommonTFLayerTest):
fbn_version="v3"),
dict(x_shape=[5, 10, 8, 2], epsilon=0.0002, exponential_avg_factor=0.2, data_format="NHWC",
is_training=True, fbn_version="v3", empty_mean_variance=False),
# 5D cases
dict(x_shape=[5, 4, 3, 2, 3], epsilon=0.0005, exponential_avg_factor=0.0, data_format="NCDHW",
is_training=False, fbn_version="v3"),
dict(x_shape=[3, 4, 3, 3, 2], epsilon=0.0003, exponential_avg_factor=0.0, data_format="NDHWC",
is_training=False, fbn_version="v3"),
]
@pytest.mark.parametrize("params", test_data_basic)

View File

@ -9,6 +9,7 @@ imagenet/efficientnet_v2_imagenet1k_b0/feature_vector,https://tfhub.dev/google/i
imagenet/mobilenet_v1_100_224/classification,https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/5?tf-hub-format=compressed,skip,119718 - Accuracy issue
magenta/arbitrary-image-stylization-v1-256,https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2?tf-hub-format=compressed
small_bert/bert_en_uncased_L-4_H-256_A-4,https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/2?tf-hub-format=compressed,skip,119718 - Accuracy issue
movinet/a5/base/kinetics-600/classification,https://tfhub.dev/tensorflow/movinet/a5/base/kinetics-600/classification/3?tf-hub-format=compressed
# secure notebook models
unet/industrial/class_1,https://tfhub.dev/nvidia/unet/industrial/class_1/1?tf-hub-format=compressed
movenet/singlepose/thunder,https://tfhub.dev/google/movenet/singlepose/thunder/4?tf-hub-format=compressed