Add MVN-6 support to ngraph (#3464)

* Add MVN-6 to ngraph

* Apply review feedback

* Fix max opset number

* Fix code style

* Fix shape test

* Disable reader test

* Apply review feedback and remove reader test

* Fix code style

* Fix build

* Apply review feedback

* Fix build problem

* Fix code style

* Fix build
This commit is contained in:
Maxim Vafin 2020-12-15 21:36:44 +03:00 committed by GitHub
parent a569a0b529
commit ab974e4f2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 416 additions and 7 deletions

View File

@ -22,6 +22,7 @@
#include <ngraph/opsets/opset2.hpp> #include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset5.hpp> #include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
#include <cpp/ie_cnn_network.h> #include <cpp/ie_cnn_network.h>
@ -70,6 +71,7 @@ V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) : _exts(exts) {
opsets["opset3"] = ngraph::get_opset3(); opsets["opset3"] = ngraph::get_opset3();
opsets["opset4"] = ngraph::get_opset4(); opsets["opset4"] = ngraph::get_opset4();
opsets["opset5"] = ngraph::get_opset5(); opsets["opset5"] = ngraph::get_opset5();
opsets["opset6"] = ngraph::get_opset6();
// Load custom opsets // Load custom opsets
for (const auto& ext : exts) { for (const auto& ext : exts) {
@ -427,7 +429,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
// Check that operation in default opsets // Check that operation in default opsets
auto isDefaultOpSet = [](const std::string& version) -> bool { auto isDefaultOpSet = [](const std::string& version) -> bool {
for (size_t i = 1; i <= 5; i++) { for (size_t i = 1; i <= 6; i++) {
std::string opset_name = "opset" + std::to_string(i); std::string opset_name = "opset" + std::to_string(i);
if (version == opset_name) if (version == opset_name)
return true; return true;

View File

@ -20,12 +20,12 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp" #include "ngraph/op/util/fused_op.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
NGRAPH_SUPPRESS_DEPRECATED_START
namespace v0 namespace v0
{ {
/// \brief Operator performing Mean Variance Normalization /// \brief Operator performing Mean Variance Normalization
@ -87,7 +87,75 @@ namespace ngraph
}; };
} }
using v0::MVN; using v0::MVN;
} // namespace op
} // namespace ngraph
NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_SUPPRESS_DEPRECATED_END
/// \brief Specifies how eps is applied in MVN
enum class MVNEpsMode
{
// Apply eps inside sqrt
INSIDE_SQRT,
// Apply eps outside sqrt
OUTSIDE_SQRT
};
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const MVNEpsMode& type);
namespace v6
{
/// \brief Operator performing Mean Variance Normalization
///
class NGRAPH_API MVN : public ngraph::op::Op
{
public:
NGRAPH_RTTI_DECLARATION;
MVN() = default;
/// \brief Constructs an MVN operation.
///
/// \param data Input tensor with data
/// \param reduction_axes A list of axes, along which to reduce.
/// \param normalize_variance flag that denotes whether to perform variance
/// normalization.
/// \param eps the number to be added to the variance to avoid division by zero when
/// normalizing the value
/// \param eps_mode the mode of applying epsilon
///
MVN(const Output<Node>& data,
const Output<Node>& reduction_axes,
bool normalize_variance,
float eps,
MVNEpsMode eps_mode);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
float get_eps() const { return m_eps; }
bool get_normalize_variance() const { return m_normalize_variance; }
MVNEpsMode get_eps_mode() const { return m_eps_mode; }
private:
bool m_normalize_variance = true;
float m_eps = (float)1e-6;
MVNEpsMode m_eps_mode = MVNEpsMode::INSIDE_SQRT;
};
} // namespace v6
} // namespace op
template <>
class NGRAPH_API AttributeAdapter<op::MVNEpsMode>
: public EnumAttributeAdapterBase<op::MVNEpsMode>
{
public:
AttributeAdapter(op::MVNEpsMode& value)
: EnumAttributeAdapterBase<op::MVNEpsMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::MVNEpsMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph

View File

@ -133,4 +133,5 @@ namespace ngraph
const NGRAPH_API OpSet& get_opset3(); const NGRAPH_API OpSet& get_opset3();
const NGRAPH_API OpSet& get_opset4(); const NGRAPH_API OpSet& get_opset4();
const NGRAPH_API OpSet& get_opset5(); const NGRAPH_API OpSet& get_opset5();
const NGRAPH_API OpSet& get_opset6();
} }

View File

