[ONNX][Code refactor] ONNX GroupNormalization - Optimize getting inputs

* Optimize getting inputs and use more const

* Fix typo
This commit is contained in:
Katarzyna Mitrus 2023-10-31 08:19:11 +01:00 committed by GitHub
parent b13cb8ce12
commit 1384471849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,30 +12,34 @@ namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector group_normalization(const Node& node) {
const auto data = node.get_ng_inputs().at(0); // Shape [N, C, ...]
auto scale = node.get_ng_inputs().at(1); // Shape [num_groups]
auto bias = node.get_ng_inputs().at(2); // Shape [num_groups]
const auto inputs = node.get_ng_inputs();
OPENVINO_ASSERT(inputs.size() == 3);
auto eps = node.get_attribute_value<float>("epsilon", 1e-05f);
auto num_groups = node.get_attribute_value<int64_t>("num_groups");
const auto& data = inputs[0]; // Shape [N, C, ...]
const auto& scale = inputs[1]; // Shape [num_groups]
const auto& bias = inputs[2]; // Shape [num_groups]
auto zero = default_opset::Constant::create(element::i64, Shape{1}, {0});
auto one = default_opset::Constant::create(element::i64, Shape{1}, {1});
auto c_dim = std::make_shared<default_opset::Gather>(std::make_shared<default_opset::ShapeOf>(data), one, zero);
auto g_dim = default_opset::Constant::create(element::i64, Shape{1}, {num_groups});
const auto eps = node.get_attribute_value<float>("epsilon", 1e-05f);
const auto num_groups = node.get_attribute_value<int64_t>("num_groups");
auto c_g_div = std::make_shared<default_opset::Divide>(c_dim, g_dim);
const auto zero = default_opset::Constant::create(element::i64, Shape{1}, {0});
const auto one = default_opset::Constant::create(element::i64, Shape{1}, {1});
const auto c_dim =
std::make_shared<default_opset::Gather>(std::make_shared<default_opset::ShapeOf>(data), one, zero);
const auto g_dim = default_opset::Constant::create(element::i64, Shape{1}, {num_groups});
const auto c_g_div = std::make_shared<default_opset::Divide>(c_dim, g_dim);
// Adjust scale and bias shape, [G] -> [G, C/G] -> [C]
scale = std::make_shared<default_opset::Unsqueeze>(scale, one);
auto broadcast_scale =
std::make_shared<default_opset::Broadcast>(scale, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
auto c_scale = std::make_shared<default_opset::Reshape>(broadcast_scale, c_dim, false);
const auto scale_unsq = std::make_shared<default_opset::Unsqueeze>(scale, one);
const auto broadcast_scale =
std::make_shared<default_opset::Broadcast>(scale_unsq, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
const auto c_scale = std::make_shared<default_opset::Reshape>(broadcast_scale, c_dim, false);
bias = std::make_shared<default_opset::Unsqueeze>(bias, one);
auto broadcast_bias =
std::make_shared<default_opset::Broadcast>(bias, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
auto c_bias = std::make_shared<default_opset::Reshape>(broadcast_bias, c_dim, false);
const auto bias_unsq = std::make_shared<default_opset::Unsqueeze>(bias, one);
const auto broadcast_bias =
std::make_shared<default_opset::Broadcast>(bias_unsq, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
const auto c_bias = std::make_shared<default_opset::Reshape>(broadcast_bias, c_dim, false);
return {std::make_shared<default_opset::GroupNormalization>(data, c_scale, c_bias, num_groups, eps)};
}