Fix RTInfo for ReduceL2Decomposition (#16107)
* Fix RTInfo for ReduceL2Decomposition * Review comments
This commit is contained in:
@@ -32,7 +32,11 @@ ov::pass::ReduceL2Decomposition::ReduceL2Decomposition() {
|
||||
reduce_l2_node->get_keep_dims());
|
||||
auto sqrt = std::make_shared<ov::opset4::Sqrt>(reduce_sum);
|
||||
sqrt->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(reduce_l2_node, {sqrt, reduce_sum, square, const_2});
|
||||
ov::NodeVector rt_info_from_nodes{reduce_l2_node};
|
||||
const auto reduce_l2_input_1 = reduce_l2_node->input_value(1).get_node();
|
||||
if (ov::op::util::is_constant(reduce_l2_input_1))
|
||||
rt_info_from_nodes.emplace_back(reduce_l2_input_1->shared_from_this());
|
||||
ngraph::copy_runtime_info(rt_info_from_nodes, {sqrt, reduce_sum, square, const_2});
|
||||
ngraph::replace_node(m.get_match_root(), sqrt);
|
||||
return true;
|
||||
};
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
#include "transformations/common_optimizations/nop_elimination.hpp"
|
||||
#include "transformations/convert_precision.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
|
||||
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
|
||||
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
|
||||
#include "transformations/rt_info/decompression.hpp"
|
||||
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||
|
||||
@@ -461,3 +463,30 @@ TEST_F(GetSupportedNodesTest, ShuffleChannelFusion) {
|
||||
},
|
||||
{}); // Nothing is supported due to unsupported ShuffleChannels
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, FusedNameReduceL2Test) {
|
||||
{
|
||||
auto data = std::make_shared<ov::opset9::Parameter>(ov::element::f32, ov::Shape{1, 512});
|
||||
data->set_friendly_name("data");
|
||||
auto axes = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1});
|
||||
axes->set_friendly_name("axes");
|
||||
auto reduce_l2 = std::make_shared<ov::opset9::ReduceL2>(data, axes, true);
|
||||
reduce_l2->set_friendly_name("reduce_l2");
|
||||
|
||||
m_function = std::make_shared<ngraph::Function>(ov::NodeVector{reduce_l2}, ov::ParameterVector{data});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::ReduceL2Decomposition>();
|
||||
m.register_pass<ov::pass::ConvertReduceToPooling>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||
// Pooling is supported, but Sqrt is not
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_output(op) || ov::op::util::is_constant(op) ||
|
||||
(std::dynamic_pointer_cast<ov::opset1::AvgPool>(op) != nullptr);
|
||||
},
|
||||
{}); // Check that constant axis is removed from supported
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user