diff --git a/docs/doxygen/ie_docs.xml b/docs/doxygen/ie_docs.xml
index 5ea61f802a3..0406f639fcb 100644
--- a/docs/doxygen/ie_docs.xml
+++ b/docs/doxygen/ie_docs.xml
@@ -113,7 +113,7 @@
-
+
diff --git a/docs/ops/normalization/BatchNormInference_5.md b/docs/ops/normalization/BatchNormInference_5.md
new file mode 100644
index 00000000000..aab4daee36c
--- /dev/null
+++ b/docs/ops/normalization/BatchNormInference_5.md
@@ -0,0 +1,98 @@
+## BatchNormInference {#openvino_docs_ops_normalization_BatchNormInference_5}
+
+**Versioned name**: *BatchNormInference-5
+
+**Category**: *Normalization*
+
+**Short description**: *BatchNormInference* layer normalizes a `input` tensor by `mean` and `variance`, and applies a scale (`gamma`) to it, as well as an offset (`beta`).
+
+**Attributes**:
+
+* *epsilon*
+ * **Description**: *epsilon* is the number to be added to the variance to avoid division by zero when normalizing a value. For example, *epsilon* equal to 0.001 means that 0.001 is added to the variance.
+ * **Range of values**: a positive floating-point number
+ * **Type**: `float`
+ * **Default value**: None
+ * **Required**: *yes*
+
+**Inputs**
+
+* **1**: `input` - input tensor with data for normalization. At least a 2D tensor of type T, the second dimension represents the channel axis and must have a span of at least 1. **Required.**
+* **2**: `gamma` - gamma scaling for normalized value. A 1D tensor of type T with the same span as input's channel axis. **Required.**
+* **3**: `beta` - bias added to the scaled normalized value. A 1D tensor of type T with the same span as input's channel axis.. **Required.**
+* **4**: `mean` - value for mean normalization. A 1D tensor of type T with the same span as input's channel axis.. **Required.**
+* **5**: `variance` - value for variance normalization. A 1D tensor of type T with the same span as input's channel axis.. **Required.**
+
+**Outputs**
+
+* **1**: The result of normalization. A tensor of the same type and shape with 1st input tensor.
+
+**Types**
+
+* *T*: any numeric type.
+
+**Mathematical Formulation**
+
+*BatchNormInference* normalizes the output in each hidden layer.
+* **Input**: Values of \f$x\f$ over a mini-batch:
+ \f[
+ \beta = \{ x_{1...m} \}
+ \f]
+* **Parameters to learn**: \f$ \gamma, \beta\f$
+* **Output**:
+ \f[
+ \{ o_{i} = BN_{\gamma, \beta} ( b_{i} ) \}
+ \f]
+* **Mini-batch mean**:
+ \f[
+ \mu_{\beta} \leftarrow \frac{1}{m}\sum_{i=1}^{m}b_{i}
+ \f]
+* **Mini-batch variance**:
+ \f[
+ \sigma_{\beta }^{2}\leftarrow \frac{1}{m}\sum_{i=1}^{m} ( b_{i} - \mu_{\beta} )^{2}
+ \f]
+* **Normalize**:
+ \f[
+ \hat{b_{i}} \leftarrow \frac{b_{i} - \mu_{\beta}}{\sqrt{\sigma_{\beta }^{2} + \epsilon }}
+ \f]
+* **Scale and shift**:
+ \f[
+ o_{i} \leftarrow \gamma\hat{b_{i}} + \beta = BN_{\gamma ,\beta } ( b_{i} )
+ \f]
+
+**Example**
+
+```xml
+
+
+
+
+ 1
+ 3
+ 224
+ 224
+
+
+ 3
+
+
+ 3
+
+
+ 3
+
+
+ 3
+
+
+
+
+```
+
diff --git a/docs/ops/opset5.md b/docs/ops/opset5.md
index 75db70238f8..7db25f894d5 100644
--- a/docs/ops/opset5.md
+++ b/docs/ops/opset5.md
@@ -19,7 +19,7 @@ declared in `namespace opset5`.
* [Atan](arithmetic/Atan_1.md)
* [Atanh](arithmetic/Atanh_3.md)
* [AvgPool](pooling/AvgPool_1.md)
-* [BatchNormInference](normalization/BatchNormInference_1.md)
+* [BatchNormInference](normalization/BatchNormInference_5.md)
* [BatchToSpace](movement/BatchToSpace_2.md)
* [BinaryConvolution](convolution/BinaryConvolution_1.md)
* [Broadcast](movement/Broadcast_3.md)
diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
index 8a992a40f49..0cf25b94f7d 100644
--- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
+++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
@@ -732,7 +732,6 @@ void convertFunctionToICNNNetwork(const std::shared_ptr>(),
std::make_shared>(),
std::make_shared>(),
- std::make_shared>(),
std::make_shared>(),
std::make_shared>(),
std::make_shared>(),
diff --git a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp
index c9adf5389e1..b1274b914ce 100644
--- a/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp
+++ b/inference-engine/src/legacy_api/src/ie_cnn_layer_builder_ngraph.cpp
@@ -667,12 +667,6 @@ CNNLayer::Ptr NodeConverter::createLayer(const std::shared_
return res;
}
-template <>
-CNNLayer::Ptr NodeConverter::createLayer(
- const std::shared_ptr& layer) const {
- THROW_IE_EXCEPTION << "BatchNormInference operation should be fused or decomposed";
-}
-
template <>
CNNLayer::Ptr NodeConverter::createLayer(const std::shared_ptr& layer) const {
LayerParams params = {layer->get_friendly_name(), "Squeeze",
diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp
index d07a3e9250d..a8208392fa2 100644
--- a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp
+++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp
@@ -394,7 +394,6 @@ std::shared_ptr V10Parser::createNode(const std::vector>("Asin"),
std::make_shared>("Atan"),
std::make_shared>("AvgPool"),
- std::make_shared>("BatchNormInference"),
std::make_shared>("Ceiling"),
std::make_shared>("Clamp"),
std::make_shared>("Concat"),
@@ -951,20 +950,6 @@ std::shared_ptr V10Parser::LayerCreator:
activations, activations_alpha, activations_beta, clip);
}
-// BatchNormInference layer
-template <>
-std::shared_ptr V10Parser::LayerCreator::createLayer(
- const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream,
- const GenericLayerParams& layerParsePrms) {
- checkParameters(inputs, layerParsePrms, 5);
- pugi::xml_node dn = node.child("data");
- if (dn.empty())
- THROW_IE_EXCEPTION << "Cannot read parameter for " << getType() << " layer with name: " << layerParsePrms.name;
-
- float eps = GetFloatAttr(dn, "eps");
- return std::make_shared(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], eps);
-}
-
// CTCGreedyDecoder layer
template <>
std::shared_ptr V10Parser::LayerCreator::createLayer(
diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp
index 8ca8879eb7c..7845d835cd5 100644
--- a/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp
+++ b/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp
@@ -11,6 +11,7 @@
#include
#include
+#include
using namespace std;
@@ -18,6 +19,7 @@ namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API BatchNormDecomposition;
+class TRANSFORMATIONS_API BatchNormV5Decomposition;
} // namespace pass
} // namespace ngraph
@@ -27,3 +29,9 @@ public:
NGRAPH_RTTI_DECLARATION;
BatchNormDecomposition();
};
+
+class ngraph::pass::BatchNormV5Decomposition: public ngraph::pass::MatcherPass {
+public:
+ NGRAPH_RTTI_DECLARATION;
+ BatchNormV5Decomposition();
+};
diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
index 059faa72337..7f2a8353f2a 100644
--- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
+++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
@@ -93,6 +93,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher();
decomp->add_matcher();
decomp->add_matcher();
+ decomp->add_matcher();
decomp->set_name("ngraph::pass::CommonDecompositions");
// CF is required after all decompositions
diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp
index 97e31823855..e76e30278f2 100644
--- a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp
+++ b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp
@@ -8,8 +8,11 @@
#include
#include
+#include
#include
+using namespace ngraph;
+
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposition", 0);
ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
@@ -43,39 +46,107 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
const auto& input_type = m_input->get_element_type();
// scale_add = variance + eps
- auto scale_add = make_shared(m_var, opset1::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
+ auto scale_add = make_shared(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
// scale = sqrt(variance + eps)
- auto scale = make_shared(scale_add);
+ auto scale = make_shared(scale_add);
// Divide `gamma` by `sqrt(variance + eps)`
- auto gamma_div_scale = std::make_shared(m_gamma, scale);
+ auto gamma_div_scale = std::make_shared(m_gamma, scale);
size_t dims_to_add = m_input->get_shape().size() - 2;
Shape input_aligned_shape = m_gamma->get_shape();
for (size_t i = 0; i < dims_to_add; ++i)
input_aligned_shape.push_back(1);
- auto new_shape = opset1::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
+ auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
- auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true);
- auto beta_aligned = make_shared(m_beta, new_shape, true);
- auto mean_aligned = make_shared(m_mean, new_shape, true);
+ auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true);
+ auto beta_aligned = make_shared(m_beta, new_shape, true);
+ auto mean_aligned = make_shared(m_mean, new_shape, true);
// input_sub_mean = input - mean
- auto input_sub_mean = register_new_node(m_input, mean_aligned);
+ auto input_sub_mean = register_new_node(m_input, mean_aligned);
// Multiply `input - mean` and `gamma / sqrt(variance + eps)`
- auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned);
+ auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned);
// Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta`
- auto add = std::make_shared(mul, beta_aligned);
+ auto add = std::make_shared(mul, beta_aligned);
add->set_friendly_name(m_bn->get_friendly_name());
copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned,
- beta_aligned, input_sub_mean, mul, add});
+ beta_aligned, input_sub_mean, mul, add});
+
+ replace_node(m_bn, add);
+
+ return true;
+ };
+ auto m = std::make_shared(bn, "BatchNormDecomposition");
+ this->register_matcher(m, callback);
+}
+
+NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);
+
+ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
+ Shape shape{2, 2, 1, 1};
+ auto input = make_shared(element::f32, shape);
+ auto mean_shape = Shape{2};
+ auto mean = make_shared(element::f32, mean_shape);
+ auto var_shape = Shape{2};
+ auto var = make_shared(element::f32, var_shape);
+ auto gamma_shape = Shape{2};
+ auto gamma = make_shared(element::f32, gamma_shape);
+ auto beta_shape = Shape{2};
+ auto beta = make_shared(element::f32, beta_shape);
+ auto bn = make_shared(input, gamma, beta, mean, var, 0.001);
+
+ ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) {
+ auto pattern_map = m.get_pattern_map();
+
+ auto m_input = pattern_map[input];
+ auto m_gamma = pattern_map[gamma];
+ auto m_beta = pattern_map[beta];
+ auto m_mean = pattern_map[mean];
+ auto m_var = pattern_map[var];
+
+ // TODO: check that all input shapes are static
+
+ auto m_bn = dynamic_pointer_cast(m.get_match_root());
+ if (!m_bn) {
+ return false;
+ }
+
+ const auto& input_type = m_input->get_element_type();
+ // scale_add = variance + eps
+ auto scale_add = make_shared(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
+ // scale = sqrt(variance + eps)
+ auto scale = make_shared(scale_add);
+ // Divide `gamma` by `sqrt(variance + eps)`
+ auto gamma_div_scale = std::make_shared(m_gamma, scale);
+
+ size_t dims_to_add = m_input->get_shape().size() - 2;
+ Shape input_aligned_shape = m_gamma->get_shape();
+ for (size_t i = 0; i < dims_to_add; ++i)
+ input_aligned_shape.push_back(1);
+ auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
+
+ auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true);
+ auto beta_aligned = make_shared(m_beta, new_shape, true);
+ auto mean_aligned = make_shared(m_mean, new_shape, true);
+
+ // input_sub_mean = input - mean
+ auto input_sub_mean = register_new_node(m_input, mean_aligned);
+ // Multiply `input - mean` and `gamma / sqrt(variance + eps)`
+ auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned);
+ // Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta`
+ auto add = std::make_shared(mul, beta_aligned);
+
+ add->set_friendly_name(m_bn->get_friendly_name());
+
+ copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned,
+ beta_aligned, input_sub_mean, mul, add});
replace_node(m_bn, add);
return true;
};
-
auto m = std::make_shared(bn, "BatchNormDecomposition");
this->register_matcher(m, callback);
}
diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp
index 45028d17471..94c8b367f61 100644
--- a/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp
+++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/batch_norm_inference_tests.cpp
@@ -87,8 +87,8 @@ TEST_F(NGraphReaderTests, ReadBatchNormInferenceNetwork) {
-
-
+
+
1
diff --git a/inference-engine/tests/ngraph_functions/src/batch_norm.cpp b/inference-engine/tests/ngraph_functions/src/batch_norm.cpp
index 14f4035e9e4..d45972b8271 100644
--- a/inference-engine/tests/ngraph_functions/src/batch_norm.cpp
+++ b/inference-engine/tests/ngraph_functions/src/batch_norm.cpp
@@ -24,7 +24,7 @@ std::shared_ptr makeBatchNormInference(const ngraph::Output&
std::uniform_real_distribution dis(0.0, 10.0);
std::generate(values.begin(), values.end(), [&dis, &gen]() { return dis(gen); });
auto variance = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, !random);
- return std::make_shared(data, gamma, beta, mean, variance, epsilon);
+ return std::make_shared(data, gamma, beta, mean, variance, epsilon);
}
} // namespace builder
} // namespace ngraph
diff --git a/ngraph/core/include/ngraph/op/batch_norm.hpp b/ngraph/core/include/ngraph/op/batch_norm.hpp
index a81870da059..78aba1e134c 100644
--- a/ngraph/core/include/ngraph/op/batch_norm.hpp
+++ b/ngraph/core/include/ngraph/op/batch_norm.hpp
@@ -31,8 +31,7 @@ namespace ngraph
class NGRAPH_API BatchNormInference : public Op
{
public:
- static constexpr NodeTypeInfo type_info{"BatchNormInference", 0};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
BatchNormInference() = default;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
@@ -66,6 +65,44 @@ namespace ngraph
double m_epsilon;
};
} // namespace v0
- using v0::BatchNormInference;
+ namespace v5
+ {
+ class NGRAPH_API BatchNormInference : public Op
+ {
+ public:
+ NGRAPH_RTTI_DECLARATION;
+ BatchNormInference() = default;
+ /// \param input [., C, ...]
+ /// \param gamma gamma scaling for normalized value. [C]
+ /// \param beta bias added to the scaled normalized value [C]
+ /// \param mean value for mean normalization [C]
+ /// \param variance value for variance normalization [C]
+ /// \param epsilon Avoids divsion by 0 if input has 0 variance
+ BatchNormInference(const Output& input,
+ const Output& gamma,
+ const Output& beta,
+ const Output& mean,
+ const Output& variance,
+ double epsilon);
+
+ bool visit_attributes(AttributeVisitor& visitor) override;
+
+ void validate_and_infer_types() override;
+
+ double get_eps_value() const { return m_epsilon; }
+ void set_eps_value(double epsilon) { m_epsilon = epsilon; }
+ std::shared_ptr
+ clone_with_new_inputs(const OutputVector& new_args) const override;
+
+ private:
+ static constexpr size_t INPUT_DATA = 0;
+ static constexpr size_t INPUT_GAMMA = 1;
+ static constexpr size_t INPUT_BETA = 2;
+ static constexpr size_t INPUT_MEAN = 3;
+ static constexpr size_t INPUT_VARIANCE = 4;
+
+ double m_epsilon;
+ };
+ } // namespace v0
}
}
diff --git a/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp b/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp
index 2fbc5f6c825..d22b0df92fe 100644
--- a/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp
+++ b/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp
@@ -25,7 +25,7 @@ NGRAPH_OP(Add, ngraph::op::v1)
NGRAPH_OP(Asin, ngraph::op::v0)
NGRAPH_OP(Atan, ngraph::op::v0)
NGRAPH_OP(AvgPool, ngraph::op::v1)
-NGRAPH_OP(BatchNormInference, ngraph::op::v0)
+NGRAPH_OP(BatchNormInference, ngraph::op::v5)
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
NGRAPH_OP(Broadcast, ngraph::op::v3)
NGRAPH_OP(Bucketize, ngraph::op::v3)
diff --git a/ngraph/core/src/op/batch_norm.cpp b/ngraph/core/src/op/batch_norm.cpp
index 04470f9944c..a778c4c15a6 100644
--- a/ngraph/core/src/op/batch_norm.cpp
+++ b/ngraph/core/src/op/batch_norm.cpp
@@ -23,27 +23,27 @@
using namespace std;
using namespace ngraph;
-constexpr NodeTypeInfo op::BatchNormInference::type_info;
+NGRAPH_RTTI_DEFINITION(op::v0::BatchNormInference, "batchNormInference", 0);
-op::BatchNormInference::BatchNormInference(const Output& input,
- const Output& gamma,
- const Output& beta,
- const Output& mean,
- const Output& variance,
- double epsilon)
+op::v0::BatchNormInference::BatchNormInference(const Output& input,
+ const Output& gamma,
+ const Output& beta,
+ const Output& mean,
+ const Output& variance,
+ double epsilon)
: Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
-bool op::BatchNormInference::visit_attributes(AttributeVisitor& visitor)
+bool op::v0::BatchNormInference::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
-void op::BatchNormInference::validate_and_infer_types()
+void op::v0::BatchNormInference::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
@@ -67,9 +67,60 @@ void op::BatchNormInference::validate_and_infer_types()
}
std::shared_ptr
- op::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const
+ op::v0::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared(
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
}
+
+NGRAPH_RTTI_DEFINITION(op::v5::BatchNormInference, "BatchNormInference", 5);
+
+op::v5::BatchNormInference::BatchNormInference(const Output& input,
+ const Output& gamma,
+ const Output& beta,
+ const Output& mean,
+ const Output& variance,
+ double epsilon)
+ : Op({input, gamma, beta, mean, variance})
+ , m_epsilon(epsilon)
+{
+ constructor_validate_and_infer_types();
+}
+
+bool op::v5::BatchNormInference::visit_attributes(AttributeVisitor& visitor)
+{
+ visitor.on_attribute("epsilon", m_epsilon);
+ return true;
+}
+
+void op::v5::BatchNormInference::validate_and_infer_types()
+{
+ element::Type result_et;
+ PartialShape result_batch_shape;
+ PartialShape result_channel_shape; // unused here
+
+ set_output_size(1);
+ std::tie(result_et, result_batch_shape, result_channel_shape) =
+ infer_batch_norm_forward(this,
+ get_input_element_type(INPUT_DATA),
+ get_input_element_type(INPUT_GAMMA),
+ get_input_element_type(INPUT_BETA),
+ get_input_element_type(INPUT_MEAN),
+ get_input_element_type(INPUT_VARIANCE),
+ get_input_partial_shape(INPUT_DATA),
+ get_input_partial_shape(INPUT_GAMMA),
+ get_input_partial_shape(INPUT_BETA),
+ get_input_partial_shape(INPUT_MEAN),
+ get_input_partial_shape(INPUT_VARIANCE));
+
+ set_output_type(0, result_et, result_batch_shape);
+}
+
+std::shared_ptr
+ op::v5::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const
+{
+ check_new_args_count(this, new_args);
+ return std::make_shared(
+ new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), m_epsilon);
+}
diff --git a/ngraph/python/src/ngraph/opset5/__init__.py b/ngraph/python/src/ngraph/opset5/__init__.py
index 8c115b0a403..0141b1e284d 100644
--- a/ngraph/python/src/ngraph/opset5/__init__.py
+++ b/ngraph/python/src/ngraph/opset5/__init__.py
@@ -26,7 +26,7 @@ from ngraph.opset3.ops import assign
from ngraph.opset1.ops import atan
from ngraph.opset4.ops import atanh
from ngraph.opset1.ops import avg_pool
-from ngraph.opset1.ops import batch_norm_inference
+from ngraph.opset5.ops import batch_norm_inference
from ngraph.opset2.ops import batch_to_space
from ngraph.opset1.ops import binary_convolution
from ngraph.opset3.ops import broadcast
diff --git a/ngraph/python/src/ngraph/opset5/ops.py b/ngraph/python/src/ngraph/opset5/ops.py
index 0c84162bb1c..ab6f9c53f2a 100644
--- a/ngraph/python/src/ngraph/opset5/ops.py
+++ b/ngraph/python/src/ngraph/opset5/ops.py
@@ -58,6 +58,32 @@ _get_node_factory_opset5 = partial(_get_node_factory, "opset5")
# -------------------------------------------- ops ------------------------------------------------
+@nameable_op
+def batch_norm_inference(
+ data: NodeInput,
+ gamma: NodeInput,
+ beta: NodeInput,
+ mean: NodeInput,
+ variance: NodeInput,
+ epsilon: float,
+ name: Optional[str] = None,
+) -> Node:
+ """Perform layer normalizes a input tensor by mean and variance with appling scale and offset.
+
+ :param data: The input tensor with data for normalization.
+ :param gamma: The scalar scaling for normalized value.
+ :param beta: The bias added to the scaled normalized value.
+ :param mean: The value for mean normalization.
+ :param variance: The value for variance normalization.
+ :param epsilon: The number to be added to the variance to avoid division
+ by zero when normalizing a value.
+ :param name: The optional name of the output node.
+ :return: The new node which performs BatchNormInference.
+ """
+ inputs = as_nodes(data, gamma, beta, mean, variance)
+ return _get_node_factory_opset5().create("BatchNormInference", inputs, {"epsilon": epsilon})
+
+
@nameable_op
def gather_nd(
data: NodeInput,
diff --git a/ngraph/test/backend/batch_norm.in.cpp b/ngraph/test/backend/batch_norm.in.cpp
index 49f277c2de0..d4b501c8c9d 100644
--- a/ngraph/test/backend/batch_norm.in.cpp
+++ b/ngraph/test/backend/batch_norm.in.cpp
@@ -46,7 +46,8 @@ public:
auto Beta = make_shared(etype, channel_shape);
auto Mean = make_shared(etype, channel_shape);
auto Variance = make_shared(etype, channel_shape);
- auto BN = make_shared(Input, Gamma, Beta, Mean, Variance, epsilon);
+ auto BN =
+ make_shared(Input, Gamma, Beta, Mean, Variance, epsilon);
m_function = make_shared(BN, ParameterVector{Input, Gamma, Beta, Mean, Variance});
m_input = backend->create_tensor(etype, input_shape);
@@ -285,7 +286,52 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_inference_parameters_duplication)
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
- auto bn = make_shared(input, mvgb, mvgb, mvgb, mvgb, eps);
+ auto bn = make_shared(input, mvgb, mvgb, mvgb, mvgb, eps);
+
+ auto f = make_shared(bn, ParameterVector{input, mvgb, mvgb, mvgb, mvgb});
+ auto backend = runtime::Backend::create("${BACKEND_NAME}");
+ // Create some tensors for input/output
+ auto _input = backend->create_tensor(element::f32, input_shape);
+ copy_data(_input,
+ vector{0.54881352f,
+ 0.71518934f,
+ 0.60276335f,
+ 0.54488319f,
+ 0.42365479f,
+ 0.64589411f,
+ 0.4375872f,
+ 0.89177299f});
+
+ auto _mvgb = backend->create_tensor(element::f32, mvgb_shape);
+ copy_data(_mvgb, vector{1.0f, 1.0f});
+ auto bn_output = backend->create_tensor(element::f32, shape_r);
+
+ vector expected_result{0.54903894f,
+ 0.71533161f,
+ 0.60296183f,
+ 0.54511058f,
+ 0.42394274f,
+ 0.64607101f,
+ 0.43786817f,
+ 0.89182704f};
+ auto handle = backend->compile(f);
+ handle->call_with_validate({bn_output}, {_input, _mvgb, _mvgb, _mvgb, _mvgb});
+
+ ASSERT_TRUE(
+ ngraph::test::all_close(expected_result, read_vector(bn_output), 1e-3f, 1e-4f));
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, batch_norm_inference_parameters_duplication_v5)
+{
+ auto input_shape = Shape{2, 2, 2, 1};
+ auto input = make_shared(element::f32, input_shape);
+
+ auto mvgb_shape = Shape{2};
+ auto mvgb = make_shared(element::f32, mvgb_shape);
+
+ double eps = 0.001;
+ auto shape_r = Shape{2, 2, 2, 1};
+ auto bn = make_shared(input, mvgb, mvgb, mvgb, mvgb, eps);
auto f = make_shared(bn, ParameterVector{input, mvgb, mvgb, mvgb, mvgb});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
@@ -334,7 +380,56 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_inference_b2c2h2w1)
auto var = make_shared(element::f32, var_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
- auto bn = make_shared(input, gamma, beta, mean, var, eps);
+ auto bn = make_shared(input, gamma, beta, mean, var, eps);
+
+ auto f = make_shared(bn, ParameterVector{input, gamma, beta, mean, var});
+ auto backend = runtime::Backend::create("${BACKEND_NAME}");
+ // Create some tensors for input/output
+ auto _input = backend->create_tensor(element::f32, input_shape);
+ copy_data(_input,
+ vector{0.54881352f,
+ 0.71518934f,
+ 0.60276335f,
+ 0.54488319f,
+ 0.42365479f,
+ 0.64589411f,
+ 0.4375872f,
+ 0.89177299f});
+
+ auto _gamma = backend->create_tensor(element::f32, gamma_shape);
+ copy_data(_gamma, vector{1.0f, 1.0f});
+ auto _beta = backend->create_tensor(element::f32, beta_shape);
+ copy_data(_beta, vector{0.0f, 0.0f});
+ auto _mean = backend->create_tensor(element::f32, mean_shape);
+ copy_data(_mean, vector{0.583388f, 0.619252f});
+ auto _var = backend->create_tensor(element::f32, var_shape);
+ copy_data(_var, vector{0.0119972f, 0.0282681f});
+ auto bn_output = backend->create_tensor(element::f32, shape_r);
+
+ vector expected_result{
+ -0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
+ auto handle = backend->compile(f);
+ handle->call_with_validate({bn_output}, {_input, _gamma, _beta, _mean, _var});
+
+ ASSERT_TRUE(
+ ngraph::test::all_close(expected_result, read_vector(bn_output), 1e-3f, 1e-4f));
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_inference_b2c2h2w1_v5)
+{
+ auto input_shape = Shape{2, 2, 2, 1};
+ auto input = make_shared(element::f32, input_shape);
+ auto gamma_shape = Shape{2};
+ auto gamma = make_shared(element::f32, gamma_shape);
+ auto beta_shape = Shape{2};
+ auto beta = make_shared(element::f32, beta_shape);
+ auto mean_shape = Shape{2};
+ auto mean = make_shared(element::f32, mean_shape);
+ auto var_shape = Shape{2};
+ auto var = make_shared(element::f32, var_shape);
+ double eps = 0.001;
+ auto shape_r = Shape{2, 2, 2, 1};
+ auto bn = make_shared(input, gamma, beta, mean, var, eps);
auto f = make_shared(bn, ParameterVector{input, gamma, beta, mean, var});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
diff --git a/ngraph/test/op_is.cpp b/ngraph/test/op_is.cpp
index a4504f20d03..fe64a5f334a 100644
--- a/ngraph/test/op_is.cpp
+++ b/ngraph/test/op_is.cpp
@@ -85,7 +85,7 @@ namespace
void op_is_BatchNormInference()
{
- op::BatchNormInference node;
+ op::v0::BatchNormInference node;
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest
index fd516f5fc18..bbad53c3178 100644
--- a/ngraph/test/runtime/ie/unit_test.manifest
+++ b/ngraph/test/runtime/ie/unit_test.manifest
@@ -997,6 +997,7 @@ batch_norm_training_0eps_f64
# Function inputs number differ from number of given inputs
batch_norm_inference_parameters_duplication
+batch_norm_inference_parameters_duplication_v5
backwards_abs
backwards_acos
diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp
index be993f6c980..b3b35503c16 100644
--- a/ngraph/test/runtime/interpreter/int_executable.hpp
+++ b/ngraph/test/runtime/interpreter/int_executable.hpp
@@ -250,8 +250,8 @@ protected:
}
case OP_TYPEID::BatchNormInference:
{
- const ngraph::op::BatchNormInference* bn =
- static_cast(&node);
+ const ngraph::op::v0::BatchNormInference* bn =
+ static_cast(&node);
reference::batch_norm_inference(bn->get_eps_value(),
args[0]->get_data_ptr(),
args[1]->get_data_ptr(),
@@ -262,6 +262,20 @@ protected:
node.get_input_shape(2));
break;
}
+ case OP_TYPEID::BatchNormInference_v5:
+ {
+ const ngraph::op::v5::BatchNormInference* bn =
+ static_cast(&node);
+ reference::batch_norm_inference(bn->get_eps_value(),
+ args[1]->get_data_ptr(),
+ args[2]->get_data_ptr(),
+ args[0]->get_data_ptr(),
+ args[3]->get_data_ptr(),
+ args[4]->get_data_ptr(),
+ out[0]->get_data_ptr(),
+ node.get_input_shape(0));
+ break;
+ }
case OP_TYPEID::BroadcastLike: break;
case OP_TYPEID::Ceiling:
{
diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp
index f9e1ee46126..1570b695104 100644
--- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp
+++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp
@@ -57,6 +57,7 @@ NGRAPH_OP(GatherND, op::v5)
NGRAPH_OP(LSTMSequence, op::v5)
NGRAPH_OP(GRUSequence, op::v5)
NGRAPH_OP(RNNSequence, op::v5)
+NGRAPH_OP(BatchNormInference, op::v5)
NGRAPH_OP(Round, op::v5)
NGRAPH_OP(LogSoftmax, op::v5)
#undef ID_SUFFIX
diff --git a/ngraph/test/runtime/opset0_tbl.hpp b/ngraph/test/runtime/opset0_tbl.hpp
index 2d918225cd4..abf6a5833a7 100644
--- a/ngraph/test/runtime/opset0_tbl.hpp
+++ b/ngraph/test/runtime/opset0_tbl.hpp
@@ -56,7 +56,7 @@ NGRAPH_OP(Add, ngraph::op)
NGRAPH_OP(Asin, ngraph::op)
NGRAPH_OP(Atan, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op::v0)
-NGRAPH_OP(BatchNormInference, ngraph::op)
+NGRAPH_OP(BatchNormInference, ngraph::op::v0)
NGRAPH_OP(Broadcast, ngraph::op)
NGRAPH_OP(BroadcastLike, ngraph::op)
NGRAPH_OP(Ceiling, ngraph::op)
diff --git a/ngraph/test/type_prop/batch_norm.cpp b/ngraph/test/type_prop/batch_norm.cpp
index c2cd9841bb9..0ab600a8a51 100644
--- a/ngraph/test/type_prop/batch_norm.cpp
+++ b/ngraph/test/type_prop/batch_norm.cpp
@@ -41,7 +41,8 @@ TEST(type_prop, batch_norm_inference_partial_all_rank_dynamic)
auto mean = make_shared(mean_et, mean_shape);
auto variance = make_shared(variance_et, variance_shape);
- auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
ASSERT_EQ(bn->get_output_size(), 1);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
@@ -69,7 +70,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_ok)
auto mean = make_shared(mean_et, mean_shape);
auto variance = make_shared(variance_et, variance_shape);
- auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
ASSERT_EQ(bn->get_output_size(), 1);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
@@ -100,8 +102,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_zero_chan
try
{
- auto bn =
- make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Zero channel count not detected";
}
catch (const NodeValidationFailure& error)
@@ -134,7 +136,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_dynamic_some_rank_static
auto mean = make_shared(mean_et, mean_shape);
auto variance = make_shared(variance_et, variance_shape);
- auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
ASSERT_EQ(bn->get_output_size(), 1);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
@@ -163,8 +166,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_dynamic_some_rank_static
try
{
- auto bn =
- make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
}
catch (const NodeValidationFailure& error)
@@ -202,8 +205,8 @@ TEST(type_prop,
try
{
- auto bn =
- make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent gamma/beta/mean/variance shape not detected";
}
catch (const NodeValidationFailure& error)
@@ -240,8 +243,8 @@ TEST(type_prop,
try
{
- auto bn =
- make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent gamma/beta/mean/variance channel count not detected";
}
catch (const NodeValidationFailure& error)
@@ -275,7 +278,8 @@ TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_some_stat
auto mean = make_shared(mean_et, mean_shape);
auto variance = make_shared(variance_et, variance_shape);
- auto bn = make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
ASSERT_EQ(bn->get_output_size(), 1);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
@@ -306,8 +310,315 @@ TEST(type_prop,
try
{
- auto bn =
- make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
+ FAIL() << "Inconsistent input/gamma/beta/mean/variance channel count not detected";
+ }
+ catch (const NodeValidationFailure& error)
+ {
+ EXPECT_HAS_SUBSTRING(error.what(),
+ std::string("Input channel dimension (4) does not match "
+ "shape for gamma/beta/mean/variance ({3})"));
+ }
+ catch (...)
+ {
+ FAIL() << "Deduced type check failed for unexpected reason";
+ }
+}
+
+TEST(type_prop, batch_norm_inference_partial_all_rank_dynamic_v5)
+{
+ PartialShape data_batch_shape{PartialShape::dynamic()};
+ PartialShape gamma_shape{PartialShape::dynamic()};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{PartialShape::dynamic()};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+
+ ASSERT_EQ(bn->get_output_size(), 1);
+ ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
+ ASSERT_TRUE(bn->get_output_partial_shape(0).rank().is_dynamic());
+}
+
+TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_ok_v5)
+{
+ PartialShape data_batch_shape{
+ 64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
+ PartialShape gamma_shape{PartialShape::dynamic()};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{PartialShape::dynamic()};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+
+ ASSERT_EQ(bn->get_output_size(), 1);
+ ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
+ ASSERT_TRUE(bn->get_output_partial_shape(0).same_scheme(
+ PartialShape{64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
+}
+
+TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_zero_channels_v5)
+{
+ PartialShape data_batch_shape{
+ Dimension::dynamic(), 0, Dimension::dynamic(), Dimension::dynamic()};
+ PartialShape gamma_shape{PartialShape::dynamic()};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{PartialShape::dynamic()};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ try
+ {
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
+ FAIL() << "Zero channel count not detected";
+ }
+ catch (const NodeValidationFailure& error)
+ {
+ EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
+ }
+ catch (...)
+ {
+ FAIL() << "Deduced type check failed for unexpected reason";
+ }
+}
+
+TEST(type_prop, batch_norm_inference_partial_input_rank_dynamic_some_rank_static_dynamic_ok_v5)
+{
+ PartialShape data_batch_shape{PartialShape::dynamic()};
+ PartialShape gamma_shape{Dimension::dynamic()};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{Dimension::dynamic()};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+
+ ASSERT_EQ(bn->get_output_size(), 1);
+ ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
+ ASSERT_TRUE(bn->get_output_partial_shape(0).rank().is_dynamic());
+}
+
+TEST(type_prop,
+ batch_norm_inference_partial_input_rank_dynamic_some_rank_static_dynamic_wrong_rank_v5)
+{
+ PartialShape data_batch_shape{PartialShape::dynamic()};
+ PartialShape gamma_shape{Dimension::dynamic(), Dimension::dynamic()};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{Dimension::dynamic(), Dimension::dynamic()};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ try
+ {
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
+ FAIL() << "Wrong gamma/beta/mean/variance shape not detected";
+ }
+ catch (const NodeValidationFailure& error)
+ {
+ EXPECT_HAS_SUBSTRING(
+ error.what(),
+ std::string("Shape for gamma/beta/mean/variance ({?,?}) does not have rank 1"));
+ }
+ catch (...)
+ {
+ FAIL() << "Deduced type check failed for unexpected reason";
+ }
+}
+
+TEST(type_prop,
+ batch_norm_inference_partial_input_rank_dynamic_some_rank_static_dynamic_inconsistent_rank_v5)
+{
+ PartialShape data_batch_shape{PartialShape::dynamic()};
+ PartialShape gamma_shape{3, Dimension::dynamic()};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{Dimension::dynamic()};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ try
+ {
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
+ FAIL() << "Inconsistent gamma/beta/mean/variance shape not detected";
+ }
+ catch (const NodeValidationFailure& error)
+ {
+ EXPECT_HAS_SUBSTRING(error.what(),
+ std::string("Shapes for gamma/beta/mean/variance do not match"));
+ }
+ catch (...)
+ {
+ FAIL() << "Deduced type check failed for unexpected reason";
+ }
+}
+
+TEST(type_prop,
+ batch_norm_inference_partial_input_rank_dynamic_some_static_inconsistent_channel_count_v5)
+{
+ PartialShape data_batch_shape{PartialShape::dynamic()};
+ PartialShape gamma_shape{3};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{4};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ try
+ {
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
+ FAIL() << "Inconsistent gamma/beta/mean/variance channel count not detected";
+ }
+ catch (const NodeValidationFailure& error)
+ {
+ EXPECT_HAS_SUBSTRING(error.what(),
+ std::string("Shapes for gamma/beta/mean/variance do not match"));
+ }
+ catch (...)
+ {
+ FAIL() << "Deduced type check failed for unexpected reason";
+ }
+}
+
+TEST(type_prop, batch_norm_inference_partial_input_rank_static_dynamic_some_static_ok_v5)
+{
+ PartialShape data_batch_shape{64, Dimension::dynamic(), Dimension::dynamic(), 224};
+ PartialShape gamma_shape{3};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{3};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ auto bn =
+ make_shared(data_batch, gamma, beta, mean, variance, epsilon);
+
+ ASSERT_EQ(bn->get_output_size(), 1);
+ ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
+ ASSERT_TRUE(bn->get_output_partial_shape(0).same_scheme(
+ PartialShape{64, 3, Dimension::dynamic(), 224}));
+}
+
+TEST(
+ type_prop,
+ batch_norm_inference_partial_input_rank_static_dynamic_some_static_inconsistent_channel_count_v5)
+{
+ PartialShape data_batch_shape{64, 4, Dimension::dynamic(), 224};
+ PartialShape gamma_shape{3};
+ PartialShape beta_shape{PartialShape::dynamic()};
+ PartialShape mean_shape{3};
+ PartialShape variance_shape{PartialShape::dynamic()};
+ double epsilon = 0.001;
+ element::Type data_batch_et = element::f32;
+ element::Type gamma_et = element::f32;
+ element::Type beta_et = element::f32;
+ element::Type mean_et = element::f32;
+ element::Type variance_et = element::f32;
+
+ auto data_batch = make_shared(data_batch_et, data_batch_shape);
+ auto gamma = make_shared(gamma_et, gamma_shape);
+ auto beta = make_shared(beta_et, beta_shape);
+ auto mean = make_shared(mean_et, mean_shape);
+ auto variance = make_shared(variance_et, variance_shape);
+
+ try
+ {
+ auto bn = make_shared(
+ data_batch, gamma, beta, mean, variance, epsilon);
FAIL() << "Inconsistent input/gamma/beta/mean/variance channel count not detected";
}
catch (const NodeValidationFailure& error)