Fix RTInfo for ReduceL2Decomposition (#16107)

* Fix RTInfo for ReduceL2Decomposition

* Review comments
This commit is contained in:
Nadezhda Ageeva
2023-03-07 10:24:55 +04:00
committed by GitHub
parent 87b18a21c1
commit 0dad7749b5
2 changed files with 34 additions and 1 deletions

View File

@@ -32,7 +32,11 @@ ov::pass::ReduceL2Decomposition::ReduceL2Decomposition() {
reduce_l2_node->get_keep_dims());
auto sqrt = std::make_shared<ov::opset4::Sqrt>(reduce_sum);
sqrt->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(reduce_l2_node, {sqrt, reduce_sum, square, const_2});
ov::NodeVector rt_info_from_nodes{reduce_l2_node};
const auto reduce_l2_input_1 = reduce_l2_node->input_value(1).get_node();
if (ov::op::util::is_constant(reduce_l2_input_1))
rt_info_from_nodes.emplace_back(reduce_l2_input_1->shared_from_this());
ngraph::copy_runtime_info(rt_info_from_nodes, {sqrt, reduce_sum, square, const_2});
ngraph::replace_node(m.get_match_root(), sqrt);
return true;
};

View File

@@ -15,7 +15,9 @@
#include "transformations/common_optimizations/nop_elimination.hpp"
#include "transformations/convert_precision.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
#include "transformations/rt_info/decompression.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
@@ -461,3 +463,30 @@ TEST_F(GetSupportedNodesTest, ShuffleChannelFusion) {
},
{}); // Nothing is supported due to unsupported ShuffleChannels
}
TEST_F(GetSupportedNodesTest, FusedNameReduceL2Test) {
{
auto data = std::make_shared<ov::opset9::Parameter>(ov::element::f32, ov::Shape{1, 512});
data->set_friendly_name("data");
auto axes = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1});
axes->set_friendly_name("axes");
auto reduce_l2 = std::make_shared<ov::opset9::ReduceL2>(data, axes, true);
reduce_l2->set_friendly_name("reduce_l2");
m_function = std::make_shared<ngraph::Function>(ov::NodeVector{reduce_l2}, ov::ParameterVector{data});
}
Run(
[&](std::shared_ptr<ov::Model>& model) {
ov::pass::Manager m;
m.register_pass<ov::pass::InitNodeInfo>();
m.register_pass<ov::pass::ReduceL2Decomposition>();
m.register_pass<ov::pass::ConvertReduceToPooling>();
m.run_passes(model);
},
[&](const std::shared_ptr<ngraph::Node>& op) {
// Pooling is supported, but Sqrt is not
return ov::op::util::is_parameter(op) || ov::op::util::is_output(op) || ov::op::util::is_constant(op) ||
(std::dynamic_pointer_cast<ov::opset1::AvgPool>(op) != nullptr);
},
{}); // Check that constant axis is removed from supported
}