[ONNX] use Reshape instead of Broadcast in InstanceNorm (#5148)
This commit is contained in:
parent
6b777b643d
commit
34385eb45d
@ -71,25 +71,9 @@ namespace ngraph
|
||||
const auto conv_shape = std::make_shared<default_opset::ShapeOf>(ng_conv);
|
||||
const auto conv_rank = std::make_shared<default_opset::ShapeOf>(conv_shape);
|
||||
|
||||
// Prepare tail shape (rank = conv.rank - 2): [1, 1, 1, 1, ... ]
|
||||
const auto one_const =
|
||||
default_opset::Constant::create(element::i64, Shape{1}, {1});
|
||||
const auto two_const =
|
||||
default_opset::Constant::create(element::i64, Shape{1}, {2});
|
||||
const auto tail_shape_rank =
|
||||
std::make_shared<default_opset::Subtract>(conv_rank, two_const);
|
||||
const auto tail_shape =
|
||||
std::make_shared<default_opset::Broadcast>(one_const, tail_shape_rank);
|
||||
|
||||
// Construct new bias shape: [1, C, 1, 1, ... ]
|
||||
const auto C_dim = std::make_shared<default_opset::ShapeOf>(bias);
|
||||
const auto bias_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{one_const, C_dim, tail_shape}, 0);
|
||||
|
||||
const auto reshaped_bias =
|
||||
std::make_shared<default_opset::Reshape>(bias, bias_shape, false);
|
||||
|
||||
return {std::make_shared<default_opset::Add>(ng_conv, reshaped_bias)};
|
||||
return {std::make_shared<default_opset::Add>(
|
||||
ng_conv,
|
||||
reshape::reshape_channel_shaped_node_to_nchw(bias, conv_rank))};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "ngraph/partial_shape.hpp"
|
||||
#include "op/instance_norm.hpp"
|
||||
#include "utils/common.hpp"
|
||||
#include "utils/reshape.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -84,22 +85,15 @@ namespace ngraph
|
||||
auto mvn = std::make_shared<default_opset::MVN>(
|
||||
data, reduction_axes, true, epsilon, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
|
||||
const auto data_shape_node = std::make_shared<default_opset::ShapeOf>(data);
|
||||
|
||||
// Broadcast preserving channel dimension
|
||||
scale = std::make_shared<default_opset::Broadcast>(
|
||||
scale,
|
||||
data_shape_node,
|
||||
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1));
|
||||
bias = std::make_shared<default_opset::Broadcast>(
|
||||
bias,
|
||||
data_shape_node,
|
||||
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1));
|
||||
const auto mvn_shape = std::make_shared<default_opset::ShapeOf>(mvn);
|
||||
const auto mvn_rank = std::make_shared<default_opset::ShapeOf>(mvn_shape);
|
||||
|
||||
// scale * mvn + bias
|
||||
std::shared_ptr<ngraph::Node> result =
|
||||
std::make_shared<default_opset::Multiply>(mvn, scale);
|
||||
result = std::make_shared<default_opset::Add>(result, bias);
|
||||
std::make_shared<default_opset::Multiply>(
|
||||
mvn, reshape::reshape_channel_shaped_node_to_nchw(scale, mvn_rank));
|
||||
result = std::make_shared<default_opset::Add>(
|
||||
result, reshape::reshape_channel_shaped_node_to_nchw(bias, mvn_rank));
|
||||
|
||||
return {result};
|
||||
}
|
||||
|
@ -84,15 +84,35 @@ namespace ngraph
|
||||
std::shared_ptr<ngraph::Node> result =
|
||||
std::make_shared<default_opset::Reshape>(mvn, data_shape_node, true);
|
||||
|
||||
const auto& rank = data.get_partial_shape().rank();
|
||||
NGRAPH_CHECK(rank.is_static());
|
||||
auto data_rank_size = rank.get_length();
|
||||
const auto& scale_shape = scale.get_partial_shape();
|
||||
NGRAPH_CHECK(scale_shape.rank().is_static());
|
||||
auto scale_rank = scale_shape.rank().get_length();
|
||||
|
||||
result = std::make_shared<default_opset::Multiply>(
|
||||
result,
|
||||
reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size));
|
||||
result = std::make_shared<default_opset::Add>(
|
||||
result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size));
|
||||
const auto& bias_shape = bias.get_partial_shape();
|
||||
NGRAPH_CHECK(bias_shape.rank().is_static());
|
||||
auto bias_rank = bias_shape.rank().get_length();
|
||||
|
||||
const auto data_rank =
|
||||
std::make_shared<default_opset::ShapeOf>(data_shape_node);
|
||||
|
||||
if (scale_rank == 1)
|
||||
{
|
||||
result = std::make_shared<default_opset::Multiply>(
|
||||
result, reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank));
|
||||
}
|
||||
else
|
||||
{
|
||||
result = std::make_shared<default_opset::Multiply>(result, scale);
|
||||
}
|
||||
if (bias_rank == 1)
|
||||
{
|
||||
result = std::make_shared<default_opset::Add>(
|
||||
result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank));
|
||||
}
|
||||
else
|
||||
{
|
||||
result = std::make_shared<default_opset::Add>(result, bias);
|
||||
}
|
||||
|
||||
return {result};
|
||||
}
|
||||
|
@ -104,21 +104,22 @@ namespace ngraph
|
||||
|
||||
Output<ngraph::Node>
|
||||
reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
|
||||
size_t expected_rank)
|
||||
const Output<ngraph::Node>& expected_rank)
|
||||
{
|
||||
const auto& rank = node.get_partial_shape().rank();
|
||||
NGRAPH_CHECK(rank.is_static());
|
||||
size_t node_rank = rank.get_length();
|
||||
if (node_rank == 1)
|
||||
{
|
||||
// reshape the node with shape {C} to {1, C, 1, 1, ..., 1}
|
||||
std::vector<size_t> reshape_pattern_values(expected_rank, 1U);
|
||||
reshape_pattern_values[1] = node.get_shape().front();
|
||||
const auto reshape_pattern = default_opset::Constant::create(
|
||||
element::u64, Shape{reshape_pattern_values.size()}, reshape_pattern_values);
|
||||
return std::make_shared<default_opset::Reshape>(node, reshape_pattern, false);
|
||||
}
|
||||
return node;
|
||||
// Prepare tail shape (rank = conv.rank - 2): [1, 1, 1, 1, ... ]
|
||||
const auto one_const = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
||||
const auto two_const = default_opset::Constant::create(element::i64, Shape{1}, {2});
|
||||
const auto tail_shape_rank =
|
||||
std::make_shared<default_opset::Subtract>(expected_rank, two_const);
|
||||
const auto tail_shape =
|
||||
std::make_shared<default_opset::Broadcast>(one_const, tail_shape_rank);
|
||||
|
||||
// Construct new bias shape: [1, C, 1, 1, ... ]
|
||||
const auto C_dim = std::make_shared<default_opset::ShapeOf>(node);
|
||||
const auto new_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{one_const, C_dim, tail_shape}, 0);
|
||||
|
||||
return std::make_shared<default_opset::Reshape>(node, new_shape, false);
|
||||
}
|
||||
|
||||
} // namespace reshape
|
||||
|
@ -63,7 +63,7 @@ namespace ngraph
|
||||
///
|
||||
Output<ngraph::Node>
|
||||
reshape_channel_shaped_node_to_nchw(const Output<ngraph::Node>& node,
|
||||
size_t expected_rank);
|
||||
const Output<ngraph::Node>& expected_rank);
|
||||
|
||||
} // namespace reshape
|
||||
} // namespace onnx_import
|
||||
|
@ -59,7 +59,6 @@ graph {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user