@ -0,0 +1,29 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/ops.hpp"
namespace ngraph
{
namespace opset6
{
#define NGRAPH_OP(a, b) using b::a;
#include "ngraph/opsets/opset6_tbl.hpp"
#undef NGRAPH_OP
} // namespace opset6
} // namespace ngraph

View File

@ -0,0 +1,176 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(Abs, ngraph::op::v0)
NGRAPH_OP(Acos, ngraph::op::v0)
NGRAPH_OP(Add, ngraph::op::v1)
NGRAPH_OP(Asin, ngraph::op::v0)
NGRAPH_OP(Atan, ngraph::op::v0)
NGRAPH_OP(AvgPool, ngraph::op::v1)
NGRAPH_OP(BatchNormInference, ngraph::op::v5)
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
NGRAPH_OP(Broadcast, ngraph::op::v3)
NGRAPH_OP(Bucketize, ngraph::op::v3)
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
NGRAPH_OP(Ceiling, ngraph::op::v0)
NGRAPH_OP(Clamp, ngraph::op::v0)
NGRAPH_OP(Concat, ngraph::op::v0)
NGRAPH_OP(Constant, ngraph::op)
NGRAPH_OP(Convert, ngraph::op::v0)
NGRAPH_OP(ConvertLike, ngraph::op::v1)
NGRAPH_OP(Convolution, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(Cos, ngraph::op::v0)
NGRAPH_OP(Cosh, ngraph::op::v0)
NGRAPH_OP(CumSum, ngraph::op::v0)
NGRAPH_OP(DeformableConvolution, ngraph::op::v1)
NGRAPH_OP(DeformablePSROIPooling, ngraph::op::v1)
NGRAPH_OP(DepthToSpace, ngraph::op::v0)
NGRAPH_OP(DetectionOutput, ngraph::op::v0)
NGRAPH_OP(Divide, ngraph::op::v1)
NGRAPH_OP(Elu, ngraph::op::v0)
NGRAPH_OP(Erf, ngraph::op::v0)
NGRAPH_OP(Equal, ngraph::op::v1)
NGRAPH_OP(Exp, ngraph::op::v0)
NGRAPH_OP(ExtractImagePatches, ngraph::op::v3)
NGRAPH_OP(FakeQuantize, ngraph::op::v0)
NGRAPH_OP(Floor, ngraph::op::v0)
NGRAPH_OP(FloorMod, ngraph::op::v1)
NGRAPH_OP(Gather, ngraph::op::v1)
NGRAPH_OP(GatherTree, ngraph::op::v1)
NGRAPH_OP(Greater, ngraph::op::v1)
NGRAPH_OP(GreaterEqual, ngraph::op::v1)
NGRAPH_OP(GroupConvolution, ngraph::op::v1)
NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(GRN, ngraph::op::v0)
NGRAPH_OP(HardSigmoid, ngraph::op::v0)
NGRAPH_OP(Less, ngraph::op::v1)
NGRAPH_OP(LessEqual, ngraph::op::v1)
NGRAPH_OP(Log, ngraph::op::v0)
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
NGRAPH_OP(LogicalNot, ngraph::op::v1)
NGRAPH_OP(LogicalOr, ngraph::op::v1)
NGRAPH_OP(LogicalXor, ngraph::op::v1)
NGRAPH_OP(LRN, ngraph::op::v0)
NGRAPH_OP(LSTMCell, ngraph::op::v4)
NGRAPH_OP(MatMul, ngraph::op::v0)
NGRAPH_OP(MaxPool, ngraph::op::v1)
NGRAPH_OP(Maximum, ngraph::op::v1)
NGRAPH_OP(Minimum, ngraph::op::v1)
NGRAPH_OP(Mod, ngraph::op::v1)
NGRAPH_OP(Multiply, ngraph::op::v1)
NGRAPH_OP(Negative, ngraph::op::v0)
NGRAPH_OP(NormalizeL2, ngraph::op::v0)
NGRAPH_OP(NotEqual, ngraph::op::v1)
NGRAPH_OP(OneHot, ngraph::op::v1)
NGRAPH_OP(PRelu, ngraph::op::v0)
NGRAPH_OP(PSROIPooling, ngraph::op::v0)
NGRAPH_OP(Pad, ngraph::op::v1)
NGRAPH_OP(Parameter, ngraph::op::v0)
NGRAPH_OP(Power, ngraph::op::v1)
NGRAPH_OP(PriorBox, ngraph::op::v0)
NGRAPH_OP(PriorBoxClustered, ngraph::op::v0)
NGRAPH_OP(Proposal, ngraph::op::v4)
NGRAPH_OP(Range, ngraph::op::v4)
NGRAPH_OP(Relu, ngraph::op::v0)
NGRAPH_OP(ReduceMax, ngraph::op::v1)
NGRAPH_OP(ReduceLogicalAnd, ngraph::op::v1)
NGRAPH_OP(ReduceLogicalOr, ngraph::op::v1)
NGRAPH_OP(ReduceMean, ngraph::op::v1)
NGRAPH_OP(ReduceMin, ngraph::op::v1)
NGRAPH_OP(ReduceProd, ngraph::op::v1)
NGRAPH_OP(ReduceSum, ngraph::op::v1)
NGRAPH_OP(RegionYolo, ngraph::op::v0)
NGRAPH_OP(ReorgYolo, ngraph::op::v0)
NGRAPH_OP(Reshape, ngraph::op::v1)
NGRAPH_OP(Result, ngraph::op::v0)
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
NGRAPH_OP(ROIPooling, ngraph::op::v0)
NGRAPH_OP(ScatterNDUpdate, ngraph::op::v3)
NGRAPH_OP(Select, ngraph::op::v1)
NGRAPH_OP(Selu, ngraph::op::v0)
NGRAPH_OP(Sign, ngraph::op::v0)
NGRAPH_OP(Sigmoid, ngraph::op::v0)
NGRAPH_OP(Sin, ngraph::op::v0)
NGRAPH_OP(Sinh, ngraph::op::v0)
NGRAPH_OP(Softmax, ngraph::op::v1)
NGRAPH_OP(Sqrt, ngraph::op::v0)
NGRAPH_OP(SpaceToDepth, ngraph::op::v0)
NGRAPH_OP(Split, ngraph::op::v1)
NGRAPH_OP(SquaredDifference, ngraph::op::v0)
NGRAPH_OP(Squeeze, ngraph::op::v0)
NGRAPH_OP(StridedSlice, ngraph::op::v1)
NGRAPH_OP(Subtract, ngraph::op::v1)
NGRAPH_OP(Tan, ngraph::op::v0)
NGRAPH_OP(Tanh, ngraph::op::v0)
NGRAPH_OP(TensorIterator, ngraph::op::v0)
NGRAPH_OP(Tile, ngraph::op::v0)
NGRAPH_OP(Transpose, ngraph::op::v1)
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
// New operations added in opset2
NGRAPH_OP(Gelu, ngraph::op::v0)
NGRAPH_OP(BatchToSpace, ngraph::op::v1)
NGRAPH_OP(SpaceToBatch, ngraph::op::v1)
// New operations added in opset3
NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3)
NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3)
NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3)
NGRAPH_OP(GRUCell, ngraph::op::v3)
NGRAPH_OP(NonZero, ngraph::op::v3)
NGRAPH_OP(RNNCell, ngraph::op::v0)
NGRAPH_OP(ROIAlign, ngraph::op::v3)
NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
NGRAPH_OP(ShapeOf, ngraph::op::v3)
NGRAPH_OP(Assign, ngraph::op::v3)
NGRAPH_OP(ReadValue, ngraph::op::v3)
NGRAPH_OP(TopK, ngraph::op::v3)
// New operations added in opset4
NGRAPH_OP(Acosh, ngraph::op::v3)
NGRAPH_OP(Asinh, ngraph::op::v3)
NGRAPH_OP(Atanh, ngraph::op::v3)
NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(HSwish, ngraph::op::v4)
NGRAPH_OP(Interpolate, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
NGRAPH_OP(ReduceL1, ngraph::op::v4)
NGRAPH_OP(ReduceL2, ngraph::op::v4)
NGRAPH_OP(SoftPlus, ngraph::op::v4)
NGRAPH_OP(Swish, ngraph::op::v4)
// New operations added in opset5
NGRAPH_OP(GatherND, ngraph::op::v5)
NGRAPH_OP(GRUSequence, ngraph::op::v5)
NGRAPH_OP(HSigmoid, ngraph::op::v5)
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
NGRAPH_OP(Loop, ngraph::op::v5)
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v5)
NGRAPH_OP(RNNSequence, ngraph::op::v5)
NGRAPH_OP(Round, ngraph::op::v5)
// New operations added in opset6
NGRAPH_OP(MVN, ngraph::op::v6)

