[ONNX][Code refactor] ONNX GroupNormalization - Optimize getting inputs
* Optimize getting inputs and use more const * Fix typo
This commit is contained in:
parent
b13cb8ce12
commit
1384471849
@ -12,30 +12,34 @@ namespace onnx_import {
|
|||||||
namespace op {
|
namespace op {
|
||||||
namespace set_1 {
|
namespace set_1 {
|
||||||
OutputVector group_normalization(const Node& node) {
|
OutputVector group_normalization(const Node& node) {
|
||||||
const auto data = node.get_ng_inputs().at(0); // Shape [N, C, ...]
|
const auto inputs = node.get_ng_inputs();
|
||||||
auto scale = node.get_ng_inputs().at(1); // Shape [num_groups]
|
OPENVINO_ASSERT(inputs.size() == 3);
|
||||||
auto bias = node.get_ng_inputs().at(2); // Shape [num_groups]
|
|
||||||
|
|
||||||
auto eps = node.get_attribute_value<float>("epsilon", 1e-05f);
|
const auto& data = inputs[0]; // Shape [N, C, ...]
|
||||||
auto num_groups = node.get_attribute_value<int64_t>("num_groups");
|
const auto& scale = inputs[1]; // Shape [num_groups]
|
||||||
|
const auto& bias = inputs[2]; // Shape [num_groups]
|
||||||
|
|
||||||
auto zero = default_opset::Constant::create(element::i64, Shape{1}, {0});
|
const auto eps = node.get_attribute_value<float>("epsilon", 1e-05f);
|
||||||
auto one = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
const auto num_groups = node.get_attribute_value<int64_t>("num_groups");
|
||||||
auto c_dim = std::make_shared<default_opset::Gather>(std::make_shared<default_opset::ShapeOf>(data), one, zero);
|
|
||||||
auto g_dim = default_opset::Constant::create(element::i64, Shape{1}, {num_groups});
|
|
||||||
|
|
||||||
auto c_g_div = std::make_shared<default_opset::Divide>(c_dim, g_dim);
|
const auto zero = default_opset::Constant::create(element::i64, Shape{1}, {0});
|
||||||
|
const auto one = default_opset::Constant::create(element::i64, Shape{1}, {1});
|
||||||
|
const auto c_dim =
|
||||||
|
std::make_shared<default_opset::Gather>(std::make_shared<default_opset::ShapeOf>(data), one, zero);
|
||||||
|
const auto g_dim = default_opset::Constant::create(element::i64, Shape{1}, {num_groups});
|
||||||
|
|
||||||
|
const auto c_g_div = std::make_shared<default_opset::Divide>(c_dim, g_dim);
|
||||||
|
|
||||||
// Adjust scale and bias shape, [G] -> [G, C/G] -> [C]
|
// Adjust scale and bias shape, [G] -> [G, C/G] -> [C]
|
||||||
scale = std::make_shared<default_opset::Unsqueeze>(scale, one);
|
const auto scale_unsq = std::make_shared<default_opset::Unsqueeze>(scale, one);
|
||||||
auto broadcast_scale =
|
const auto broadcast_scale =
|
||||||
std::make_shared<default_opset::Broadcast>(scale, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
|
std::make_shared<default_opset::Broadcast>(scale_unsq, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
|
||||||
auto c_scale = std::make_shared<default_opset::Reshape>(broadcast_scale, c_dim, false);
|
const auto c_scale = std::make_shared<default_opset::Reshape>(broadcast_scale, c_dim, false);
|
||||||
|
|
||||||
bias = std::make_shared<default_opset::Unsqueeze>(bias, one);
|
const auto bias_unsq = std::make_shared<default_opset::Unsqueeze>(bias, one);
|
||||||
auto broadcast_bias =
|
const auto broadcast_bias =
|
||||||
std::make_shared<default_opset::Broadcast>(bias, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
|
std::make_shared<default_opset::Broadcast>(bias_unsq, c_g_div, ov::op::BroadcastType::BIDIRECTIONAL);
|
||||||
auto c_bias = std::make_shared<default_opset::Reshape>(broadcast_bias, c_dim, false);
|
const auto c_bias = std::make_shared<default_opset::Reshape>(broadcast_bias, c_dim, false);
|
||||||
|
|
||||||
return {std::make_shared<default_opset::GroupNormalization>(data, c_scale, c_bias, num_groups, eps)};
|
return {std::make_shared<default_opset::GroupNormalization>(data, c_scale, c_bias, num_groups, eps)};
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user