[TF FE] Fix FusedBatchNormV3 in case of mean and variance empty tensors (#15675)

* [TF FE] Fix FusedBatchNormV3 in case of mean and variance empty tensors

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Add to nightly

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-02-13 18:23:48 +04:00 committed by GitHub
parent ee2e9d497c
commit 65b69fe8ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 11 deletions

View File

@ -104,18 +104,28 @@ void compute_weighted_batch_mean_and_variance(const Output<Node>& x,
// for weighted_variance it is similar
// (1 - exponential_avg_factor) * variance + exponential_avg_factor * batch_variance,
// where batch_variance is the variance of the current batch in x.
// compute weighted_batch_mean
auto const_one = make_shared<Constant>(exp_avg_factor_const.get_element_type(), Shape{}, 1);
auto one_minus_exp_avg_factor = make_shared<Subtract>(const_one, exp_avg_factor_const);
auto bt_mean_by_exp_avg = make_shared<Multiply>(batch_mean, exp_avg_factor_const);
weighted_batch_mean = make_shared<Multiply>(mean, one_minus_exp_avg_factor)->output(0);
weighted_batch_mean = make_shared<Add>(bt_mean_by_exp_avg, weighted_batch_mean);
// compute weighted_batch_mean
// no need to weight in case of empty tensor mean
if (mean.get_partial_shape().is_static() && shape_size(mean.get_shape()) > 0) {
auto bt_mean_by_exp_avg = make_shared<Multiply>(batch_mean, exp_avg_factor_const);
weighted_batch_mean = make_shared<Multiply>(mean, one_minus_exp_avg_factor)->output(0);
weighted_batch_mean = make_shared<Add>(bt_mean_by_exp_avg, weighted_batch_mean);
} else {
weighted_batch_mean = batch_mean;
}
// compute weighted_batch_variance
auto bt_variance_by_exp_avg = make_shared<Multiply>(batch_variance, exp_avg_factor_const);
weighted_batch_variance = make_shared<Multiply>(variance, one_minus_exp_avg_factor)->output(0);
weighted_batch_variance = make_shared<Add>(bt_variance_by_exp_avg, weighted_batch_variance)->output(0);
// no need to weight in case of empty tensor variance
if (variance.get_partial_shape().is_static() && shape_size(variance.get_shape()) > 0) {
auto bt_variance_by_exp_avg = make_shared<Multiply>(batch_variance, exp_avg_factor_const);
weighted_batch_variance = make_shared<Multiply>(variance, one_minus_exp_avg_factor)->output(0);
weighted_batch_variance = make_shared<Add>(bt_variance_by_exp_avg, weighted_batch_variance)->output(0);
} else {
weighted_batch_variance = batch_variance;
}
}
void compute_fused_batch_norm_inference(const NodeContext& node,

View File

@ -25,7 +25,7 @@ class TestFusedBatchNorm(CommonTFLayerTest):
return inputs_data
def create_fused_batch_norm_net(self, x_shape, epsilon, exponential_avg_factor, data_format, is_training,
fbn_version):
fbn_version, empty_mean_variance=False):
fbn_dict = {
"v1": tf.raw_ops.FusedBatchNorm,
"v2": tf.raw_ops.FusedBatchNormV2,
@ -38,8 +38,13 @@ class TestFusedBatchNorm(CommonTFLayerTest):
if data_format == "NCHW":
c_dim = x_shape[1]
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
mean = tf.compat.v1.placeholder(tf.float32, [c_dim], 'mean')
variance = tf.compat.v1.placeholder(tf.float32, [c_dim], 'variance')
if empty_mean_variance:
mean = tf.constant([], dtype=tf.float32)
variance = tf.constant([], dtype=tf.float32)
else:
mean = tf.compat.v1.placeholder(tf.float32, [c_dim], 'mean')
variance = tf.compat.v1.placeholder(tf.float32, [c_dim], 'variance')
scale = tf.compat.v1.placeholder(tf.float32, [c_dim], 'scale')
offset = tf.compat.v1.placeholder(tf.float32, [c_dim], 'offset')
fbn_func = fbn_dict[fbn_version]
@ -85,6 +90,8 @@ class TestFusedBatchNorm(CommonTFLayerTest):
dict(x_shape=[5, 4, 3, 2], epsilon=0.0005, exponential_avg_factor=0.0, data_format="NCHW",
is_training=False,
fbn_version="v3"),
dict(x_shape=[5, 10, 8, 2], epsilon=0.0002, exponential_avg_factor=0.2, data_format="NHWC",
is_training=True, fbn_version="v3", empty_mean_variance=False),
]
@pytest.mark.parametrize("params", test_data_basic)