Fix NormalizeL2Fusion and allow LpNormalization to be fused to Normal… (#5664)

* Fix NormalizeL2Fusion and allow LpNormalization to be fused to NormalizeL2

* apply code format

* use cast_vector<uint64_t>

* use MKLDNNNormalizeL2Node::isSupportedOperation in normalizeL2FusionCallback
This commit is contained in:
Mateusz Tabaka 2021-07-22 06:14:43 +02:00 committed by GitHub
parent aecb2b9693
commit 1e1e3bfffd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 84 additions and 74 deletions

View File

@ -26,6 +26,7 @@
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/common_optimizations/softmax_fusion.hpp>
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
@ -87,6 +88,7 @@
#include "nodes/mkldnn_mvn_node.h"
#include "nodes/mkldnn_fake_quantize_node.h"
#include "nodes/mkldnn_normalize_node.h"
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
@ -277,6 +279,13 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
});
auto normalizeL2FusionCallback = [](const_node_ptr &node) -> bool {
std::string errorMsg;
return !MKLDNNNormalizeL2Node::isSupportedOperation(node, errorMsg);
};
pass_config->set_callback<ngraph::pass::NormalizeL2FusionWithAdd>(normalizeL2FusionCallback);
pass_config->set_callback<ngraph::pass::NormalizeL2FusionWithMax>(normalizeL2FusionCallback);
// List of enabled/disabled transformations
pass_config->disable<ngraph::pass::ConvertGELU>();
pass_config->disable<ngraph::pass::ConvertShuffleChannels3>();

View File

@ -660,7 +660,7 @@ MKLDNNNormalizeL2Node::MKLDNNNormalizeL2Node(const std::shared_ptr<ngraph::Node>
}
}
bool MKLDNNNormalizeL2Node::isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept {
bool MKLDNNNormalizeL2Node::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
const auto norm = std::dynamic_pointer_cast<const ngraph::op::v0::NormalizeL2>(op);
if (!norm) {

View File

@ -84,7 +84,7 @@ public:
return false;
}
static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
bool canFuse(const MKLDNNNodePtr& node) const override;
private:

View File

@ -25,10 +25,10 @@ ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
@ -52,12 +52,14 @@ ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
const auto eps_attr_value = eps_attr->cast_vector<float>()[0];
auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::MAX);
if (transformation_callback(normalize_l2))
return false;
normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_max_eps).get_node_shared_ptr(),
pattern_to_output.at(max).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);
@ -79,10 +81,10 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
@ -106,12 +108,14 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
const auto eps_attr_value = op::util::has_constant_value<float>(exp_input, 2.0f);
auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::ADD);
if (transformation_callback(normalize_l2))
return false;
normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_add_eps).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);

View File

@ -27,10 +27,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMax) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
@ -62,10 +62,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
@ -81,10 +81,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}
@ -101,10 +101,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
@ -120,10 +120,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}
@ -141,10 +141,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAdd) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {eps_value});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
@ -176,10 +176,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
@ -196,10 +196,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}
@ -216,10 +216,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
@ -235,10 +235,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}

View File

@ -32,11 +32,13 @@ namespace ngraph
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-0 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l0_norm(const Output<Node>& value,
const Output<Node>& reduction_axes);
const Output<Node>& reduction_axes,
bool keep_dims = false);
/// \brief Calculates L-1 norm of a value.
///
@ -45,12 +47,14 @@ namespace ngraph
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-1 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias = 0.f);
float bias = 0.f,
bool keep_dims = false);
/// \brief Calculates L-2 norm of input tensor.
///
@ -77,13 +81,15 @@ namespace ngraph
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] p_norm The p norm to calculate.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-p norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> lp_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
std::size_t p_norm = 2,
float bias = 0.f);
float bias = 0.f,
bool keep_dims = false);
} // namespace opset1
} // namespace builder
} // namespace ngraph

View File