View File

@ -28,6 +28,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
// ------------------------------ V0 ------------------------------
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0); NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0);
@ -119,3 +121,80 @@ bool op::MVN::visit_attributes(AttributeVisitor& visitor)
visitor.on_attribute("reduction_axes", m_reduction_axes); visitor.on_attribute("reduction_axes", m_reduction_axes);
return true; return true;
} }
// ------------------------------ V6 ------------------------------
namespace ngraph
{
template <>
NGRAPH_API EnumNames<op::MVNEpsMode>& EnumNames<op::MVNEpsMode>::get()
{
static auto enum_names =
EnumNames<op::MVNEpsMode>("op::MVNEpsMode",
{{"OUTSIDE_SQRT", op::MVNEpsMode::OUTSIDE_SQRT},
{"INSIDE_SQRT", op::MVNEpsMode::INSIDE_SQRT}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::MVNEpsMode>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::MVNEpsMode& type)
{
return s << as_string(type);
}
} // namespace ngraph
NGRAPH_RTTI_DEFINITION(op::v6::MVN, "MVN", 6);
op::v6::MVN::MVN(const Output<Node>& data,
const Output<Node>& reduction_axes,
bool normalize_variance,
float eps,
MVNEpsMode eps_mode)
: Op({data, reduction_axes})
, m_eps{eps}
, m_normalize_variance{normalize_variance}
, m_eps_mode{eps_mode}
{
constructor_validate_and_infer_types();
}
void op::v6::MVN::validate_and_infer_types()
{
const auto data = get_input_partial_shape(0);
const auto axes = get_input_partial_shape(1);
if (axes.is_static())
{
NODE_VALIDATION_CHECK(this,
is_vector(axes.to_shape()),
"Expected 1D tensor for the 'axes' input. Got: ",
axes);
NODE_VALIDATION_CHECK(
this,
data.rank().is_dynamic() || data.rank().get_length() >= axes.get_shape()[0],
"Expected rank for the 'data' input to be higher than axes shape. Got: ",
data);
}
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::v6::MVN::clone_with_new_inputs(const OutputVector& new_args) const
{
NODE_VALIDATION_CHECK(this,
new_args.size() == 2,
"Expected 2 element in new_args for the MVN op but got ",
new_args.size());
return make_shared<op::v6::MVN>(
new_args.at(0), new_args.at(1), m_normalize_variance, m_eps, m_eps_mode);
}
bool op::v6::MVN::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("eps", m_eps);
visitor.on_attribute("normalize_variance", m_normalize_variance);
visitor.on_attribute("eps_mode", m_eps_mode);
return true;
}

View File

@ -128,6 +128,7 @@ namespace ngraph
{ {
return s << as_string(type); return s << as_string(type);
} }
template <> template <>
NGRAPH_API EnumNames<op::TopKSortType>& EnumNames<op::TopKSortType>::get() NGRAPH_API EnumNames<op::TopKSortType>& EnumNames<op::TopKSortType>::get()
{ {

View File

@ -138,3 +138,22 @@ const ngraph::OpSet& ngraph::get_opset5()
} }
return opset; return opset;
} }
const ngraph::OpSet& ngraph::get_opset6()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset6_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
return opset;
}

