[TF Hub][TF FE] Fix 5D case for FusedBatchNorm (#19904)
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
df19699e3a
commit
d90ceb93d1
@ -146,7 +146,7 @@ void compute_fused_batch_norm_inference(const NodeContext& node,
|
||||
// retrieve attributes
|
||||
auto epsilon = node.get_attribute<float>("epsilon", 0.0001f);
|
||||
auto data_format = node.get_attribute<string>("data_format", "NHWC");
|
||||
bool is_nhwc = (data_format == "NHWC");
|
||||
bool is_nhwc = (data_format == "NHWC" || data_format == "NDHWC");
|
||||
|
||||
// create auxiliary Constant nodes for some attributes: epsilon and exponential_avg_factor
|
||||
auto eps_const = create_same_type_const_scalar<float>(x, epsilon);
|
||||
|
@ -35,7 +35,7 @@ class TestFusedBatchNorm(CommonTFLayerTest):
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
c_dim = x_shape[-1]
|
||||
if data_format == "NCHW":
|
||||
if data_format == "NCHW" or data_format == "NCDHW":
|
||||
c_dim = x_shape[1]
|
||||
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
|
||||
if empty_mean_variance:
|
||||
@ -92,6 +92,11 @@ class TestFusedBatchNorm(CommonTFLayerTest):
|
||||
fbn_version="v3"),
|
||||
dict(x_shape=[5, 10, 8, 2], epsilon=0.0002, exponential_avg_factor=0.2, data_format="NHWC",
|
||||
is_training=True, fbn_version="v3", empty_mean_variance=False),
|
||||
# 5D cases
|
||||
dict(x_shape=[5, 4, 3, 2, 3], epsilon=0.0005, exponential_avg_factor=0.0, data_format="NCDHW",
|
||||
is_training=False, fbn_version="v3"),
|
||||
dict(x_shape=[3, 4, 3, 3, 2], epsilon=0.0003, exponential_avg_factor=0.0, data_format="NDHWC",
|
||||
is_training=False, fbn_version="v3"),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
|
@ -9,6 +9,7 @@ imagenet/efficientnet_v2_imagenet1k_b0/feature_vector,https://tfhub.dev/google/i
|
||||
imagenet/mobilenet_v1_100_224/classification,https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/5?tf-hub-format=compressed,skip,119718 - Accuracy issue
|
||||
magenta/arbitrary-image-stylization-v1-256,https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2?tf-hub-format=compressed
|
||||
small_bert/bert_en_uncased_L-4_H-256_A-4,https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/2?tf-hub-format=compressed,skip,119718 - Accuracy issue
|
||||
movinet/a5/base/kinetics-600/classification,https://tfhub.dev/tensorflow/movinet/a5/base/kinetics-600/classification/3?tf-hub-format=compressed
|
||||
# secure notebook models
|
||||
unet/industrial/class_1,https://tfhub.dev/nvidia/unet/industrial/class_1/1?tf-hub-format=compressed
|
||||
movenet/singlepose/thunder,https://tfhub.dev/google/movenet/singlepose/thunder/4?tf-hub-format=compressed
|
||||
|
Loading…
Reference in New Issue
Block a user