Add MVN-6 support to ngraph (#3464)
* Add MVN-6 to ngraph * Apply review feedback * Fix max opset number * Fix code style * Fix shape test * Disable reader test * Apply review feedback and remove reader test * Fix code style * Fix build * Apply review feedback * Fix build problem * Fix code style * Fix build
This commit is contained in:
parent
a569a0b529
commit
ab974e4f2e
@ -22,6 +22,7 @@
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
@ -70,6 +71,7 @@ V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) : _exts(exts) {
|
||||
opsets["opset3"] = ngraph::get_opset3();
|
||||
opsets["opset4"] = ngraph::get_opset4();
|
||||
opsets["opset5"] = ngraph::get_opset5();
|
||||
opsets["opset6"] = ngraph::get_opset6();
|
||||
|
||||
// Load custom opsets
|
||||
for (const auto& ext : exts) {
|
||||
@ -427,7 +429,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
||||
|
||||
// Check that operation in default opsets
|
||||
auto isDefaultOpSet = [](const std::string& version) -> bool {
|
||||
for (size_t i = 1; i <= 5; i++) {
|
||||
for (size_t i = 1; i <= 6; i++) {
|
||||
std::string opset_name = "opset" + std::to_string(i);
|
||||
if (version == opset_name)
|
||||
return true;
|
||||
|
@ -20,12 +20,12 @@
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
namespace v0
|
||||
{
|
||||
/// \brief Operator performing Mean Variance Normalization
|
||||
@ -87,7 +87,75 @@ namespace ngraph
|
||||
};
|
||||
}
|
||||
using v0::MVN;
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
||||
/// \brief Specifies how eps is applied in MVN
|
||||
enum class MVNEpsMode
|
||||
{
|
||||
// Apply eps inside sqrt
|
||||
INSIDE_SQRT,
|
||||
// Apply eps outside sqrt
|
||||
OUTSIDE_SQRT
|
||||
};
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const MVNEpsMode& type);
|
||||
|
||||
namespace v6
|
||||
{
|
||||
/// \brief Operator performing Mean Variance Normalization
|
||||
///
|
||||
class NGRAPH_API MVN : public ngraph::op::Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
MVN() = default;
|
||||
/// \brief Constructs an MVN operation.
|
||||
///
|
||||
/// \param data Input tensor with data
|
||||
/// \param reduction_axes A list of axes, along which to reduce.
|
||||
/// \param normalize_variance flag that denotes whether to perform variance
|
||||
/// normalization.
|
||||
/// \param eps the number to be added to the variance to avoid division by zero when
|
||||
/// normalizing the value
|
||||
/// \param eps_mode the mode of applying epsilon
|
||||
///
|
||||
MVN(const Output<Node>& data,
|
||||
const Output<Node>& reduction_axes,
|
||||
bool normalize_variance,
|
||||
float eps,
|
||||
MVNEpsMode eps_mode);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
float get_eps() const { return m_eps; }
|
||||
bool get_normalize_variance() const { return m_normalize_variance; }
|
||||
MVNEpsMode get_eps_mode() const { return m_eps_mode; }
|
||||
private:
|
||||
bool m_normalize_variance = true;
|
||||
float m_eps = (float)1e-6;
|
||||
MVNEpsMode m_eps_mode = MVNEpsMode::INSIDE_SQRT;
|
||||
};
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<op::MVNEpsMode>
|
||||
: public EnumAttributeAdapterBase<op::MVNEpsMode>
|
||||
{
|
||||
public:
|
||||
AttributeAdapter(op::MVNEpsMode& value)
|
||||
: EnumAttributeAdapterBase<op::MVNEpsMode>(value)
|
||||
{
|
||||
}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::MVNEpsMode>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
} // namespace ngraph
|
||||
|
@ -133,4 +133,5 @@ namespace ngraph
|
||||
const NGRAPH_API OpSet& get_opset3();
|
||||
const NGRAPH_API OpSet& get_opset4();
|
||||
const NGRAPH_API OpSet& get_opset5();
|
||||
const NGRAPH_API OpSet& get_opset6();
|
||||
}
|
||||
|
29
ngraph/core/include/ngraph/opsets/opset6.hpp
Normal file
29
ngraph/core/include/ngraph/opsets/opset6.hpp
Normal file
@ -0,0 +1,29 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/ops.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace opset6
|
||||
{
|
||||
#define NGRAPH_OP(a, b) using b::a;
|
||||
#include "ngraph/opsets/opset6_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
} // namespace opset6
|
||||
} // namespace ngraph
|
176
ngraph/core/include/ngraph/opsets/opset6_tbl.hpp
Normal file
176
ngraph/core/include/ngraph/opsets/opset6_tbl.hpp
Normal file
@ -0,0 +1,176 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#ifndef NGRAPH_OP
|
||||
#warning "NGRAPH_OP not defined"
|
||||
#define NGRAPH_OP(x, y)
|
||||
#endif
|
||||
|
||||
NGRAPH_OP(Abs, ngraph::op::v0)
|
||||
NGRAPH_OP(Acos, ngraph::op::v0)
|
||||
NGRAPH_OP(Add, ngraph::op::v1)
|
||||
NGRAPH_OP(Asin, ngraph::op::v0)
|
||||
NGRAPH_OP(Atan, ngraph::op::v0)
|
||||
NGRAPH_OP(AvgPool, ngraph::op::v1)
|
||||
NGRAPH_OP(BatchNormInference, ngraph::op::v5)
|
||||
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
|
||||
NGRAPH_OP(Broadcast, ngraph::op::v3)
|
||||
NGRAPH_OP(Bucketize, ngraph::op::v3)
|
||||
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
|
||||
NGRAPH_OP(Ceiling, ngraph::op::v0)
|
||||
NGRAPH_OP(Clamp, ngraph::op::v0)
|
||||
NGRAPH_OP(Concat, ngraph::op::v0)
|
||||
NGRAPH_OP(Constant, ngraph::op)
|
||||
NGRAPH_OP(Convert, ngraph::op::v0)
|
||||
NGRAPH_OP(ConvertLike, ngraph::op::v1)
|
||||
NGRAPH_OP(Convolution, ngraph::op::v1)
|
||||
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
|
||||
NGRAPH_OP(Cos, ngraph::op::v0)
|
||||
NGRAPH_OP(Cosh, ngraph::op::v0)
|
||||
NGRAPH_OP(CumSum, ngraph::op::v0)
|
||||
NGRAPH_OP(DeformableConvolution, ngraph::op::v1)
|
||||
NGRAPH_OP(DeformablePSROIPooling, ngraph::op::v1)
|
||||
NGRAPH_OP(DepthToSpace, ngraph::op::v0)
|
||||
NGRAPH_OP(DetectionOutput, ngraph::op::v0)
|
||||
NGRAPH_OP(Divide, ngraph::op::v1)
|
||||
NGRAPH_OP(Elu, ngraph::op::v0)
|
||||
NGRAPH_OP(Erf, ngraph::op::v0)
|
||||
NGRAPH_OP(Equal, ngraph::op::v1)
|
||||
NGRAPH_OP(Exp, ngraph::op::v0)
|
||||
NGRAPH_OP(ExtractImagePatches, ngraph::op::v3)
|
||||
NGRAPH_OP(FakeQuantize, ngraph::op::v0)
|
||||
NGRAPH_OP(Floor, ngraph::op::v0)
|
||||
NGRAPH_OP(FloorMod, ngraph::op::v1)
|
||||
NGRAPH_OP(Gather, ngraph::op::v1)
|
||||
NGRAPH_OP(GatherTree, ngraph::op::v1)
|
||||
NGRAPH_OP(Greater, ngraph::op::v1)
|
||||
NGRAPH_OP(GreaterEqual, ngraph::op::v1)
|
||||
NGRAPH_OP(GroupConvolution, ngraph::op::v1)
|
||||
NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op::v1)
|
||||
NGRAPH_OP(GRN, ngraph::op::v0)
|
||||
NGRAPH_OP(HardSigmoid, ngraph::op::v0)
|
||||
NGRAPH_OP(Less, ngraph::op::v1)
|
||||
NGRAPH_OP(LessEqual, ngraph::op::v1)
|
||||
NGRAPH_OP(Log, ngraph::op::v0)
|
||||
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
|
||||
NGRAPH_OP(LogicalNot, ngraph::op::v1)
|
||||
NGRAPH_OP(LogicalOr, ngraph::op::v1)
|
||||
NGRAPH_OP(LogicalXor, ngraph::op::v1)
|
||||
NGRAPH_OP(LRN, ngraph::op::v0)
|
||||
NGRAPH_OP(LSTMCell, ngraph::op::v4)
|
||||
NGRAPH_OP(MatMul, ngraph::op::v0)
|
||||
NGRAPH_OP(MaxPool, ngraph::op::v1)
|
||||
NGRAPH_OP(Maximum, ngraph::op::v1)
|
||||
NGRAPH_OP(Minimum, ngraph::op::v1)
|
||||
NGRAPH_OP(Mod, ngraph::op::v1)
|
||||
NGRAPH_OP(Multiply, ngraph::op::v1)
|
||||
NGRAPH_OP(Negative, ngraph::op::v0)
|
||||
NGRAPH_OP(NormalizeL2, ngraph::op::v0)
|
||||
NGRAPH_OP(NotEqual, ngraph::op::v1)
|
||||
NGRAPH_OP(OneHot, ngraph::op::v1)
|
||||
NGRAPH_OP(PRelu, ngraph::op::v0)
|
||||
NGRAPH_OP(PSROIPooling, ngraph::op::v0)
|
||||
NGRAPH_OP(Pad, ngraph::op::v1)
|
||||
NGRAPH_OP(Parameter, ngraph::op::v0)
|
||||
NGRAPH_OP(Power, ngraph::op::v1)
|
||||
NGRAPH_OP(PriorBox, ngraph::op::v0)
|
||||
NGRAPH_OP(PriorBoxClustered, ngraph::op::v0)
|
||||
NGRAPH_OP(Proposal, ngraph::op::v4)
|
||||
NGRAPH_OP(Range, ngraph::op::v4)
|
||||
NGRAPH_OP(Relu, ngraph::op::v0)
|
||||
NGRAPH_OP(ReduceMax, ngraph::op::v1)
|
||||
NGRAPH_OP(ReduceLogicalAnd, ngraph::op::v1)
|
||||
NGRAPH_OP(ReduceLogicalOr, ngraph::op::v1)
|
||||
NGRAPH_OP(ReduceMean, ngraph::op::v1)
|
||||
NGRAPH_OP(ReduceMin, ngraph::op::v1)
|
||||
NGRAPH_OP(ReduceProd, ngraph::op::v1)
|
||||
NGRAPH_OP(ReduceSum, ngraph::op::v1)
|
||||
NGRAPH_OP(RegionYolo, ngraph::op::v0)
|
||||
NGRAPH_OP(ReorgYolo, ngraph::op::v0)
|
||||
NGRAPH_OP(Reshape, ngraph::op::v1)
|
||||
NGRAPH_OP(Result, ngraph::op::v0)
|
||||
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
|
||||
NGRAPH_OP(ROIPooling, ngraph::op::v0)
|
||||
NGRAPH_OP(ScatterNDUpdate, ngraph::op::v3)
|
||||
NGRAPH_OP(Select, ngraph::op::v1)
|
||||
NGRAPH_OP(Selu, ngraph::op::v0)
|
||||
NGRAPH_OP(Sign, ngraph::op::v0)
|
||||
NGRAPH_OP(Sigmoid, ngraph::op::v0)
|
||||
NGRAPH_OP(Sin, ngraph::op::v0)
|
||||
NGRAPH_OP(Sinh, ngraph::op::v0)
|
||||
NGRAPH_OP(Softmax, ngraph::op::v1)
|
||||
NGRAPH_OP(Sqrt, ngraph::op::v0)
|
||||
NGRAPH_OP(SpaceToDepth, ngraph::op::v0)
|
||||
NGRAPH_OP(Split, ngraph::op::v1)
|
||||
NGRAPH_OP(SquaredDifference, ngraph::op::v0)
|
||||
NGRAPH_OP(Squeeze, ngraph::op::v0)
|
||||
NGRAPH_OP(StridedSlice, ngraph::op::v1)
|
||||
NGRAPH_OP(Subtract, ngraph::op::v1)
|
||||
NGRAPH_OP(Tan, ngraph::op::v0)
|
||||
NGRAPH_OP(Tanh, ngraph::op::v0)
|
||||
NGRAPH_OP(TensorIterator, ngraph::op::v0)
|
||||
NGRAPH_OP(Tile, ngraph::op::v0)
|
||||
NGRAPH_OP(Transpose, ngraph::op::v1)
|
||||
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
|
||||
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
|
||||
|
||||
// New operations added in opset2
|
||||
NGRAPH_OP(Gelu, ngraph::op::v0)
|
||||
NGRAPH_OP(BatchToSpace, ngraph::op::v1)
|
||||
NGRAPH_OP(SpaceToBatch, ngraph::op::v1)
|
||||
|
||||
// New operations added in opset3
|
||||
NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3)
|
||||
NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3)
|
||||
NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3)
|
||||
NGRAPH_OP(GRUCell, ngraph::op::v3)
|
||||
NGRAPH_OP(NonZero, ngraph::op::v3)
|
||||
NGRAPH_OP(RNNCell, ngraph::op::v0)
|
||||
NGRAPH_OP(ROIAlign, ngraph::op::v3)
|
||||
NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
|
||||
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
|
||||
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
|
||||
NGRAPH_OP(ShapeOf, ngraph::op::v3)
|
||||
NGRAPH_OP(Assign, ngraph::op::v3)
|
||||
NGRAPH_OP(ReadValue, ngraph::op::v3)
|
||||
NGRAPH_OP(TopK, ngraph::op::v3)
|
||||
|
||||
// New operations added in opset4
|
||||
NGRAPH_OP(Acosh, ngraph::op::v3)
|
||||
NGRAPH_OP(Asinh, ngraph::op::v3)
|
||||
NGRAPH_OP(Atanh, ngraph::op::v3)
|
||||
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||
NGRAPH_OP(HSwish, ngraph::op::v4)
|
||||
NGRAPH_OP(Interpolate, ngraph::op::v4)
|
||||
NGRAPH_OP(Mish, ngraph::op::v4)
|
||||
NGRAPH_OP(ReduceL1, ngraph::op::v4)
|
||||
NGRAPH_OP(ReduceL2, ngraph::op::v4)
|
||||
NGRAPH_OP(SoftPlus, ngraph::op::v4)
|
||||
NGRAPH_OP(Swish, ngraph::op::v4)
|
||||
|
||||
// New operations added in opset5
|
||||
NGRAPH_OP(GatherND, ngraph::op::v5)
|
||||
NGRAPH_OP(GRUSequence, ngraph::op::v5)
|
||||
NGRAPH_OP(HSigmoid, ngraph::op::v5)
|
||||
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
|
||||
NGRAPH_OP(Loop, ngraph::op::v5)
|
||||
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
|
||||
NGRAPH_OP(NonMaxSuppression, ngraph::op::v5)
|
||||
NGRAPH_OP(RNNSequence, ngraph::op::v5)
|
||||
NGRAPH_OP(Round, ngraph::op::v5)
|
||||
|
||||
// New operations added in opset6
|
||||
NGRAPH_OP(MVN, ngraph::op::v6)
|
@ -28,6 +28,8 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
// ------------------------------ V0 ------------------------------
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0);
|
||||
@ -119,3 +121,80 @@ bool op::MVN::visit_attributes(AttributeVisitor& visitor)
|
||||
visitor.on_attribute("reduction_axes", m_reduction_axes);
|
||||
return true;
|
||||
}
|
||||
|
||||
// ------------------------------ V6 ------------------------------
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
template <>
|
||||
NGRAPH_API EnumNames<op::MVNEpsMode>& EnumNames<op::MVNEpsMode>::get()
|
||||
{
|
||||
static auto enum_names =
|
||||
EnumNames<op::MVNEpsMode>("op::MVNEpsMode",
|
||||
{{"OUTSIDE_SQRT", op::MVNEpsMode::OUTSIDE_SQRT},
|
||||
{"INSIDE_SQRT", op::MVNEpsMode::INSIDE_SQRT}});
|
||||
return enum_names;
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<op::MVNEpsMode>::type_info;
|
||||
|
||||
std::ostream& op::operator<<(std::ostream& s, const op::MVNEpsMode& type)
|
||||
{
|
||||
return s << as_string(type);
|
||||
}
|
||||
} // namespace ngraph
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v6::MVN, "MVN", 6);
|
||||
|
||||
op::v6::MVN::MVN(const Output<Node>& data,
|
||||
const Output<Node>& reduction_axes,
|
||||
bool normalize_variance,
|
||||
float eps,
|
||||
MVNEpsMode eps_mode)
|
||||
: Op({data, reduction_axes})
|
||||
, m_eps{eps}
|
||||
, m_normalize_variance{normalize_variance}
|
||||
, m_eps_mode{eps_mode}
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v6::MVN::validate_and_infer_types()
|
||||
{
|
||||
const auto data = get_input_partial_shape(0);
|
||||
const auto axes = get_input_partial_shape(1);
|
||||
|
||||
if (axes.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
is_vector(axes.to_shape()),
|
||||
"Expected 1D tensor for the 'axes' input. Got: ",
|
||||
axes);
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
data.rank().is_dynamic() || data.rank().get_length() >= axes.get_shape()[0],
|
||||
"Expected rank for the 'data' input to be higher than axes shape. Got: ",
|
||||
data);
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v6::MVN::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
new_args.size() == 2,
|
||||
"Expected 2 element in new_args for the MVN op but got ",
|
||||
new_args.size());
|
||||
return make_shared<op::v6::MVN>(
|
||||
new_args.at(0), new_args.at(1), m_normalize_variance, m_eps, m_eps_mode);
|
||||
}
|
||||
|
||||
bool op::v6::MVN::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("eps", m_eps);
|
||||
visitor.on_attribute("normalize_variance", m_normalize_variance);
|
||||
visitor.on_attribute("eps_mode", m_eps_mode);
|
||||
return true;
|
||||
}
|
||||
|
@ -128,6 +128,7 @@ namespace ngraph
|
||||
{
|
||||
return s << as_string(type);
|
||||
}
|
||||
|
||||
template <>
|
||||
NGRAPH_API EnumNames<op::TopKSortType>& EnumNames<op::TopKSortType>::get()
|
||||
{
|
||||
|
@ -137,4 +137,23 @@ const ngraph::OpSet& ngraph::get_opset5()
|
||||
}
|
||||
}
|
||||
return opset;
|
||||
}
|
||||
}
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset6()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset6_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
return opset;
|
||||
}
|
||||
|
@ -21,6 +21,8 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
// ------------------------------ V0 ------------------------------
|
||||
|
||||
TEST(type_prop, mvn)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
||||
@ -47,3 +49,35 @@ TEST(type_prop, mvn_partial)
|
||||
EXPECT_EQ(mvn_partial->get_reduction_axes(), AxisSet{});
|
||||
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
// ------------------------------ V6 ------------------------------
|
||||
|
||||
TEST(type_prop, mvn_6)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
|
||||
auto axes = make_shared<op::Parameter>(element::i64, Shape{3});
|
||||
|
||||
auto mvn_func = make_shared<op::v6::MVN>(data, axes, true, 1e-6, op::MVNEpsMode::INSIDE_SQRT);
|
||||
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
|
||||
EXPECT_EQ(mvn_func->get_shape(), (Shape{1, 2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(type_prop, mvn_6_partial)
|
||||
{
|
||||
auto data =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 5, 6});
|
||||
auto axes = make_shared<op::Parameter>(element::i64, Shape{3});
|
||||
auto mvn_func = make_shared<op::v6::MVN>(data, axes, true, 1e-6, op::MVNEpsMode::INSIDE_SQRT);
|
||||
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(mvn_func->get_output_partial_shape(0).same_scheme(
|
||||
(PartialShape{1, Dimension::dynamic(), 5, 6})));
|
||||
|
||||
// rank unknown
|
||||
auto mvn_partial =
|
||||
make_shared<op::v6::MVN>(make_shared<op::Parameter>(element::f32, PartialShape::dynamic()),
|
||||
axes,
|
||||
true,
|
||||
1e-6,
|
||||
op::MVNEpsMode::INSIDE_SQRT);
|
||||
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user