Add MVN-6 support to ngraph (#3464)
* Add MVN-6 to ngraph * Apply review feedback * Fix max opset number * Fix code style * Fix shape test * Disable reader test * Apply review feedback and remove reader test * Fix code style * Fix build * Apply review feedback * Fix build problem * Fix code style * Fix build
This commit is contained in:
parent
a569a0b529
commit
ab974e4f2e
@ -22,6 +22,7 @@
|
|||||||
#include <ngraph/opsets/opset2.hpp>
|
#include <ngraph/opsets/opset2.hpp>
|
||||||
#include <ngraph/opsets/opset3.hpp>
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
#include <ngraph/opsets/opset5.hpp>
|
#include <ngraph/opsets/opset5.hpp>
|
||||||
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
#include <ngraph/variant.hpp>
|
#include <ngraph/variant.hpp>
|
||||||
|
|
||||||
#include <cpp/ie_cnn_network.h>
|
#include <cpp/ie_cnn_network.h>
|
||||||
@ -70,6 +71,7 @@ V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) : _exts(exts) {
|
|||||||
opsets["opset3"] = ngraph::get_opset3();
|
opsets["opset3"] = ngraph::get_opset3();
|
||||||
opsets["opset4"] = ngraph::get_opset4();
|
opsets["opset4"] = ngraph::get_opset4();
|
||||||
opsets["opset5"] = ngraph::get_opset5();
|
opsets["opset5"] = ngraph::get_opset5();
|
||||||
|
opsets["opset6"] = ngraph::get_opset6();
|
||||||
|
|
||||||
// Load custom opsets
|
// Load custom opsets
|
||||||
for (const auto& ext : exts) {
|
for (const auto& ext : exts) {
|
||||||
@ -427,7 +429,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
|||||||
|
|
||||||
// Check that operation in default opsets
|
// Check that operation in default opsets
|
||||||
auto isDefaultOpSet = [](const std::string& version) -> bool {
|
auto isDefaultOpSet = [](const std::string& version) -> bool {
|
||||||
for (size_t i = 1; i <= 5; i++) {
|
for (size_t i = 1; i <= 6; i++) {
|
||||||
std::string opset_name = "opset" + std::to_string(i);
|
std::string opset_name = "opset" + std::to_string(i);
|
||||||
if (version == opset_name)
|
if (version == opset_name)
|
||||||
return true;
|
return true;
|
||||||
|
@ -20,12 +20,12 @@
|
|||||||
#include "ngraph/op/op.hpp"
|
#include "ngraph/op/op.hpp"
|
||||||
#include "ngraph/op/util/fused_op.hpp"
|
#include "ngraph/op/util/fused_op.hpp"
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph
|
||||||
{
|
{
|
||||||
namespace op
|
namespace op
|
||||||
{
|
{
|
||||||
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
|
|
||||||
namespace v0
|
namespace v0
|
||||||
{
|
{
|
||||||
/// \brief Operator performing Mean Variance Normalization
|
/// \brief Operator performing Mean Variance Normalization
|
||||||
@ -87,7 +87,75 @@ namespace ngraph
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
using v0::MVN;
|
using v0::MVN;
|
||||||
} // namespace op
|
|
||||||
} // namespace ngraph
|
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
|
/// \brief Specifies how eps is applied in MVN
|
||||||
|
enum class MVNEpsMode
|
||||||
|
{
|
||||||
|
// Apply eps inside sqrt
|
||||||
|
INSIDE_SQRT,
|
||||||
|
// Apply eps outside sqrt
|
||||||
|
OUTSIDE_SQRT
|
||||||
|
};
|
||||||
|
|
||||||
|
NGRAPH_API
|
||||||
|
std::ostream& operator<<(std::ostream& s, const MVNEpsMode& type);
|
||||||
|
|
||||||
|
namespace v6
|
||||||
|
{
|
||||||
|
/// \brief Operator performing Mean Variance Normalization
|
||||||
|
///
|
||||||
|
class NGRAPH_API MVN : public ngraph::op::Op
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
MVN() = default;
|
||||||
|
/// \brief Constructs an MVN operation.
|
||||||
|
///
|
||||||
|
/// \param data Input tensor with data
|
||||||
|
/// \param reduction_axes A list of axes, along which to reduce.
|
||||||
|
/// \param normalize_variance flag that denotes whether to perform variance
|
||||||
|
/// normalization.
|
||||||
|
/// \param eps the number to be added to the variance to avoid division by zero when
|
||||||
|
/// normalizing the value
|
||||||
|
/// \param eps_mode the mode of applying epsilon
|
||||||
|
///
|
||||||
|
MVN(const Output<Node>& data,
|
||||||
|
const Output<Node>& reduction_axes,
|
||||||
|
bool normalize_variance,
|
||||||
|
float eps,
|
||||||
|
MVNEpsMode eps_mode);
|
||||||
|
|
||||||
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
|
std::shared_ptr<Node>
|
||||||
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
|
float get_eps() const { return m_eps; }
|
||||||
|
bool get_normalize_variance() const { return m_normalize_variance; }
|
||||||
|
MVNEpsMode get_eps_mode() const { return m_eps_mode; }
|
||||||
|
private:
|
||||||
|
bool m_normalize_variance = true;
|
||||||
|
float m_eps = (float)1e-6;
|
||||||
|
MVNEpsMode m_eps_mode = MVNEpsMode::INSIDE_SQRT;
|
||||||
|
};
|
||||||
|
} // namespace v6
|
||||||
|
} // namespace op
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class NGRAPH_API AttributeAdapter<op::MVNEpsMode>
|
||||||
|
: public EnumAttributeAdapterBase<op::MVNEpsMode>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AttributeAdapter(op::MVNEpsMode& value)
|
||||||
|
: EnumAttributeAdapterBase<op::MVNEpsMode>(value)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::MVNEpsMode>", 0};
|
||||||
|
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||||
|
};
|
||||||
|
} // namespace ngraph
|
||||||
|
@ -133,4 +133,5 @@ namespace ngraph
|
|||||||
const NGRAPH_API OpSet& get_opset3();
|
const NGRAPH_API OpSet& get_opset3();
|
||||||
const NGRAPH_API OpSet& get_opset4();
|
const NGRAPH_API OpSet& get_opset4();
|
||||||
const NGRAPH_API OpSet& get_opset5();
|
const NGRAPH_API OpSet& get_opset5();
|
||||||
|
const NGRAPH_API OpSet& get_opset6();
|
||||||
}
|
}
|
||||||
|
29
ngraph/core/include/ngraph/opsets/opset6.hpp
Normal file
29
ngraph/core/include/ngraph/opsets/opset6.hpp
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ngraph/ops.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace opset6
|
||||||
|
{
|
||||||
|
#define NGRAPH_OP(a, b) using b::a;
|
||||||
|
#include "ngraph/opsets/opset6_tbl.hpp"
|
||||||
|
#undef NGRAPH_OP
|
||||||
|
} // namespace opset6
|
||||||
|
} // namespace ngraph
|
176
ngraph/core/include/ngraph/opsets/opset6_tbl.hpp
Normal file
176
ngraph/core/include/ngraph/opsets/opset6_tbl.hpp
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#ifndef NGRAPH_OP
|
||||||
|
#warning "NGRAPH_OP not defined"
|
||||||
|
#define NGRAPH_OP(x, y)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NGRAPH_OP(Abs, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Acos, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Add, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Asin, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Atan, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(AvgPool, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(BatchNormInference, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Broadcast, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Bucketize, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Ceiling, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Clamp, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Concat, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Constant, ngraph::op)
|
||||||
|
NGRAPH_OP(Convert, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ConvertLike, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Convolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Cos, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Cosh, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(CumSum, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(DeformableConvolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(DeformablePSROIPooling, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(DepthToSpace, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(DetectionOutput, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Divide, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Elu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Erf, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Equal, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Exp, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ExtractImagePatches, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(FakeQuantize, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Floor, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(FloorMod, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Gather, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GatherTree, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Greater, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GreaterEqual, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GroupConvolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GRN, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(HardSigmoid, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Less, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LessEqual, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Log, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LogicalNot, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LogicalOr, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LogicalXor, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LRN, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(LSTMCell, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(MatMul, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(MaxPool, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Maximum, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Minimum, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Mod, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Multiply, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Negative, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(NormalizeL2, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(NotEqual, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(OneHot, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(PRelu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(PSROIPooling, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Pad, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Parameter, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Power, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(PriorBox, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(PriorBoxClustered, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Proposal, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Range, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Relu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ReduceMax, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceLogicalAnd, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceLogicalOr, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceMean, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceMin, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceProd, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceSum, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(RegionYolo, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ReorgYolo, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Reshape, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Result, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ROIPooling, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ScatterNDUpdate, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Select, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Selu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sign, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sigmoid, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sin, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sinh, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Softmax, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Sqrt, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(SpaceToDepth, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Split, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(SquaredDifference, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Squeeze, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(StridedSlice, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Subtract, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Tan, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Tanh, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(TensorIterator, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Tile, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Transpose, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
|
||||||
|
|
||||||
|
// New operations added in opset2
|
||||||
|
NGRAPH_OP(Gelu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(BatchToSpace, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(SpaceToBatch, ngraph::op::v1)
|
||||||
|
|
||||||
|
// New operations added in opset3
|
||||||
|
NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(GRUCell, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(NonZero, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(RNNCell, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ROIAlign, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ShapeOf, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Assign, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ReadValue, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(TopK, ngraph::op::v3)
|
||||||
|
|
||||||
|
// New operations added in opset4
|
||||||
|
NGRAPH_OP(Acosh, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Asinh, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Atanh, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(HSwish, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Interpolate, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Mish, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(ReduceL1, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(ReduceL2, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(SoftPlus, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Swish, ngraph::op::v4)
|
||||||
|
|
||||||
|
// New operations added in opset5
|
||||||
|
NGRAPH_OP(GatherND, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(GRUSequence, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(HSigmoid, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(Loop, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(NonMaxSuppression, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(RNNSequence, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(Round, ngraph::op::v5)
|
||||||
|
|
||||||
|
// New operations added in opset6
|
||||||
|
NGRAPH_OP(MVN, ngraph::op::v6)
|
@ -28,6 +28,8 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
|
// ------------------------------ V0 ------------------------------
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0);
|
NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0);
|
||||||
@ -119,3 +121,80 @@ bool op::MVN::visit_attributes(AttributeVisitor& visitor)
|
|||||||
visitor.on_attribute("reduction_axes", m_reduction_axes);
|
visitor.on_attribute("reduction_axes", m_reduction_axes);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ------------------------------ V6 ------------------------------
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
template <>
|
||||||
|
NGRAPH_API EnumNames<op::MVNEpsMode>& EnumNames<op::MVNEpsMode>::get()
|
||||||
|
{
|
||||||
|
static auto enum_names =
|
||||||
|
EnumNames<op::MVNEpsMode>("op::MVNEpsMode",
|
||||||
|
{{"OUTSIDE_SQRT", op::MVNEpsMode::OUTSIDE_SQRT},
|
||||||
|
{"INSIDE_SQRT", op::MVNEpsMode::INSIDE_SQRT}});
|
||||||
|
return enum_names;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr DiscreteTypeInfo AttributeAdapter<op::MVNEpsMode>::type_info;
|
||||||
|
|
||||||
|
std::ostream& op::operator<<(std::ostream& s, const op::MVNEpsMode& type)
|
||||||
|
{
|
||||||
|
return s << as_string(type);
|
||||||
|
}
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(op::v6::MVN, "MVN", 6);
|
||||||
|
|
||||||
|
op::v6::MVN::MVN(const Output<Node>& data,
|
||||||
|
const Output<Node>& reduction_axes,
|
||||||
|
bool normalize_variance,
|
||||||
|
float eps,
|
||||||
|
MVNEpsMode eps_mode)
|
||||||
|
: Op({data, reduction_axes})
|
||||||
|
, m_eps{eps}
|
||||||
|
, m_normalize_variance{normalize_variance}
|
||||||
|
, m_eps_mode{eps_mode}
|
||||||
|
{
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v6::MVN::validate_and_infer_types()
|
||||||
|
{
|
||||||
|
const auto data = get_input_partial_shape(0);
|
||||||
|
const auto axes = get_input_partial_shape(1);
|
||||||
|
|
||||||
|
if (axes.is_static())
|
||||||
|
{
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
is_vector(axes.to_shape()),
|
||||||
|
"Expected 1D tensor for the 'axes' input. Got: ",
|
||||||
|
axes);
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
this,
|
||||||
|
data.rank().is_dynamic() || data.rank().get_length() >= axes.get_shape()[0],
|
||||||
|
"Expected rank for the 'data' input to be higher than axes shape. Got: ",
|
||||||
|
data);
|
||||||
|
}
|
||||||
|
|
||||||
|
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Node> op::v6::MVN::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
|
{
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
new_args.size() == 2,
|
||||||
|
"Expected 2 element in new_args for the MVN op but got ",
|
||||||
|
new_args.size());
|
||||||
|
return make_shared<op::v6::MVN>(
|
||||||
|
new_args.at(0), new_args.at(1), m_normalize_variance, m_eps, m_eps_mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v6::MVN::visit_attributes(AttributeVisitor& visitor)
|
||||||
|
{
|
||||||
|
visitor.on_attribute("eps", m_eps);
|
||||||
|
visitor.on_attribute("normalize_variance", m_normalize_variance);
|
||||||
|
visitor.on_attribute("eps_mode", m_eps_mode);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
@ -128,6 +128,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
return s << as_string(type);
|
return s << as_string(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
NGRAPH_API EnumNames<op::TopKSortType>& EnumNames<op::TopKSortType>::get()
|
NGRAPH_API EnumNames<op::TopKSortType>& EnumNames<op::TopKSortType>::get()
|
||||||
{
|
{
|
||||||
|
@ -138,3 +138,22 @@ const ngraph::OpSet& ngraph::get_opset5()
|
|||||||
}
|
}
|
||||||
return opset;
|
return opset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ngraph::OpSet& ngraph::get_opset6()
|
||||||
|
{
|
||||||
|
static std::mutex init_mutex;
|
||||||
|
static bool opset_is_initialized = false;
|
||||||
|
static OpSet opset;
|
||||||
|
if (!opset_is_initialized)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(init_mutex);
|
||||||
|
if (!opset_is_initialized)
|
||||||
|
{
|
||||||
|
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||||
|
#include "ngraph/opsets/opset6_tbl.hpp"
|
||||||
|
#undef NGRAPH_OP
|
||||||
|
opset_is_initialized = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return opset;
|
||||||
|
}
|
||||||
|
@ -21,6 +21,8 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
|
// ------------------------------ V0 ------------------------------
|
||||||
|
|
||||||
TEST(type_prop, mvn)
|
TEST(type_prop, mvn)
|
||||||
{
|
{
|
||||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
||||||
@ -47,3 +49,35 @@ TEST(type_prop, mvn_partial)
|
|||||||
EXPECT_EQ(mvn_partial->get_reduction_axes(), AxisSet{});
|
EXPECT_EQ(mvn_partial->get_reduction_axes(), AxisSet{});
|
||||||
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ------------------------------ V6 ------------------------------
|
||||||
|
|
||||||
|
TEST(type_prop, mvn_6)
|
||||||
|
{
|
||||||
|
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
|
||||||
|
auto axes = make_shared<op::Parameter>(element::i64, Shape{3});
|
||||||
|
|
||||||
|
auto mvn_func = make_shared<op::v6::MVN>(data, axes, true, 1e-6, op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(mvn_func->get_shape(), (Shape{1, 2, 3, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, mvn_6_partial)
|
||||||
|
{
|
||||||
|
auto data =
|
||||||
|
make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 5, 6});
|
||||||
|
auto axes = make_shared<op::Parameter>(element::i64, Shape{3});
|
||||||
|
auto mvn_func = make_shared<op::v6::MVN>(data, axes, true, 1e-6, op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
|
||||||
|
ASSERT_TRUE(mvn_func->get_output_partial_shape(0).same_scheme(
|
||||||
|
(PartialShape{1, Dimension::dynamic(), 5, 6})));
|
||||||
|
|
||||||
|
// rank unknown
|
||||||
|
auto mvn_partial =
|
||||||
|
make_shared<op::v6::MVN>(make_shared<op::Parameter>(element::f32, PartialShape::dynamic()),
|
||||||
|
axes,
|
||||||
|
true,
|
||||||
|
1e-6,
|
||||||
|
op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user