[TF FE] Optimize and correct FusedBatchNorm translator - avoid Transposes (#14135)
* [TF FE] Optimize FusedBatchNorm translator - avoid Transposes Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix FusedBatchNorm: correct compute for batch_mean and batch_variance * Add layer tests and fix based on them * Mark failing test cases and bind to issue tickets * Work around tests for FusedBatchNorm in inference mode * Use separate fictitious constants for reserved outputs Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
a0a6e1c141
commit
cc219d085e
@ -3,55 +3,265 @@
|
||||
//
|
||||
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset8;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
namespace {
|
||||
void generate_axes_range_except_c(const Output<Node>& x_rank, bool is_nhwc, Output<Node>& axes_no_c) {
|
||||
auto const_one = make_shared<Constant>(element::i32, Shape{}, 1);
|
||||
if (is_nhwc) {
|
||||
auto const_zero = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto rank_minus_one = make_shared<Subtract>(x_rank, const_one);
|
||||
axes_no_c = make_shared<Range>(const_zero, rank_minus_one, const_one, element::i32)->output(0);
|
||||
} else {
|
||||
auto const_zero = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto const_two = make_shared<Constant>(element::i32, Shape{}, 2);
|
||||
// in NCHW layout case
|
||||
axes_no_c = make_shared<Range>(const_two, x_rank, const_one, element::i32)->output(0);
|
||||
// add batch dimension as well
|
||||
axes_no_c = make_shared<Concat>(OutputVector{const_zero, axes_no_c}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
OutputVector translate_fused_batch_norm_op(const NodeContext& node) {
|
||||
void adjust_coeff(const Output<Node>& x_rank,
|
||||
element::Type x_type,
|
||||
const Output<Node>& coeff,
|
||||
Output<Node>& adjusted_coeff,
|
||||
bool is_nhwc) {
|
||||
// adjust types of the normalizing coefficients
|
||||
// they can vary for FusedBatchNormV2 and FusedBatchNormV3 operations
|
||||
adjusted_coeff = make_shared<Convert>(coeff, x_type)->output(0);
|
||||
|
||||
if (is_nhwc) {
|
||||
return;
|
||||
}
|
||||
|
||||
// in case NCHW format, we need to unsqueeze the normalizing coefficient by lower dimensions
|
||||
// to have the coefficient of shape [C, 1, 1]
|
||||
// generate axes range for unsqueezing the coefficient
|
||||
auto const_one = make_shared<Constant>(element::i32, Shape{}, 1);
|
||||
auto x_rank_minus_one = make_shared<Subtract>(x_rank, const_one);
|
||||
auto axes = make_shared<Range>(const_one, x_rank_minus_one, const_one, element::i32);
|
||||
|
||||
// adjust shapes of the normalizing coefficients
|
||||
adjusted_coeff = make_shared<Unsqueeze>(adjusted_coeff, axes)->output(0);
|
||||
}
|
||||
|
||||
void compute_batch_mean_and_variance(const Output<Node>& x,
|
||||
const Output<Node>& x_rank,
|
||||
bool is_nhwc,
|
||||
Output<Node>& batch_mean,
|
||||
Output<Node>& batch_variance) {
|
||||
// generate axes range for reduction operation
|
||||
Output<Node> reduce_axes;
|
||||
generate_axes_range_except_c(x_rank, is_nhwc, reduce_axes);
|
||||
|
||||
// compute batch_mean
|
||||
batch_mean = make_shared<ReduceMean>(x, reduce_axes, false)->output(0);
|
||||
|
||||
// compute batch_variance
|
||||
auto unsqueezed_batch_mean = make_shared<Unsqueeze>(batch_mean, reduce_axes);
|
||||
batch_variance = make_shared<Subtract>(x, unsqueezed_batch_mean)->output(0);
|
||||
auto const_two = make_shared<Constant>(x.get_element_type(), Shape{}, 2);
|
||||
batch_variance = make_shared<Power>(batch_variance, const_two);
|
||||
batch_variance = make_shared<ReduceMean>(batch_variance, reduce_axes)->output(0);
|
||||
|
||||
// for training mode, variance of FusedBatchNorm is computed with Bessel's correction
|
||||
// batch_variance must be multiplied by n / (n - 1), where n is a number of samples
|
||||
// to compute variance
|
||||
auto x_shape = make_shared<ShapeOf>(x, element::i32);
|
||||
auto gather_axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto needed_dim_values = make_shared<Gather>(x_shape, reduce_axes, gather_axis);
|
||||
auto n = make_shared<ReduceProd>(needed_dim_values, gather_axis, false)->output(0);
|
||||
n = make_shared<Convert>(n, batch_variance.get_element_type())->output(0);
|
||||
auto const_one = make_shared<Constant>(batch_variance.get_element_type(), Shape{}, 1);
|
||||
auto bessel_correction = make_shared<Subtract>(n, const_one)->output(0);
|
||||
bessel_correction = make_shared<Divide>(n, bessel_correction);
|
||||
|
||||
// adjust batch_variance by bessel correction
|
||||
batch_variance = make_shared<Multiply>(batch_variance, bessel_correction);
|
||||
}
|
||||
|
||||
void compute_weighted_batch_mean_and_variance(const Output<Node>& x,
|
||||
const Output<Node>& mean,
|
||||
const Output<Node>& variance,
|
||||
const Output<Node>& exp_avg_factor_const,
|
||||
const Output<Node>& x_rank,
|
||||
const Output<Node>& batch_mean,
|
||||
const Output<Node>& batch_variance,
|
||||
Output<Node>& weighted_batch_mean,
|
||||
Output<Node>& weighted_batch_variance) {
|
||||
// compute weighted_mean and weighted_variance by the following formula:
|
||||
// (1 - exponential_avg_factor) * mean + exponential_avg_factor * batch_mean,
|
||||
// where batch_mean is the mean of the current batch in x.
|
||||
// for weighted_variance it is similar
|
||||
// (1 - exponential_avg_factor) * variance + exponential_avg_factor * batch_variance,
|
||||
// where batch_variance is the variance of the current batch in x.
|
||||
|
||||
// compute weighted_batch_mean
|
||||
auto const_one = make_shared<Constant>(exp_avg_factor_const.get_element_type(), Shape{}, 1);
|
||||
auto one_minus_exp_avg_factor = make_shared<Subtract>(const_one, exp_avg_factor_const);
|
||||
auto bt_mean_by_exp_avg = make_shared<Multiply>(batch_mean, exp_avg_factor_const);
|
||||
weighted_batch_mean = make_shared<Multiply>(mean, one_minus_exp_avg_factor)->output(0);
|
||||
weighted_batch_mean = make_shared<Add>(bt_mean_by_exp_avg, weighted_batch_mean);
|
||||
|
||||
// compute weighted_batch_variance
|
||||
auto bt_variance_by_exp_avg = make_shared<Multiply>(batch_variance, exp_avg_factor_const);
|
||||
weighted_batch_variance = make_shared<Multiply>(variance, one_minus_exp_avg_factor)->output(0);
|
||||
weighted_batch_variance = make_shared<Add>(bt_variance_by_exp_avg, weighted_batch_variance)->output(0);
|
||||
}
|
||||
|
||||
void compute_fused_batch_norm_inference(const NodeContext& node,
|
||||
Output<Node>& fused_batch_norm,
|
||||
Output<Node>& batch_mean,
|
||||
Output<Node>& batch_variance) {
|
||||
// when it is inference mode, there are five inputs: x, scale, offset, mean, and variance
|
||||
// The formula for FusedBatchNorm is the following:
|
||||
// (x - mean) / sqrt(variance + eps) * scale + offset
|
||||
default_op_checks(node, 5, {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"});
|
||||
auto ng_input = node.get_input(0);
|
||||
auto ng_scale = node.get_input(1);
|
||||
auto ng_offset = node.get_input(2);
|
||||
auto ng_mean = node.get_input(3);
|
||||
auto ng_variance = node.get_input(4);
|
||||
|
||||
bool is_v3 = node.get_op_type() == "FusedBatchNormV3";
|
||||
|
||||
auto data_format = node.get_attribute<std::string>("data_format");
|
||||
TENSORFLOW_OP_VALIDATION(node, data_format == "NHWC" || data_format == "NCHW", "Unsupported data format");
|
||||
auto x = node.get_input(0);
|
||||
auto scale = node.get_input(1);
|
||||
auto offset = node.get_input(2);
|
||||
auto mean = node.get_input(3);
|
||||
auto variance = node.get_input(4);
|
||||
|
||||
// retrieve attributes
|
||||
auto epsilon = node.get_attribute<float>("epsilon", 0.0001f);
|
||||
auto data_format = node.get_attribute<string>("data_format", "NHWC");
|
||||
bool is_nhwc = (data_format == "NHWC");
|
||||
|
||||
OPENVINO_DEBUG << "data_format: " << data_format;
|
||||
// create auxiliary Constant nodes for some attributes: epsilon and exponential_avg_factor
|
||||
auto eps_const = make_shared<Constant>(x.get_element_type(), Shape{}, epsilon);
|
||||
auto half = make_shared<Constant>(x.get_element_type(), Shape{}, 0.5);
|
||||
|
||||
// TODO: where does 0.0001 come from?
|
||||
auto tf_epsilon = node.get_attribute<float>("epsilon", 0.0001f);
|
||||
// adjust normalizing coefficients: scale, offset, mean, and variance
|
||||
auto x_rank = compute_subgraph_scalar_rank(x, element::i32, true);
|
||||
Output<Node> adjusted_scale, adjusted_offset, adjusted_mean, adjusted_variance;
|
||||
adjust_coeff(x_rank, x.get_element_type(), scale, adjusted_scale, is_nhwc);
|
||||
adjust_coeff(x_rank, x.get_element_type(), offset, adjusted_offset, is_nhwc);
|
||||
adjust_coeff(x_rank, x.get_element_type(), mean, adjusted_mean, is_nhwc);
|
||||
adjust_coeff(x_rank, x.get_element_type(), variance, adjusted_variance, is_nhwc);
|
||||
|
||||
OPENVINO_DEBUG << "epsilon: " << tf_epsilon;
|
||||
// perform the main part of the transformation
|
||||
// 1. subtract mean from the input
|
||||
auto x_minus_mean = make_shared<Subtract>(x, adjusted_mean);
|
||||
|
||||
convert_nhwc_to_nchw(is_nhwc, ng_input, ov::Rank(4));
|
||||
// 2. normalize the input after the shifting
|
||||
auto var_plus_eps = make_shared<Add>(adjusted_variance, eps_const);
|
||||
auto root_sq_var = make_shared<Power>(var_plus_eps, half);
|
||||
auto normalized_x = make_shared<Divide>(x_minus_mean, root_sq_var);
|
||||
|
||||
auto ng_batch_norm =
|
||||
make_shared<BatchNormInference>(ng_input, ng_scale, ng_offset, ng_mean, ng_variance, tf_epsilon)->output(0);
|
||||
convert_nchw_to_nhwc(is_nhwc, ng_batch_norm, ov::Rank(4));
|
||||
// 3. scale the input after the normalization
|
||||
auto scaled_x = make_shared<Multiply>(normalized_x, adjusted_scale);
|
||||
fused_batch_norm = make_shared<Add>(scaled_x, adjusted_offset)->output(0);
|
||||
|
||||
// TODO: Why are there so many? Is it correct?
|
||||
OutputVector result = {ng_batch_norm, ng_mean, ng_variance, ng_mean, ng_variance};
|
||||
if (is_v3) {
|
||||
// FusedBatchNormV3 has 6 outputs
|
||||
result.push_back(ng_mean); // reserve_space_3
|
||||
// mean and variance go as outputs for batch_mean and batch_variance
|
||||
// exponential_avg_factor has no affect on it
|
||||
batch_mean = mean;
|
||||
batch_variance = variance;
|
||||
}
|
||||
|
||||
void compute_fused_batch_norm_training(const NodeContext& node,
|
||||
Output<Node>& fused_batch_norm,
|
||||
Output<Node>& batch_mean,
|
||||
Output<Node>& batch_variance) {
|
||||
// when is_training is True, the operations have just three inputs: x, scale, and offset
|
||||
default_op_checks(node, 3, {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"});
|
||||
auto x = node.get_input(0);
|
||||
auto scale = node.get_input(1);
|
||||
auto offset = node.get_input(2);
|
||||
|
||||
// retrieve attributes
|
||||
auto epsilon = node.get_attribute<float>("epsilon", 0.0001f);
|
||||
auto data_format = node.get_attribute<string>("data_format", "NHWC");
|
||||
bool is_nhwc = (data_format == "NHWC");
|
||||
|
||||
// adjust normalizing coefficients: scale, offset
|
||||
auto x_rank = compute_subgraph_scalar_rank(x, element::i32, true);
|
||||
Output<Node> adjusted_scale, adjusted_offset;
|
||||
adjust_coeff(x_rank, x.get_element_type(), scale, adjusted_scale, is_nhwc);
|
||||
adjust_coeff(x_rank, x.get_element_type(), offset, adjusted_offset, is_nhwc);
|
||||
|
||||
// generate axes for MVN operations
|
||||
Output<Node> mvn_axes;
|
||||
generate_axes_range_except_c(x_rank, is_nhwc, mvn_axes);
|
||||
|
||||
// perform mean-variance normalization
|
||||
auto mvn = make_shared<MVN>(x, mvn_axes, true, epsilon, ov::op::MVNEpsMode::INSIDE_SQRT);
|
||||
|
||||
// perform scaling and shifting
|
||||
fused_batch_norm = make_shared<Multiply>(mvn, adjusted_scale)->output(0);
|
||||
fused_batch_norm = make_shared<Add>(fused_batch_norm, adjusted_offset)->output(0);
|
||||
|
||||
// compute two other outputs: batch_mean and batch_variance
|
||||
compute_batch_mean_and_variance(x, x_rank, is_nhwc, batch_mean, batch_variance);
|
||||
if (node.get_input_size() >= 5) {
|
||||
Output<Node> weighted_batch_mean, weighted_batch_variance;
|
||||
auto exponential_avg_factor = node.get_attribute<float>("exponential_avg_factor", 1.0f);
|
||||
auto exp_avg_factor_const = make_shared<Constant>(scale.get_element_type(), Shape{}, exponential_avg_factor);
|
||||
auto mean = node.get_input(3);
|
||||
auto variance = node.get_input(4);
|
||||
compute_weighted_batch_mean_and_variance(x,
|
||||
mean,
|
||||
variance,
|
||||
exp_avg_factor_const,
|
||||
x_rank,
|
||||
batch_mean,
|
||||
batch_variance,
|
||||
weighted_batch_mean,
|
||||
weighted_batch_variance);
|
||||
batch_mean = weighted_batch_mean;
|
||||
batch_variance = weighted_batch_variance;
|
||||
}
|
||||
set_node_name(node.get_name(), ng_batch_norm.get_node_shared_ptr());
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_fused_batch_norm_op(const NodeContext& node) {
|
||||
default_op_checks(node, 3, {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"});
|
||||
auto scale = node.get_input(1);
|
||||
|
||||
// understand which version of the FusedBatchNorm operation
|
||||
auto op_type = node.get_op_type();
|
||||
auto is_v3 = (op_type == "FusedBatchNormV3");
|
||||
|
||||
// there are two modes of FusedBatchNorm operations: training and inference
|
||||
// compute three meaningful outputs: fused_batch_norm, batch_mean and batch_variance
|
||||
Output<Node> fused_batch_norm, batch_mean, batch_variance;
|
||||
auto is_training = node.get_attribute<bool>("is_training", true);
|
||||
if (is_training) {
|
||||
compute_fused_batch_norm_training(node, fused_batch_norm, batch_mean, batch_variance);
|
||||
} else {
|
||||
compute_fused_batch_norm_inference(node, fused_batch_norm, batch_mean, batch_variance);
|
||||
}
|
||||
|
||||
// create fictious output for reserved outputs of FusedBatchNorm operation
|
||||
auto zero_const = make_shared<Constant>(scale.get_element_type(), Shape{}, 0);
|
||||
auto zero_const2 = make_shared<Constant>(scale.get_element_type(), Shape{}, 0);
|
||||
|
||||
// set node names and tensor names
|
||||
set_node_name(node.get_name(), fused_batch_norm.get_node_shared_ptr());
|
||||
set_node_name(node.get_name() + ":1", batch_mean.get_node_shared_ptr());
|
||||
set_node_name(node.get_name() + ":2", batch_variance.get_node_shared_ptr());
|
||||
set_node_name(node.get_name() + ":3", zero_const);
|
||||
set_node_name(node.get_name() + ":4", zero_const2);
|
||||
|
||||
OutputVector results = OutputVector{fused_batch_norm, batch_mean, batch_variance, zero_const, zero_const2};
|
||||
if (is_v3) {
|
||||
auto zero_const3 = make_shared<Constant>(scale.get_element_type(), Shape{}, 0);
|
||||
set_node_name(node.get_name() + ":5", zero_const3);
|
||||
results.push_back(zero_const3);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -32,6 +32,7 @@ class TestKerasBatchNormalization(CommonTF2LayerTest):
|
||||
@pytest.mark.parametrize("params", test_data_float32)
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.precommit_tf_fe
|
||||
def test_keras_batch_normalization_float32(self, params, ie_device, precision, ir_version,
|
||||
temp_dir, use_old_api, use_new_frontend):
|
||||
self._test(*self.create_keras_batch_normalization_net(**params, ir_version=ir_version),
|
||||
|
97
tests/layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py
Normal file
97
tests/layer_tests/tensorflow_tests/test_tf_FusedBatchNorm.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestFusedBatchNorm(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
inputs_data = {}
|
||||
x_shape = inputs_info['x']
|
||||
inputs_data['x'] = np.random.randint(-10, 10, x_shape)
|
||||
scale_shape = inputs_info['scale']
|
||||
inputs_data['scale'] = np.random.randint(-10, 10, scale_shape)
|
||||
offset_shape = inputs_info['offset']
|
||||
inputs_data['offset'] = np.random.randint(-10, 10, offset_shape)
|
||||
if 'mean' in inputs_info:
|
||||
mean_shape = inputs_info['mean']
|
||||
inputs_data['mean'] = np.random.randint(-10, 10, mean_shape)
|
||||
if 'variance' in inputs_info:
|
||||
variance_shape = inputs_info['variance']
|
||||
inputs_data['variance'] = np.random.randint(0, 10, variance_shape)
|
||||
return inputs_data
|
||||
|
||||
def create_fused_batch_norm_net(self, x_shape, epsilon, exponential_avg_factor, data_format, is_training,
|
||||
fbn_version):
|
||||
fbn_dict = {
|
||||
"v1": tf.raw_ops.FusedBatchNorm,
|
||||
"v2": tf.raw_ops.FusedBatchNormV2,
|
||||
"v3": tf.raw_ops.FusedBatchNormV3
|
||||
}
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
c_dim = x_shape[-1]
|
||||
if data_format == "NCHW":
|
||||
c_dim = x_shape[1]
|
||||
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
|
||||
mean = tf.compat.v1.placeholder(tf.float32, [c_dim], 'mean')
|
||||
variance = tf.compat.v1.placeholder(tf.float32, [c_dim], 'variance')
|
||||
scale = tf.compat.v1.placeholder(tf.float32, [c_dim], 'scale')
|
||||
offset = tf.compat.v1.placeholder(tf.float32, [c_dim], 'offset')
|
||||
fbn_func = fbn_dict[fbn_version]
|
||||
if not is_training:
|
||||
# due to limitation in the layer test infrastructure - it finds tensor names for Parameter and Result nodes
|
||||
# by get_any_name() that cannot work if some nodes fused to Parameter or Result node have multiple tensor names
|
||||
# This issue is tracked in 97192 ticket
|
||||
# Now it is worked around by guarding Parameter Node with AddV2
|
||||
mean = tf.raw_ops.AddV2(x=mean, y=tf.constant(2.0, dtype=tf.float32))
|
||||
variance = tf.raw_ops.AddV2(x=variance, y=tf.constant(2.0, dtype=tf.float32))
|
||||
fused_batch_norm = fbn_func(x=x, scale=scale, offset=offset, epsilon=epsilon,
|
||||
mean=mean, variance=variance,
|
||||
exponential_avg_factor=exponential_avg_factor, data_format=data_format,
|
||||
is_training=is_training, name="FusedBatchNorm")
|
||||
tf.identity(fused_batch_norm[0], name='y')
|
||||
tf.identity(fused_batch_norm[1], name='batch_mean')
|
||||
tf.identity(fused_batch_norm[2], name='batch_variance')
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
# Currently these cases are passing on Windows, looks a problem with CPU on Linux
|
||||
pytest.param(dict(x_shape=[2, 3, 4, 5], epsilon=0.0001, exponential_avg_factor=1, data_format="NHWC",
|
||||
is_training=True,
|
||||
fbn_version="v1"), marks=pytest.mark.xfail(reason="97191")),
|
||||
pytest.param(dict(x_shape=[2, 3, 4, 5], epsilon=0.0005, exponential_avg_factor=0.3, data_format="NHWC",
|
||||
is_training=True,
|
||||
fbn_version="v2"), marks=pytest.mark.xfail(reason="97191")),
|
||||
pytest.param(dict(x_shape=[3, 2, 1, 5], epsilon=0.00003, exponential_avg_factor=0.7, data_format="NCHW",
|
||||
is_training=True,
|
||||
fbn_version="v3"), marks=pytest.mark.xfail(reason="97191")),
|
||||
pytest.param(dict(x_shape=[3, 4, 2, 5], epsilon=0.0003, exponential_avg_factor=0.0, data_format="NCHW",
|
||||
is_training=True,
|
||||
fbn_version="v3"), marks=pytest.mark.xfail(reason="97191")),
|
||||
dict(x_shape=[2, 3, 4, 5], epsilon=0.0001, exponential_avg_factor=1, data_format="NHWC",
|
||||
is_training=False,
|
||||
fbn_version="v1"),
|
||||
dict(x_shape=[3, 2, 1, 4], epsilon=0.0005, exponential_avg_factor=0.3, data_format="NCHW",
|
||||
is_training=False,
|
||||
fbn_version="v2"),
|
||||
dict(x_shape=[5, 4, 3, 2], epsilon=0.0005, exponential_avg_factor=0.0, data_format="NCHW",
|
||||
is_training=False,
|
||||
fbn_version="v3"),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
def test_fused_batch_norm_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_fused_batch_norm_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user