View File

@ -21,6 +21,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
// ------------------------------ V0 ------------------------------
TEST(type_prop, mvn) TEST(type_prop, mvn)
{ {
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6}); auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
@ -47,3 +49,35 @@ TEST(type_prop, mvn_partial)
EXPECT_EQ(mvn_partial->get_reduction_axes(), AxisSet{}); EXPECT_EQ(mvn_partial->get_reduction_axes(), AxisSet{});
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
} }
// ------------------------------ V6 ------------------------------
TEST(type_prop, mvn_6)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto axes = make_shared<op::Parameter>(element::i64, Shape{3});
auto mvn_func = make_shared<op::v6::MVN>(data, axes, true, 1e-6, op::MVNEpsMode::INSIDE_SQRT);
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
EXPECT_EQ(mvn_func->get_shape(), (Shape{1, 2, 3, 4}));
}
TEST(type_prop, mvn_6_partial)
{
auto data =
make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 5, 6});
auto axes = make_shared<op::Parameter>(element::i64, Shape{3});
auto mvn_func = make_shared<op::v6::MVN>(data, axes, true, 1e-6, op::MVNEpsMode::INSIDE_SQRT);
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
ASSERT_TRUE(mvn_func->get_output_partial_shape(0).same_scheme(
(PartialShape{1, Dimension::dynamic(), 5, 6})));
// rank unknown
auto mvn_partial =
make_shared<op::v6::MVN>(make_shared<op::Parameter>(element::f32, PartialShape::dynamic()),
axes,
true,
1e-6,
op::MVNEpsMode::INSIDE_SQRT);
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}