[TF FE] Fix FusedBatchNormV3 in case of mean and variance empty tensors (#15675)
* [TF FE] Fix FusedBatchNormV3 in case of mean and variance empty tensors Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add to nightly --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
ee2e9d497c
commit
65b69fe8ca
@ -104,18 +104,28 @@ void compute_weighted_batch_mean_and_variance(const Output<Node>& x,
|
||||
// for weighted_variance it is similar
|
||||
// (1 - exponential_avg_factor) * variance + exponential_avg_factor * batch_variance,
|
||||
// where batch_variance is the variance of the current batch in x.
|
||||
|
||||
// compute weighted_batch_mean
|
||||
auto const_one = make_shared<Constant>(exp_avg_factor_const.get_element_type(), Shape{}, 1);
|
||||
auto one_minus_exp_avg_factor = make_shared<Subtract>(const_one, exp_avg_factor_const);
|
||||
auto bt_mean_by_exp_avg = make_shared<Multiply>(batch_mean, exp_avg_factor_const);
|
||||
weighted_batch_mean = make_shared<Multiply>(mean, one_minus_exp_avg_factor)->output(0);
|
||||
weighted_batch_mean = make_shared<Add>(bt_mean_by_exp_avg, weighted_batch_mean);
|
||||
|
||||
// compute weighted_batch_mean
|
||||
// no need to weight in case of empty tensor mean
|
||||
if (mean.get_partial_shape().is_static() && shape_size(mean.get_shape()) > 0) {
|
||||
auto bt_mean_by_exp_avg = make_shared<Multiply>(batch_mean, exp_avg_factor_const);
|
||||
weighted_batch_mean = make_shared<Multiply>(mean, one_minus_exp_avg_factor)->output(0);
|
||||
weighted_batch_mean = make_shared<Add>(bt_mean_by_exp_avg, weighted_batch_mean);
|
||||
} else {
|
||||
weighted_batch_mean = batch_mean;
|
||||
}
|
||||
|
||||
// compute weighted_batch_variance
|
||||
auto bt_variance_by_exp_avg = make_shared<Multiply>(batch_variance, exp_avg_factor_const);
|
||||
weighted_batch_variance = make_shared<Multiply>(variance, one_minus_exp_avg_factor)->output(0);
|
||||
weighted_batch_variance = make_shared<Add>(bt_variance_by_exp_avg, weighted_batch_variance)->output(0);
|
||||
// no need to weight in case of empty tensor variance
|
||||
if (variance.get_partial_shape().is_static() && shape_size(variance.get_shape()) > 0) {
|
||||
auto bt_variance_by_exp_avg = make_shared<Multiply>(batch_variance, exp_avg_factor_const);
|
||||
weighted_batch_variance = make_shared<Multiply>(variance, one_minus_exp_avg_factor)->output(0);
|
||||
weighted_batch_variance = make_shared<Add>(bt_variance_by_exp_avg, weighted_batch_variance)->output(0);
|
||||
} else {
|
||||
weighted_batch_variance = batch_variance;
|
||||
}
|
||||
}
|
||||
|
||||
void compute_fused_batch_norm_inference(const NodeContext& node,
|
||||
|
@ -25,7 +25,7 @@ class TestFusedBatchNorm(CommonTFLayerTest):
|
||||
return inputs_data
|
||||
|
||||
def create_fused_batch_norm_net(self, x_shape, epsilon, exponential_avg_factor, data_format, is_training,
|
||||
fbn_version):
|
||||
fbn_version, empty_mean_variance=False):
|
||||
fbn_dict = {
|
||||
"v1": tf.raw_ops.FusedBatchNorm,
|
||||
"v2": tf.raw_ops.FusedBatchNormV2,
|
||||
@ -38,8 +38,13 @@ class TestFusedBatchNorm(CommonTFLayerTest):
|
||||
if data_format == "NCHW":
|
||||
c_dim = x_shape[1]
|
||||
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
|
||||
mean = tf.compat.v1.placeholder(tf.float32, [c_dim], 'mean')
|
||||
variance = tf.compat.v1.placeholder(tf.float32, [c_dim], 'variance')
|
||||
if empty_mean_variance:
|
||||
mean = tf.constant([], dtype=tf.float32)
|
||||
variance = tf.constant([], dtype=tf.float32)
|
||||
else:
|
||||
mean = tf.compat.v1.placeholder(tf.float32, [c_dim], 'mean')
|
||||
variance = tf.compat.v1.placeholder(tf.float32, [c_dim], 'variance')
|
||||
|
||||
scale = tf.compat.v1.placeholder(tf.float32, [c_dim], 'scale')
|
||||
offset = tf.compat.v1.placeholder(tf.float32, [c_dim], 'offset')
|
||||
fbn_func = fbn_dict[fbn_version]
|
||||
@ -85,6 +90,8 @@ class TestFusedBatchNorm(CommonTFLayerTest):
|
||||
dict(x_shape=[5, 4, 3, 2], epsilon=0.0005, exponential_avg_factor=0.0, data_format="NCHW",
|
||||
is_training=False,
|
||||
fbn_version="v3"),
|
||||
dict(x_shape=[5, 10, 8, 2], epsilon=0.0002, exponential_avg_factor=0.2, data_format="NHWC",
|
||||
is_training=True, fbn_version="v3", empty_mean_variance=False),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
|
Loading…
Reference in New Issue
Block a user