GroupNormalization core op (#17781)

This commit is contained in:
Tomasz Dołbniak 2023-06-01 08:49:02 +02:00 committed by GitHub
parent 0b6b16c83a
commit 02124aece4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 713 additions and 7 deletions

View File

@ -38,9 +38,9 @@ The operation is applied per batch, per group of channels. This means that the e
* **1**: ``data`` - The input tensor to be normalized. The type of this tensor is *T*. The tensor's shape is arbitrary but the first two dimensions are interpreted as ``batch`` and ``channels`` respectively. **Required.**
* **2**: ``scale`` - 1D tensor of type *T* containing the scale values for each group. The expected shape of this tensor is ``[C]`` where ``C`` is the number of channels in the ``data`` tensor. **Required.**
* **2**: ``scale`` - 1D tensor of type *T* containing the scale values for each channel. The expected shape of this tensor is ``[C]`` where ``C`` is the number of channels in the ``data`` tensor. **Required.**
* **3**: ``bias`` - 1D tensor of type *T* containing the bias values for each group. The expected shape of this tensor is ``[C]`` where ``C`` is the number of channels in the ``data`` tensor. **Required.**
* **3**: ``bias`` - 1D tensor of type *T* containing the bias values for each channel. The expected shape of this tensor is ``[C]`` where ``C`` is the number of channels in the ``data`` tensor. **Required.**
**Outputs**

View File

@ -20,7 +20,7 @@ from tests.test_transformations.utils.utils import expect_exception
def test_wrap_type_pattern_type():
last_opset_number = 11
last_opset_number = 12
for i in range(1, last_opset_number + 1):
WrapType(f"opset{i}.Parameter")
WrapType(f"opset{i}::Parameter")

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/group_normalization.hpp"
namespace ngraph {
namespace op {
namespace v12 {
using ov::op::v12::GroupNormalization;
} // namespace v12
} // namespace op
} // namespace ngraph

View File

@ -72,6 +72,7 @@
#include "ngraph/op/grid_sample.hpp"
#include "ngraph/op/grn.hpp"
#include "ngraph/op/group_conv.hpp"
#include "ngraph/op/group_normalization.hpp"
#include "ngraph/op/gru_cell.hpp"
#include "ngraph/op/gru_sequence.hpp"
#include "ngraph/op/hard_sigmoid.hpp"

View File

@ -54,5 +54,6 @@ const NGRAPH_API OpSet& get_opset8();
const NGRAPH_API OpSet& get_opset9();
const NGRAPH_API OpSet& get_opset10();
const NGRAPH_API OpSet& get_opset11();
const NGRAPH_API OpSet& get_opset12();
const NGRAPH_API std::map<std::string, std::function<const ngraph::OpSet&()>>& get_available_opsets();
} // namespace ngraph

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/ops.hpp"
namespace ngraph {
namespace opset12 {
#define NGRAPH_OP(a, b) using b::a;
#include "ngraph/opsets/opset12_tbl.hpp"
#undef NGRAPH_OP
} // namespace opset12
} // namespace ngraph

View File

@ -0,0 +1,12 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#ifndef NGRAPH_OP
# warning "NGRAPH_OP not defined"
# define NGRAPH_OP(x, y)
#endif
#define _OPENVINO_OP_REG NGRAPH_OP
#include "openvino/opsets/opset12_tbl.hpp"
#undef _OPENVINO_OP_REG

View File

