[ONNX] use Reshape instead of Broadcast in InstanceNorm (#5148)

This commit is contained in:
Mateusz Tabaka 2021-04-12 12:03:00 +02:00 committed by GitHub
parent 6b777b643d
commit 34385eb45d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 56 deletions

View File

@ -71,25 +71,9 @@ namespace ngraph
const auto conv_shape = std::make_shared<default_opset::ShapeOf>(ng_conv);
const auto conv_rank = std::make_shared<default_opset::ShapeOf>(conv_shape);
// Prepare tail shape (rank = conv.rank - 2): [1, 1, 1, 1, ... ]
const auto one_const =
default_opset::Constant::create(element::i64, Shape{1}, {1});
const auto two_const =
default_opset::Constant::create(element::i64, Shape{1}, {2});
const auto tail_shape_rank =
std::make_shared<default_opset::Subtract>(conv_rank, two_const);
const auto tail_shape =
std::make_shared<default_opset::Broadcast>(one_const, tail_shape_rank);
// Construct new bias shape: [1, C, 1, 1, ... ]
const auto C_dim = std::make_shared<default_opset::ShapeOf>(bias);
const auto bias_shape = std::make_shared<default_opset::Concat>(
OutputVector{one_const, C_dim, tail_shape}, 0);
const auto reshaped_bias =
std::make_shared<default_opset::Reshape>(bias, bias_shape, false);
return {std::make_shared<default_opset::Add>(ng_conv, reshaped_bias)};
return {std::make_shared<default_opset::Add>(
ng_conv,
reshape::reshape_channel_shaped_node_to_nchw(bias, conv_rank))};
}
} // namespace

View File

@ -18,6 +18,7 @@
#include "ngraph/partial_shape.hpp"
#include "op/instance_norm.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
@ -84,22 +85,15 @@ namespace ngraph
auto mvn = std::make_shared<default_opset::MVN>(
data, reduction_axes, true, epsilon, ngraph::op::MVNEpsMode::INSIDE_SQRT);
const auto data_shape_node = std::make_shared<default_opset::ShapeOf>(data);
// Broadcast preserving channel dimension
scale = std::make_shared<default_opset::Broadcast>(
scale,
data_shape_node,
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1));
bias = std::make_shared<default_opset::Broadcast>(
bias,
data_shape_node,
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1));
const auto mvn_shape = std::make_shared<default_opset::ShapeOf>(mvn);
const auto mvn_rank = std::make_shared<default_opset::ShapeOf>(mvn_shape);
// scale * mvn + bias
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::Multiply>(mvn, scale);
result = std::make_shared<default_opset::Add>(result, bias);
std::make_shared<default_opset::Multiply>(
mvn, reshape::reshape_channel_shaped_node_to_nchw(scale, mvn_rank));
result = std::make_shared<default_opset::Add>(
result, reshape::reshape_channel_shaped_node_to_nchw(bias, mvn_rank));
return {result};
}

View File

@ -84,15 +84,35 @@ namespace ngraph
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::Reshape>(mvn, data_shape_node, true);
const auto& rank = data.get_partial_shape().rank();
NGRAPH_CHECK(rank.is_static());
auto data_rank_size = rank.get_length();
const auto& scale_shape = scale.get_partial_shape();
NGRAPH_CHECK(scale_shape.rank().is_static());
auto scale_rank = scale_shape.rank().get_length();
result = std::make_shared<default_opset::Multiply>(
result,
reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size));
result = std::make_shared<default_opset::Add>(
result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size));
const auto& bias_shape = bias.get_partial_shape();
NGRAPH_CHECK(bias_shape.rank().is_static());
auto bias_rank = bias_shape.rank().get_length();
const auto data_rank =
std::make_shared<default_opset::ShapeOf>(data_shape_node);
if (scale_rank == 1)
{
result = std::make_shared<default_opset::Multiply>(
result, reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank));
}
else
{
result = std::make_shared<default_opset::Multiply>(result, scale);
}
if (bias_rank == 1)
{
result = std::make_shared<default_opset::Add>(
result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank));
}
else
{
result = std::make_shared<default_opset::Add>(result, bias);
}
return {result};
}

View File

@ -104,21 +104,22 @@ namespace ngraph
Output<ngraph::Node>
reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
size_t expected_rank)
const Output<ngraph::Node>& expected_rank)
{
const auto& rank = node.get_partial_shape().rank();
NGRAPH_CHECK(rank.is_static());
size_t node_rank = rank.get_length();
if (node_rank == 1)
{
// reshape the node with shape {C} to {1, C, 1, 1, ..., 1}
std::vector<size_t> reshape_pattern_values(expected_rank, 1U);
reshape_pattern_values[1] = node.get_shape().front();
const auto reshape_pattern = default_opset::Constant::create(
element::u64, Shape{reshape_pattern_values.size()}, reshape_pattern_values);
return std::make_shared<default_opset::Reshape>(node, reshape_pattern, false);
}
return node;
// Prepare tail shape (rank = conv.rank - 2): [1, 1, 1, 1, ... ]
const auto one_const = default_opset::Constant::create(element::i64, Shape{1}, {1});
const auto two_const = default_opset::Constant::create(element::i64, Shape{1}, {2});
const auto tail_shape_rank =
std::make_shared<default_opset::Subtract>(expected_rank, two_const);
const auto tail_shape =
std::make_shared<default_opset::Broadcast>(one_const, tail_shape_rank);
// Construct new bias shape: [1, C, 1, 1, ... ]
const auto C_dim = std::make_shared<default_opset::ShapeOf>(node);
const auto new_shape = std::make_shared<default_opset::Concat>(
OutputVector{one_const, C_dim, tail_shape}, 0);
return std::make_shared<default_opset::Reshape>(node, new_shape, false);
}
} // namespace reshape

View File

@ -63,7 +63,7 @@ namespace ngraph
///
Output<ngraph::Node>
reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
size_t expected_rank);
const Output<ngraph::Node>& expected_rank);
} // namespace reshape
} // namespace onnx_import

View File

@ -59,7 +59,6 @@ graph {
dim {
dim_value: 2
}
}
}
}