@ -29,7 +29,8 @@ namespace ngraph
shared_ptr<Node> lp_norm(const Output<Node>& value,
size_t p_norm,
const Output<Node>& reduction_axes,
float bias)
float bias,
bool keep_dims)
{
// In general "entrywise" lp-norm for matrix `A` is defined as following double
// sum:
@ -40,7 +41,8 @@ namespace ngraph
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
shared_ptr<Node> values{make_shared<ngraph::opset1::Power>(abs_values, p_node)};
values = make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, false);
values =
make_shared<ngraph::opset1::ReduceSum>(values, reduction_axes, keep_dims);
shared_ptr<Node> bias_node{ngraph::opset1::Constant::create(
values->get_element_type(), Shape{}, {bias})};
@ -58,7 +60,8 @@ namespace ngraph
} // namespace detail
shared_ptr<Node> builder::opset1::l0_norm(const Output<Node>& value,
const Output<Node>& reduction_axes)
const Output<Node>& reduction_axes,
bool keep_dims)
{
// L0 norm returns number of elements different from zero.
const shared_ptr<Node> zero_node{
@ -68,16 +71,18 @@ namespace ngraph
const shared_ptr<Node> non_zero_values = make_shared<ngraph::opset1::Convert>(
make_shared<ngraph::opset1::NotEqual>(value, zero_node), value.get_element_type());
return make_shared<ngraph::opset1::ReduceSum>(non_zero_values, reduction_axes, false)
return make_shared<ngraph::opset1::ReduceSum>(
non_zero_values, reduction_axes, keep_dims)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias)
float bias,
bool keep_dims)
{
const shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>(
make_shared<ngraph::opset1::Abs>(value), reduction_axes, false)};
make_shared<ngraph::opset1::Abs>(value), reduction_axes, keep_dims)};
const shared_ptr<Node> bias_node{
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
@ -92,8 +97,10 @@ namespace ngraph
BiasMode bias_mode,
bool keep_dims)
{
shared_ptr<Node> values{make_shared<ngraph::opset1::ReduceSum>(
make_shared<ngraph::opset1::Multiply>(value, value), reduction_axes, keep_dims)};
shared_ptr<Node> pow = make_shared<ngraph::opset1::Power>(
value, make_shared<ngraph::opset1::Constant>(value.get_element_type(), Shape{}, 2));
shared_ptr<Node> values{
make_shared<ngraph::opset1::ReduceSum>(pow, reduction_axes, keep_dims)};
shared_ptr<Node> bias_node{
ngraph::opset1::Constant::create(values->get_element_type(), Shape{}, {bias})};
@ -117,27 +124,28 @@ namespace ngraph
shared_ptr<Node> builder::opset1::lp_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
size_t p_norm,
float bias)
float bias,
bool keep_dims)
{
// The number of non-zero elements
if (p_norm == 0)
{
return opset1::l0_norm(value, reduction_axes);
return opset1::l0_norm(value, reduction_axes, keep_dims);
}
// sum of absolute values.
else if (p_norm == 1)
{
return opset1::l1_norm(value, reduction_axes, bias);
return opset1::l1_norm(value, reduction_axes, bias, keep_dims);
}
// sqrt of sum of squares - Euclidean norm
else if (p_norm == 2)
{
return opset1::l2_norm(value, reduction_axes, bias);
return opset1::l2_norm(value, reduction_axes, bias, BiasMode::ADD, keep_dims);
}
// generic case
else
{
return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias);
return detail::opset1::lp_norm(value, p_norm, reduction_axes, bias, keep_dims);
}
}

View File

@ -30,7 +30,6 @@ namespace ngraph
const auto data_shape = data.get_partial_shape();
const auto data_rank = data_shape.rank();
const auto data_rank_value = data_rank.get_length();
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
@ -46,23 +45,7 @@ namespace ngraph
const auto normalize_axis_const =
default_opset::Constant::create(element::i64, {}, {normalize_axis});
std::shared_ptr<ngraph::Node> norm = ngraph::builder::opset1::lp_norm(
data, normalize_axis_const, static_cast<std::size_t>(p_norm));
const auto target_shape = std::make_shared<default_opset::ShapeOf>(data);
// Create a default axes order matching the data tensor rank and erase the
// element at the 'normalize_axis' position. The erased element indicates the
// axis
// along which the data should be broadcasted.
std::vector<size_t> axes_values(data_rank_value);
std::iota(axes_values.begin(), axes_values.end(), 0);
axes_values.erase(axes_values.begin() + normalize_axis);
const auto axes_mapping = default_opset::Constant::create(
element::i64, Shape{axes_values.size()}, axes_values);
norm = std::make_shared<default_opset::Broadcast>(
norm, target_shape, axes_mapping);
data, normalize_axis_const, static_cast<std::size_t>(p_norm), 0.0f, true);
return {std::make_shared<default_opset::Divide>(data, norm)};
}