@ -0,0 +1,56 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/op.hpp"
namespace ov {
namespace op {
namespace v12 {
/// \brief GroupNormalization operation over the input tensor.
///
/// \ingroup ov_ops_cpp_api
class OPENVINO_API GroupNormalization : public Op {
public:
OPENVINO_OP("GroupNormalization", "opset12");
GroupNormalization() = default;
/// \param data The input tensor to be normalized
/// \param scale The tensor containing scale values for each channel
/// \param bias The tensor containing bias values for each channel
/// \param num_groups The number of groups that the channel dimension will be divided into
/// \param epsilon The value that prevents divisions by zero in GroupNormalization formula
GroupNormalization(const Output<Node>& data,
const Output<Node>& scale,
const Output<Node>& bias,
int64_t num_groups,
double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
int64_t get_num_groups() const {
return m_num_groups;
}
void set_num_groups(int64_t num_groups) {
m_num_groups = num_groups;
}
double get_epsilon() const {
return m_epsilon;
}
void set_epsilon(double epsilon) {
m_epsilon = epsilon;
}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
private:
int64_t m_num_groups;
double m_epsilon;
};
} // namespace v12
} // namespace op
} // namespace ov

View File

@ -71,6 +71,7 @@
#include "openvino/op/grid_sample.hpp"
#include "openvino/op/grn.hpp"
#include "openvino/op/group_conv.hpp"
#include "openvino/op/group_normalization.hpp"
#include "openvino/op/gru_cell.hpp"
#include "openvino/op/gru_sequence.hpp"
#include "openvino/op/hard_sigmoid.hpp"

View File

@ -171,6 +171,11 @@ const OPENVINO_API OpSet& get_opset10();
* @ingroup ov_opset_cpp_api
*/
const OPENVINO_API OpSet& get_opset11();
/**
* @brief Returns opset12
* @ingroup ov_opset_cpp_api
*/
const OPENVINO_API OpSet& get_opset12();
/**
* @brief Returns map of available opsets

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/ops.hpp"
namespace ov {
namespace opset12 {
#define _OPENVINO_OP_REG(a, b) using b::a;
#include "openvino/opsets/opset12_tbl.hpp"
#undef _OPENVINO_OP_REG
} // namespace opset12
} // namespace ov

View File

@ -0,0 +1,209 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#ifndef _OPENVINO_OP_REG
# warning "_OPENVINO_OP_REG not defined"
# define _OPENVINO_OP_REG(x, y)
#endif
_OPENVINO_OP_REG(Abs, ov::op::v0)
_OPENVINO_OP_REG(Acos, ov::op::v0)
_OPENVINO_OP_REG(Add, ov::op::v1)
_OPENVINO_OP_REG(Asin, ov::op::v0)
_OPENVINO_OP_REG(Atan, ov::op::v0)
_OPENVINO_OP_REG(AvgPool, ov::op::v1)
_OPENVINO_OP_REG(BatchNormInference, ov::op::v5)
_OPENVINO_OP_REG(BinaryConvolution, ov::op::v1)
_OPENVINO_OP_REG(Broadcast, ov::op::v3)
_OPENVINO_OP_REG(Bucketize, ov::op::v3)
_OPENVINO_OP_REG(CTCGreedyDecoder, ov::op::v0)
_OPENVINO_OP_REG(Ceiling, ov::op::v0)
_OPENVINO_OP_REG(Clamp, ov::op::v0)
_OPENVINO_OP_REG(Concat, ov::op::v0)
_OPENVINO_OP_REG(Constant, ov::op::v0)
_OPENVINO_OP_REG(Convert, ov::op::v0)
_OPENVINO_OP_REG(ConvertLike, ov::op::v1)
_OPENVINO_OP_REG(Convolution, ov::op::v1)
_OPENVINO_OP_REG(ConvolutionBackpropData, ov::op::v1)
_OPENVINO_OP_REG(Cos, ov::op::v0)
_OPENVINO_OP_REG(Cosh, ov::op::v0)
_OPENVINO_OP_REG(CumSum, ov::op::v0)
_OPENVINO_OP_REG(DeformablePSROIPooling, ov::op::v1)
_OPENVINO_OP_REG(DepthToSpace, ov::op::v0)
_OPENVINO_OP_REG(Divide, ov::op::v1)
_OPENVINO_OP_REG(Elu, ov::op::v0)
_OPENVINO_OP_REG(Erf, ov::op::v0)
_OPENVINO_OP_REG(Equal, ov::op::v1)
_OPENVINO_OP_REG(Exp, ov::op::v0)
_OPENVINO_OP_REG(ExtractImagePatches, ov::op::v3)
_OPENVINO_OP_REG(FakeQuantize, ov::op::v0)
_OPENVINO_OP_REG(Floor, ov::op::v0)
_OPENVINO_OP_REG(FloorMod, ov::op::v1)
_OPENVINO_OP_REG(GatherTree, ov::op::v1)
_OPENVINO_OP_REG(Greater, ov::op::v1)
_OPENVINO_OP_REG(GreaterEqual, ov::op::v1)
_OPENVINO_OP_REG(GridSample, ov::op::v9)
_OPENVINO_OP_REG(GroupConvolution, ov::op::v1)
_OPENVINO_OP_REG(GroupConvolutionBackpropData, ov::op::v1)
_OPENVINO_OP_REG(GRN, ov::op::v0)
_OPENVINO_OP_REG(HardSigmoid, ov::op::v0)
_OPENVINO_OP_REG(Less, ov::op::v1)
_OPENVINO_OP_REG(LessEqual, ov::op::v1)
_OPENVINO_OP_REG(Log, ov::op::v0)
_OPENVINO_OP_REG(LogicalAnd, ov::op::v1)
_OPENVINO_OP_REG(LogicalNot, ov::op::v1)
_OPENVINO_OP_REG(LogicalOr, ov::op::v1)
_OPENVINO_OP_REG(LogicalXor, ov::op::v1)
_OPENVINO_OP_REG(LRN, ov::op::v0)
_OPENVINO_OP_REG(LSTMCell, ov::op::v4)
_OPENVINO_OP_REG(MatMul, ov::op::v0)
_OPENVINO_OP_REG(Maximum, ov::op::v1)
_OPENVINO_OP_REG(Minimum, ov::op::v1)
_OPENVINO_OP_REG(Mod, ov::op::v1)
_OPENVINO_OP_REG(Multiply, ov::op::v1)
_OPENVINO_OP_REG(Negative, ov::op::v0)
_OPENVINO_OP_REG(NormalizeL2, ov::op::v0)
_OPENVINO_OP_REG(NotEqual, ov::op::v1)
_OPENVINO_OP_REG(OneHot, ov::op::v1)
_OPENVINO_OP_REG(PRelu, ov::op::v0)
_OPENVINO_OP_REG(PSROIPooling, ov::op::v0)
_OPENVINO_OP_REG(Pad, ov::op::v1)
_OPENVINO_OP_REG(Parameter, ov::op::v0)
_OPENVINO_OP_REG(Power, ov::op::v1)
_OPENVINO_OP_REG(PriorBoxClustered, ov::op::v0)
_OPENVINO_OP_REG(Proposal, ov::op::v4)
_OPENVINO_OP_REG(Range, ov::op::v4)
_OPENVINO_OP_REG(Relu, ov::op::v0)
_OPENVINO_OP_REG(ReduceMax, ov::op::v1)
_OPENVINO_OP_REG(ReduceLogicalAnd, ov::op::v1)
_OPENVINO_OP_REG(ReduceLogicalOr, ov::op::v1)
_OPENVINO_OP_REG(ReduceMean, ov::op::v1)
_OPENVINO_OP_REG(ReduceMin, ov::op::v1)
_OPENVINO_OP_REG(ReduceProd, ov::op::v1)
_OPENVINO_OP_REG(ReduceSum, ov::op::v1)
_OPENVINO_OP_REG(RegionYolo, ov::op::v0)
_OPENVINO_OP_REG(ReorgYolo, ov::op::v0)
_OPENVINO_OP_REG(Reshape, ov::op::v1)
_OPENVINO_OP_REG(Result, ov::op::v0)
_OPENVINO_OP_REG(ReverseSequence, ov::op::v0)
_OPENVINO_OP_REG(ROIPooling, ov::op::v0)
_OPENVINO_OP_REG(ScatterNDUpdate, ov::op::v3)
_OPENVINO_OP_REG(Select, ov::op::v1)
_OPENVINO_OP_REG(Selu, ov::op::v0)
_OPENVINO_OP_REG(Sign, ov::op::v0)
_OPENVINO_OP_REG(Sigmoid, ov::op::v0)
_OPENVINO_OP_REG(Sin, ov::op::v0)
_OPENVINO_OP_REG(Sinh, ov::op::v0)
_OPENVINO_OP_REG(Sqrt, ov::op::v0)
_OPENVINO_OP_REG(SpaceToDepth, ov::op::v0)
_OPENVINO_OP_REG(Split, ov::op::v1)
_OPENVINO_OP_REG(SquaredDifference, ov::op::v0)
_OPENVINO_OP_REG(Squeeze, ov::op::v0)
_OPENVINO_OP_REG(StridedSlice, ov::op::v1)
_OPENVINO_OP_REG(Subtract, ov::op::v1)
_OPENVINO_OP_REG(Tan, ov::op::v0)
_OPENVINO_OP_REG(Tanh, ov::op::v0)
_OPENVINO_OP_REG(TensorIterator, ov::op::v0)
_OPENVINO_OP_REG(Tile, ov::op::v0)
_OPENVINO_OP_REG(Transpose, ov::op::v1)
_OPENVINO_OP_REG(Unsqueeze, ov::op::v0)
_OPENVINO_OP_REG(VariadicSplit, ov::op::v1)
// New operations added in opset2
_OPENVINO_OP_REG(BatchToSpace, ov::op::v1)
_OPENVINO_OP_REG(SpaceToBatch, ov::op::v1)
// New operations added in opset3
_OPENVINO_OP_REG(EmbeddingBagPackedSum, ov::op::v3)
_OPENVINO_OP_REG(EmbeddingSegmentsSum, ov::op::v3)
_OPENVINO_OP_REG(EmbeddingBagOffsetsSum, ov::op::v3)
_OPENVINO_OP_REG(GRUCell, ov::op::v3)
_OPENVINO_OP_REG(NonZero, ov::op::v3)
_OPENVINO_OP_REG(RNNCell, ov::op::v0)
_OPENVINO_OP_REG(ScatterElementsUpdate, ov::op::v3)
_OPENVINO_OP_REG(ScatterUpdate, ov::op::v3)
_OPENVINO_OP_REG(ShuffleChannels, ov::op::v0)
_OPENVINO_OP_REG(ShapeOf, ov::op::v3)
// New operations added in opset4
_OPENVINO_OP_REG(Acosh, ov::op::v3)
_OPENVINO_OP_REG(Asinh, ov::op::v3)
_OPENVINO_OP_REG(Atanh, ov::op::v3)
_OPENVINO_OP_REG(CTCLoss, ov::op::v4)
_OPENVINO_OP_REG(HSwish, ov::op::v4)
_OPENVINO_OP_REG(Mish, ov::op::v4)
_OPENVINO_OP_REG(ReduceL1, ov::op::v4)
_OPENVINO_OP_REG(ReduceL2, ov::op::v4)
_OPENVINO_OP_REG(SoftPlus, ov::op::v4)
_OPENVINO_OP_REG(Swish, ov::op::v4)
// New operations added in opset5
_OPENVINO_OP_REG(GRUSequence, ov::op::v5)
_OPENVINO_OP_REG(HSigmoid, ov::op::v5)
_OPENVINO_OP_REG(LogSoftmax, ov::op::v5)
_OPENVINO_OP_REG(Loop, ov::op::v5)
_OPENVINO_OP_REG(LSTMSequence, ov::op::v5)
_OPENVINO_OP_REG(RNNSequence, ov::op::v5)
_OPENVINO_OP_REG(Round, ov::op::v5)
// New operations added in opset6
_OPENVINO_OP_REG(CTCGreedyDecoderSeqLen, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronDetectionOutput, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronGenerateProposalsSingleImage, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronPriorGridGenerator, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronROIFeatureExtractor, ov::op::v6)
_OPENVINO_OP_REG(ExperimentalDetectronTopKROIs, ov::op::v6)
_OPENVINO_OP_REG(GatherElements, ov::op::v6)
_OPENVINO_OP_REG(MVN, ov::op::v6)
_OPENVINO_OP_REG(Assign, ov::op::v6) // new version
_OPENVINO_OP_REG(ReadValue, ov::op::v6) // new version
// New operations added in opset7
_OPENVINO_OP_REG(DFT, ov::op::v7)
_OPENVINO_OP_REG(Einsum, ov::op::v7)
_OPENVINO_OP_REG(Gelu, ov::op::v7)
_OPENVINO_OP_REG(IDFT, ov::op::v7)
_OPENVINO_OP_REG(Roll, ov::op::v7)
// New operations added in opset8
_OPENVINO_OP_REG(Gather, ov::op::v8)
_OPENVINO_OP_REG(GatherND, ov::op::v8)
_OPENVINO_OP_REG(AdaptiveAvgPool, ov::op::v8)
_OPENVINO_OP_REG(AdaptiveMaxPool, ov::op::v8)
_OPENVINO_OP_REG(DeformableConvolution, ov::op::v8)
_OPENVINO_OP_REG(DetectionOutput, ov::op::v8)
_OPENVINO_OP_REG(I420toBGR, ov::op::v8)
_OPENVINO_OP_REG(I420toRGB, ov::op::v8)
_OPENVINO_OP_REG(MatrixNms, ov::op::v8)
_OPENVINO_OP_REG(MaxPool, ov::op::v8)
_OPENVINO_OP_REG(NV12toBGR, ov::op::v8)
_OPENVINO_OP_REG(NV12toRGB, ov::op::v8)
_OPENVINO_OP_REG(RandomUniform, ov::op::v8)
_OPENVINO_OP_REG(Slice, ov::op::v8)
_OPENVINO_OP_REG(Softmax, ov::op::v8)
_OPENVINO_OP_REG(If, ov::op::v8)
_OPENVINO_OP_REG(PriorBox, ov::op::v8)
// New operations added in opset9
_OPENVINO_OP_REG(IRDFT, ov::op::v9)
_OPENVINO_OP_REG(RDFT, ov::op::v9)
_OPENVINO_OP_REG(Eye, ov::op::v9)
_OPENVINO_OP_REG(NonMaxSuppression, ov::op::v9)
_OPENVINO_OP_REG(ROIAlign, ov::op::v9)
_OPENVINO_OP_REG(SoftSign, ov::op::v9)
_OPENVINO_OP_REG(GenerateProposals, ov::op::v9)
_OPENVINO_OP_REG(MulticlassNms, ov::op::v9)
// New operations added in opset10
_OPENVINO_OP_REG(IsFinite, ov::op::v10)
_OPENVINO_OP_REG(IsInf, ov::op::v10)
_OPENVINO_OP_REG(IsNaN, ov::op::v10)
_OPENVINO_OP_REG(Unique, ov::op::v10)
// New operations added in opset11
_OPENVINO_OP_REG(Interpolate, ov::op::v11)
_OPENVINO_OP_REG(TopK, ov::op::v11)
// New operations added in opset12
_OPENVINO_OP_REG(GroupNormalization, ov::op::v12)

View File

@ -0,0 +1,62 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <array>
#include <openvino/core/validation_util.hpp>
#include <openvino/op/group_normalization.hpp>
namespace ov {
namespace op {
namespace v12 {
template <class TShape>
std::vector<TShape> shape_infer(const GroupNormalization* op, const std::vector<TShape>& input_shapes) {
const auto& inputs_count = input_shapes.size();
NODE_VALIDATION_CHECK(op, (inputs_count == 3));
const auto& data_shape = input_shapes[0];
const auto& data_rank = data_shape.rank();
const auto& scale_shape = input_shapes[1];
const auto& bias_shape = input_shapes[2];
NODE_VALIDATION_CHECK(op, op->get_num_groups() > 0, "The number of groups needs to be a positive integer value");
NODE_VALIDATION_CHECK(op, scale_shape.rank().compatible(1), "The scale input is required to be 1D");
NODE_VALIDATION_CHECK(op, bias_shape.rank().compatible(1), "The bias input is required to be 1D");
if (data_rank.is_static()) {
NODE_VALIDATION_CHECK(op, data_rank.get_length() >= 2, "The input tensor is required to be at least 2D");
const auto& channels_dim = data_shape[1];
NODE_VALIDATION_CHECK(op,
scale_shape.rank().is_dynamic() || channels_dim.compatible(scale_shape[0]),
"The scale input shape needs to match the channel dimension in the data input");
NODE_VALIDATION_CHECK(op,
bias_shape.rank().is_dynamic() || channels_dim.compatible(bias_shape[0]),
"The bias input shape needs to match the channel dimension in the data input");
NODE_VALIDATION_CHECK(op,
channels_dim.is_dynamic() || op->get_num_groups() <= channels_dim.get_length(),
"The number of groups must not exceed the number of channels in the input tensor");
NODE_VALIDATION_CHECK(op,
channels_dim.is_dynamic() || channels_dim.get_length() % op->get_num_groups() == 0,
"The number of channels is required to be evenly divisible by the number of groups");
}
NODE_VALIDATION_CHECK(op,
scale_shape.compatible(bias_shape),
"The shapes of both scale and bias inputs need to match");
return {input_shapes[0]};
}
template <class TShape>
void shape_infer(const GroupNormalization* op,
const std::vector<TShape>& input_shapes,
std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
}
} // namespace v12
} // namespace op
} // namespace ov

View File

@ -0,0 +1,49 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/group_normalization.hpp"
#include "group_normalization_shape_inference.hpp"
#include "itt.hpp"
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/core/validation_util.hpp"
namespace ov {
op::v12::GroupNormalization::GroupNormalization(const Output<Node>& data,
const Output<Node>& scale,
const Output<Node>& bias,
int64_t num_groups,
double epsilon)
: Op({data, scale, bias}),
m_num_groups{num_groups},
m_epsilon{epsilon} {
constructor_validate_and_infer_types();
}
bool op::v12::GroupNormalization::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v12_GroupNormalization_visit_attributes);
visitor.on_attribute("num_groups", m_num_groups);
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::v12::GroupNormalization::validate_and_infer_types() {
OV_OP_SCOPE(v12_GroupNormalization_validate_and_infer_types);
OPENVINO_SUPPRESS_DEPRECATED_START
const auto output_shapes = shape_infer(this, get_node_input_partial_shapes(*this));
OPENVINO_SUPPRESS_DEPRECATED_END
set_output_type(0, get_input_element_type(0), output_shapes.at(0));
}
std::shared_ptr<Node> op::v12::GroupNormalization::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v12_GroupNormalization_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<GroupNormalization>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_num_groups,
m_epsilon);
}
} // namespace ov

View File

@ -61,7 +61,8 @@ const std::map<std::string, std::function<const ngraph::OpSet&()>>& ngraph::get_
_NGRAPH_REG_OPSET(opset8),
_NGRAPH_REG_OPSET(opset9),
_NGRAPH_REG_OPSET(opset10),
_NGRAPH_REG_OPSET(opset11)};
_NGRAPH_REG_OPSET(opset11),
_NGRAPH_REG_OPSET(opset12)};
#undef _NGRAPH_REG_OPSET
return opset_map;
}
@ -79,7 +80,8 @@ const std::map<std::string, std::function<const ov::OpSet&()>>& ov::get_availabl
_OPENVINO_REG_OPSET(opset8),
_OPENVINO_REG_OPSET(opset9),
_OPENVINO_REG_OPSET(opset10),
_OPENVINO_REG_OPSET(opset11)};
_OPENVINO_REG_OPSET(opset11),
_OPENVINO_REG_OPSET(opset12)};
#undef _OPENVINO_REG_OPSET
return opset_map;
}
@ -205,6 +207,17 @@ const ov::OpSet& ov::get_opset11() {
return opset;
}
const ov::OpSet& ov::get_opset12() {
static OpSet opset;
static std::once_flag flag;
std::call_once(flag, [&]() {
#define _OPENVINO_OP_REG(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "openvino/opsets/opset12_tbl.hpp"
#undef _OPENVINO_OP_REG
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset1() {
static OpSet opset(ov::get_opset1());
return opset;
@ -259,3 +272,8 @@ const ngraph::OpSet& ngraph::get_opset11() {
static OpSet opset(ov::get_opset11());
return opset;
}
const ngraph::OpSet& ngraph::get_opset12() {
static OpSet opset(ov::get_opset12());
return opset;
}

View File

@ -67,6 +67,7 @@ _OPENVINO_OP_REG(Greater, ngraph::op::v1)
_OPENVINO_OP_REG(GreaterEqual, ngraph::op::v1)
_OPENVINO_OP_REG(GroupConvolution, ngraph::op::v1)
_OPENVINO_OP_REG(GroupConvolutionBackpropData, ngraph::op::v1)
_OPENVINO_OP_REG(GroupNormalization, ngraph::op::v12)
_OPENVINO_OP_REG(HardSigmoid, ngraph::op::v0)
_OPENVINO_OP_REG(Interpolate, ngraph::op::v0)
_OPENVINO_OP_REG(Interpolate, ngraph::op::v4)

View File

@ -10,6 +10,7 @@
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset11.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/opsets/opset2.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/opsets/opset4.hpp"
@ -67,7 +68,8 @@ INSTANTIATE_TEST_SUITE_P(opset,
OpsetTestParams{ov::get_opset8, 167},
OpsetTestParams{ov::get_opset9, 173},
OpsetTestParams{ov::get_opset10, 177},
OpsetTestParams{ov::get_opset11, 177}),
OpsetTestParams{ov::get_opset11, 177},
OpsetTestParams{ov::get_opset12, 178}),
OpsetTestNameGenerator{});
class MyOpOld : public ov::op::Op {

View File

@ -0,0 +1,204 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gtest/gtest.h"
#include "openvino/openvino.hpp"
#include "openvino/opsets/opset12.hpp"
#include "util/type_prop.hpp"
using namespace ov;
using namespace testing;
TEST(type_prop, group_normalization_basic) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, Shape{1, 12, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 4, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_shape(), (Shape{1, 12, 6, 6}));
}
TEST(type_prop, group_normalization_labels) {
auto data_shape = PartialShape{1, 12, 6, 6};
auto scale_shape = PartialShape{12};
auto bias_shape = PartialShape{12};
set_shape_labels(data_shape, 43);
set_shape_labels(scale_shape, 100);
set_shape_labels(bias_shape, 200);
const auto data = std::make_shared<opset12::Parameter>(element::f32, data_shape);
const auto scale = std::make_shared<opset12::Parameter>(element::f32, scale_shape);
const auto bias = std::make_shared<opset12::Parameter>(element::f32, bias_shape);
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 4, 0.00001f);
EXPECT_THAT(get_shape_labels(gn->get_output_partial_shape(0)), ElementsAre(43, 44, 45, 46));
}
TEST(type_prop, group_normalization_dynamic_channels) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{1, -1, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 2, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_output_partial_shape(0), (PartialShape{1, -1, 6, 6}));
}
TEST(type_prop, group_normalization_dynamic_scale) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{1, 4, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, PartialShape{-1});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, PartialShape{4});
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 2, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_output_partial_shape(0), (PartialShape{1, 4, 6, 6}));
}
TEST(type_prop, group_normalization_dynamic_bias) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{1, 4, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, PartialShape{4});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, PartialShape{-1});
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 2, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_output_partial_shape(0), (PartialShape{1, 4, 6, 6}));
}
TEST(type_prop, group_normalization_dynamic_rank) {
const auto data = std::make_shared<opset12::Parameter>(element::f16, PartialShape::dynamic());
const auto scale = std::make_shared<opset12::Parameter>(element::f16, PartialShape{6});
const auto bias = std::make_shared<opset12::Parameter>(element::f16, PartialShape{6});
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 3, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f16);
EXPECT_EQ(gn->get_output_partial_shape(0), (PartialShape::dynamic()));
}
TEST(type_prop, group_normalization_dynamic_everything) {
const auto data = std::make_shared<opset12::Parameter>(element::f16, PartialShape{3, -1, 10, 10});
const auto scale = std::make_shared<opset12::Parameter>(element::f16, PartialShape{-1});
const auto bias = std::make_shared<opset12::Parameter>(element::f16, PartialShape{-1});
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 7, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f16);
EXPECT_EQ(gn->get_output_partial_shape(0), (PartialShape{3, -1, 10, 10}));
}
TEST(type_prop, group_normalization_dynamic_ranks) {
const auto data = std::make_shared<opset12::Parameter>(element::f16, PartialShape::dynamic());
const auto scale = std::make_shared<opset12::Parameter>(element::f16, PartialShape::dynamic());
const auto bias = std::make_shared<opset12::Parameter>(element::f16, PartialShape::dynamic());
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 12, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f16);
EXPECT_EQ(gn->get_output_partial_shape(0), (PartialShape::dynamic()));
}
TEST(type_prop, group_normalization_dynamic_intervals) {
auto data_shape = PartialShape{2, Dimension{10, 20}, 6, 6};
auto scale_shape = PartialShape{Dimension{10, 20}};
auto bias_shape = PartialShape{Dimension{10, 20}};
set_shape_labels(data_shape, 42);
set_shape_labels(scale_shape, 21);
set_shape_labels(bias_shape, 37);
const auto data = std::make_shared<opset12::Parameter>(element::f32, data_shape);
const auto scale = std::make_shared<opset12::Parameter>(element::f32, scale_shape);
const auto bias = std::make_shared<opset12::Parameter>(element::f32, bias_shape);
const auto gn = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 2, 0.00001f);
EXPECT_EQ(gn->get_element_type(), element::f32);
EXPECT_EQ(gn->get_output_partial_shape(0), (PartialShape{2, Dimension{10, 20}, 6, 6}));
EXPECT_THAT(get_shape_labels(gn->get_output_partial_shape(0)), ElementsAre(42, 43, 44, 45));
}
TEST(type_prop, group_normalization_incorrect_scale_shape) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, Shape{1, 12, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{13});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 4, 0.00001f),
NodeValidationFailure,
HasSubstr("The scale input shape needs to match the channel dimension in the data input"));
}
TEST(type_prop, group_normalization_incorrect_bias_shape) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, Shape{1, 12, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{14});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 4, 0.00001f),
NodeValidationFailure,
HasSubstr("The bias input shape needs to match the channel dimension in the data input"));
}
TEST(type_prop, group_normalization_incompatible_scale_and_bias) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{1, -1, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{2});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{4});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 2, 0.00001f),
NodeValidationFailure,
HasSubstr("The shapes of both scale and bias inputs need to match"));
}
TEST(type_prop, group_normalization_incorrect_scale_rank) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, Shape{1, 12, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{12, 12});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 4, 0.00001f),
NodeValidationFailure,
HasSubstr("The scale input is required to be 1D"));
}
TEST(type_prop, group_normalization_incorrect_bias_rank) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, Shape{1, 12, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{12});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{3, 14});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 4, 0.00001f),
NodeValidationFailure,
HasSubstr("The bias input is required to be 1D"));
}
TEST(type_prop, group_normalization_incompatible_channels_and_groups) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{1, 10, 6, 6});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{10});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{10});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 3, 0.00001f),
NodeValidationFailure,
HasSubstr("The number of channels is required to be evenly divisible by the number of groups"));
}
TEST(type_prop, group_normalization_incorrect_data_rank) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{10});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{1});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{1});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 2, 0.00001f),
NodeValidationFailure,
HasSubstr("The input tensor is required to be at least 2D"));
}
TEST(type_prop, group_normalization_negative_num_groups) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{1, 10});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{10});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{10});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, -3, 0.00001f),
NodeValidationFailure,
HasSubstr("The number of groups needs to be a positive integer value"));
}
TEST(type_prop, group_normalization_too_many_groups) {
const auto data = std::make_shared<opset12::Parameter>(element::f32, PartialShape{1, 10});
const auto scale = std::make_shared<opset12::Parameter>(element::f32, Shape{10});
const auto bias = std::make_shared<opset12::Parameter>(element::f32, Shape{10});
OV_EXPECT_THROW(std::ignore = std::make_shared<opset12::GroupNormalization>(data, scale, bias, 11, 0.00001f),
NodeValidationFailure,
HasSubstr("The number of groups must not exceed the number of channels in the input tensor"));
}

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "openvino/opsets/opset12.hpp"
#include "util/visitor.hpp"
using namespace std;
using namespace ov;
using ngraph::test::NodeBuilder;
TEST(attributes, group_normalization) {
NodeBuilder::get_ops().register_factory<opset12::Unique>();
const auto data = make_shared<opset12::Parameter>(element::f32, Shape{1, 3, 10, 10});
const auto scale = make_shared<opset12::Parameter>(element::f32, Shape{3});
const auto bias = make_shared<opset12::Parameter>(element::f32, Shape{3});
const auto op = make_shared<opset12::GroupNormalization>(data, scale, bias, 3, 0.00001f);
NodeBuilder builder(op);
auto g_unique = ov::as_type_ptr<opset12::GroupNormalization>(builder.create());
constexpr auto expected_attr_count = 2;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
EXPECT_EQ(op->get_num_groups(), g_unique->get_num_groups());
EXPECT_NEAR(op->get_epsilon(), g_unique->get_epsilon(), 0.00001f);
}

View File

@ -25,7 +25,7 @@ inline const ov::OpSet& get_opset_by_name(const std::string& opset_name) {
if (opsets.find(opset_name) != opsets.end())
return opsets.at(opset_name)();
if (opset_name.empty() || opset_name == "latest") {
return ov::get_opset11();
return ov::get_opset12();
} else {
FRONT_END_GENERAL_CHECK(false, "Unsupported opset name: ", opset_name);
}

View File

@ -90,6 +90,15 @@ std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v5::BatchNormI
"BatchNormInterferenceGraph");
}
std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v12::GroupNormalization>& node) {
const auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{3, 14, 5, 5});
const auto scale = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{14});
const auto bias = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{14});
const auto gn = std::make_shared<ov::op::v12::GroupNormalization>(data, scale, bias, 7, 0.00001f);
const ov::ResultVector results{std::make_shared<ov::op::v0::Result>(gn)};
return std::make_shared<ov::Model>(results, ov::ParameterVector{data, scale, bias}, "GroupNormalizationGraph");
}
std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v1::BatchToSpace> &node) {
const auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{4, 1, 1, 3});
const auto block_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, {1, 1, 1, 2});
@ -1884,6 +1893,7 @@ OpGenerator getOpGeneratorMap() {
#include "openvino/opsets/opset9_tbl.hpp"
#include "openvino/opsets/opset10_tbl.hpp"
#include "openvino/opsets/opset11_tbl.hpp"
#include "openvino/opsets/opset12_tbl.hpp"
#undef _OPENVINO_OP_REG
};
return opGeneratorMap;