Fix RTInfo for ReduceL2Decomposition (#16107)
* Fix RTInfo for ReduceL2Decomposition * Review comments
This commit is contained in:
@@ -32,7 +32,11 @@ ov::pass::ReduceL2Decomposition::ReduceL2Decomposition() {
|
|||||||
reduce_l2_node->get_keep_dims());
|
reduce_l2_node->get_keep_dims());
|
||||||
auto sqrt = std::make_shared<ov::opset4::Sqrt>(reduce_sum);
|
auto sqrt = std::make_shared<ov::opset4::Sqrt>(reduce_sum);
|
||||||
sqrt->set_friendly_name(m.get_match_root()->get_friendly_name());
|
sqrt->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
ngraph::copy_runtime_info(reduce_l2_node, {sqrt, reduce_sum, square, const_2});
|
ov::NodeVector rt_info_from_nodes{reduce_l2_node};
|
||||||
|
const auto reduce_l2_input_1 = reduce_l2_node->input_value(1).get_node();
|
||||||
|
if (ov::op::util::is_constant(reduce_l2_input_1))
|
||||||
|
rt_info_from_nodes.emplace_back(reduce_l2_input_1->shared_from_this());
|
||||||
|
ngraph::copy_runtime_info(rt_info_from_nodes, {sqrt, reduce_sum, square, const_2});
|
||||||
ngraph::replace_node(m.get_match_root(), sqrt);
|
ngraph::replace_node(m.get_match_root(), sqrt);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -15,7 +15,9 @@
|
|||||||
#include "transformations/common_optimizations/nop_elimination.hpp"
|
#include "transformations/common_optimizations/nop_elimination.hpp"
|
||||||
#include "transformations/convert_precision.hpp"
|
#include "transformations/convert_precision.hpp"
|
||||||
#include "transformations/init_node_info.hpp"
|
#include "transformations/init_node_info.hpp"
|
||||||
|
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
|
||||||
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
|
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
|
||||||
|
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
|
||||||
#include "transformations/rt_info/decompression.hpp"
|
#include "transformations/rt_info/decompression.hpp"
|
||||||
#include "transformations/rt_info/fused_names_attribute.hpp"
|
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||||
|
|
||||||
@@ -461,3 +463,30 @@ TEST_F(GetSupportedNodesTest, ShuffleChannelFusion) {
|
|||||||
},
|
},
|
||||||
{}); // Nothing is supported due to unsupported ShuffleChannels
|
{}); // Nothing is supported due to unsupported ShuffleChannels
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GetSupportedNodesTest, FusedNameReduceL2Test) {
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ov::opset9::Parameter>(ov::element::f32, ov::Shape{1, 512});
|
||||||
|
data->set_friendly_name("data");
|
||||||
|
auto axes = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1});
|
||||||
|
axes->set_friendly_name("axes");
|
||||||
|
auto reduce_l2 = std::make_shared<ov::opset9::ReduceL2>(data, axes, true);
|
||||||
|
reduce_l2->set_friendly_name("reduce_l2");
|
||||||
|
|
||||||
|
m_function = std::make_shared<ngraph::Function>(ov::NodeVector{reduce_l2}, ov::ParameterVector{data});
|
||||||
|
}
|
||||||
|
Run(
|
||||||
|
[&](std::shared_ptr<ov::Model>& model) {
|
||||||
|
ov::pass::Manager m;
|
||||||
|
m.register_pass<ov::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ov::pass::ReduceL2Decomposition>();
|
||||||
|
m.register_pass<ov::pass::ConvertReduceToPooling>();
|
||||||
|
m.run_passes(model);
|
||||||
|
},
|
||||||
|
[&](const std::shared_ptr<ngraph::Node>& op) {
|
||||||
|
// Pooling is supported, but Sqrt is not
|
||||||
|
return ov::op::util::is_parameter(op) || ov::op::util::is_output(op) || ov::op::util::is_constant(op) ||
|
||||||
|
(std::dynamic_pointer_cast<ov::opset1::AvgPool>(op) != nullptr);
|
||||||
|
},
|
||||||
|
{}); // Check that constant axis is removed from supported
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user