From e12983e82cf001890986a1663ef4d8e2ce283541 Mon Sep 17 00:00:00 2001 From: Artur Kulikowski Date: Tue, 18 Oct 2022 16:30:53 +0200 Subject: [PATCH] Copy runtime info in RIC fusion transformation (#13490) * Copy runtime info in RIC fusion transformation * Use ov namespace * Iterate over const refs instead of use iterator --- .../common_optimizations/ric_fusion.cpp | 46 +++++++++++++------ .../preprocessing_fusion_tests.cpp | 24 ---------- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp index c73de03bb1e..9e736b07747 100644 --- a/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp @@ -65,7 +65,7 @@ public: } // Apply callback to materialize RIC inside graph - void materialize(Input input) const { + void materialize(Input input, const ov::NodeVector& nodes) const { if (get_axis() >= input.get_partial_shape().size()) { NGRAPH_DEBUG << "Axis calculated to materialize RIC on input: " << input << " is out of range"; return; @@ -84,7 +84,7 @@ public: } auto gather = std::make_shared(output, create_1d_const(order), create_1d_const({get_axis()})); input.replace_source_output(gather); - // TODO: copy runtime info from RIC sub-graph (ticket 88597) + ov::copy_runtime_info(nodes, gather); } bool can_be_fused() const { @@ -190,14 +190,26 @@ void erase(T port) { } // namespace ric_attr namespace init { + +namespace { + +void add_node_with_inputs_to_vector(const std::shared_ptr& node, NodeVector& vector) { + vector.push_back(node); + const auto& inputs = node->inputs(); + for (const auto& input : inputs) { + vector.push_back(input.get_source_output().get_node_shared_ptr()); + } +} + +} // namespace class SplitConcat : public ngraph::pass::MatcherPass { public: - SplitConcat() { + SplitConcat(NodeVector& nodes_to_fuse) { MATCHER_SCOPE(SplitConcat); auto split_p = pattern::wrap_type(); auto pattern_root = pattern::wrap_type({split_p, split_p, split_p}); - auto callback = [=](pattern::Matcher& m) { + auto callback = [=, &nodes_to_fuse](pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto concat = ov::as_type_ptr(pattern_map.at(pattern_root).get_node_shared_ptr()); auto split = ov::as_type_ptr(pattern_map.at(split_p).get_node_shared_ptr()); @@ -235,6 +247,9 @@ public: // Mark-up RIC output ric_attr::init(concat, order, concat->get_axis()); + + nodes_to_fuse.push_back(concat); + add_node_with_inputs_to_vector(split, nodes_to_fuse); return true; }; @@ -245,14 +260,14 @@ public: class Gather : public ngraph::pass::MatcherPass { public: - Gather() { + Gather(NodeVector& nodes_to_fuse) { MATCHER_SCOPE(Gather); auto input_p = pattern::any_input(pattern::has_static_rank()); auto indices_p = pattern::any_input(); auto axis_p = pattern::wrap_type(); auto pattern_root = pattern::wrap_type({input_p, indices_p, axis_p}); - auto callback = [=](pattern::Matcher& m) { + auto callback = [=, &nodes_to_fuse](pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); const auto& output = pattern_map.at(pattern_root); @@ -261,9 +276,10 @@ public: return false; const auto axis_value = axis->cast_vector().at(0); - - if (ov::is_preprocesing_node(output.get_node_shared_ptr())) { + auto gather = output.get_node_shared_ptr(); + if (ov::is_preprocesing_node(gather)) { ric_attr::init(output, {}, axis_value); + add_node_with_inputs_to_vector(gather, nodes_to_fuse); return true; } @@ -291,6 +307,7 @@ public: return false; } ric_attr::init(output, order_values, axis_value); + add_node_with_inputs_to_vector(gather, nodes_to_fuse); return true; }; @@ -567,17 +584,17 @@ bool need_to_erase_ric(const Output& output) { class InsertReverseInputChannel : public ngraph::pass::MatcherPass { public: - InsertReverseInputChannel() { + InsertReverseInputChannel(NodeVector& fused_nodes) { MATCHER_SCOPE(InsertReverseInputChannel); auto pattern_root = pattern::any_input(); - auto callback = [](pattern::Matcher& m) { + auto callback = [&fused_nodes](pattern::Matcher& m) { const auto& node = m.get_match_root(); for (const auto& input : node->inputs()) { if (!ric_attr::has(input)) continue; const auto& ric = ric_attr::get(input); if (ric.can_be_fused() && ric.is_final()) { - ric.materialize(input); + ric.materialize(input, fused_nodes); } } return false; @@ -803,10 +820,11 @@ bool ngraph::pass::ReverseInputChannelsFusion::run_on_model(const std::shared_pt Manager m; m.set_per_pass_validation(false); + NodeVector nodes_to_fuse; // First we need to initialize and propagate RIC attributes through entire graph auto ric_prop = m.register_pass(); - ric_prop->add_matcher(); - ric_prop->add_matcher(); + ric_prop->add_matcher(nodes_to_fuse); + ric_prop->add_matcher(nodes_to_fuse); ric_prop->add_matcher(); ric_prop->add_matcher(); ric_prop->add_matcher(); @@ -824,7 +842,7 @@ bool ngraph::pass::ReverseInputChannelsFusion::run_on_model(const std::shared_pt // Second we fuse available RIC into nodes and remove original nodes related to fused RIC auto ric_fuse = m.register_pass(); - ric_fuse->add_matcher(); + ric_fuse->add_matcher(nodes_to_fuse); ric_fuse->add_matcher(); ric_fuse->add_matcher(); diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp index d4b3096fa35..357731428b8 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp @@ -144,7 +144,6 @@ TEST_F(TransformationTestsF, RICFusionSimple) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -201,7 +200,6 @@ TEST_F(TransformationTestsF, RICFusionHard) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -226,7 +224,6 @@ TEST_F(TransformationTestsF, RICFusionDynamic) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); } TEST_F(TransformationTestsF, RICFusionEltwise1) { @@ -251,7 +248,6 @@ TEST_F(TransformationTestsF, RICFusionEltwise1) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -276,7 +272,6 @@ TEST_F(TransformationTestsF, RICFusionEltwise2) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -301,7 +296,6 @@ TEST_F(TransformationTestsF, RICFusionEltwise3) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -327,7 +321,6 @@ TEST_F(TransformationTestsF, RICFusionEltwise4) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -353,7 +346,6 @@ TEST_F(TransformationTestsF, RICFusionEltwise5) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -393,7 +385,6 @@ TEST_F(TransformationTestsF, RICFusionEltwiseTwoRIC) { } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -434,7 +425,6 @@ TEST_F(TransformationTestsF, RICFusionGroupConv) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -476,7 +466,6 @@ TEST_F(TransformationTestsF, RICFusionTranspose) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -502,7 +491,6 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -541,7 +529,6 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -581,7 +568,6 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -607,7 +593,6 @@ TEST_F(TransformationTestsF, RICFusionShapeOf) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - disable_rt_info_check(); } TEST_F(TransformationTestsF, RICFusionGatherDetectionNegative) { @@ -752,7 +737,6 @@ TEST_F(TransformationTestsF, FuseScaleValue) { function_ref = std::make_shared(NodeVector{ conv }, ParameterVector{ input }); } - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -782,7 +766,6 @@ TEST_F(TransformationTestsF, FuseScaleValues) { function_ref = std::make_shared(NodeVector{ conv }, ParameterVector{ input }); } - disable_rt_info_check(); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -834,7 +817,6 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiply) { apply_reverse_input_channels(function, {{0, "NCHW"}}); } manager.register_pass(); - disable_rt_info_check(); { auto parameter = std::make_shared(element::f32, Shape{1, 3, 14, 14}); std::shared_ptr activations = std::make_shared(parameter, @@ -882,7 +864,6 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyGroupConv) { apply_reverse_input_channels(function, {{0, "NCHW"}}); } manager.register_pass(); - disable_rt_info_check(); { auto data = std::make_shared(element::f32, data_shape); std::shared_ptr weights = opset8::Constant::create(element::f32, Shape{3, 3, 1, 4, 4}, {-2}); @@ -934,7 +915,6 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyNegative1) { apply_reverse_input_channels(function, {{0, "NCHW"}}); } manager.register_pass(); - disable_rt_info_check(); { auto parameter = std::make_shared(element::f32, Shape{1, 3, 14, 14}); std::shared_ptr activations = std::make_shared(parameter, @@ -994,7 +974,6 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyNegativeBroadcast) { apply_reverse_input_channels(function, {{0, "NCHW"}}); } manager.register_pass(); - disable_rt_info_check(); { auto parameter = std::make_shared(element::f32, Shape{1, 3, 14, 14}); std::shared_ptr activations = std::make_shared(parameter, @@ -1075,7 +1054,6 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyNonScalarFQInput) { apply_reverse_input_channels(function, {{0, "NCHW"}}); } manager.register_pass(); - disable_rt_info_check(); { auto parameter = std::make_shared(element::f32, Shape{1, 3, 14, 14}); auto gather = create_gather(std::make_shared(element::f32, Shape{1, 3, 14, 14}), {2, 1, 0}, 1); @@ -1147,7 +1125,6 @@ TEST_F(TransformationTestsF, RICFusionTwoConvolutions) { apply_reverse_input_channels(function, {{0, "NCHW"}}); manager.register_pass(); - disable_rt_info_check(); } { auto conv1_with_gather = create_conv_with_gather(input, create_weights({3, 3, 1, 1}), {2, 1, 0}); @@ -1167,7 +1144,6 @@ TEST_F(TransformationTestsF, RICFusionTwoConvolutionsTheSameWeights) { apply_reverse_input_channels(function, {{0, "NCHW"}}); manager.register_pass(); - disable_rt_info_check(); } { auto conv1_with_gather = create_conv_with_gather(input, weights, {2, 1, 0});