From 101462100bea5206e0f7a2a464992da9034641b7 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Fri, 16 Sep 2022 00:00:53 +0200 Subject: [PATCH] StridesOptimization - remove strides property from current node's target inputs (#13054) StridesOptimization propagates strides attributes up the graph. This attribute is kept in Input runtime info. There is a case in CUDA plugin, where StridesOptimization is called twice and if strides attribute is kept in node's runtime info, the second run of this transformation tries to propagate strides once again which can result in shape inference failure. --- .../rt_info/strides_property.hpp | 8 ++++++++ .../strides_optimization.cpp | 17 ++++++++++++++--- .../rt_info/strides_property.cpp | 8 ++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/include/transformations/rt_info/strides_property.hpp b/src/common/transformations/include/transformations/rt_info/strides_property.hpp index 5204543687a..c1ef7b66d6f 100644 --- a/src/common/transformations/include/transformations/rt_info/strides_property.hpp +++ b/src/common/transformations/include/transformations/rt_info/strides_property.hpp @@ -12,14 +12,22 @@ #include "openvino/core/runtime_attribute.hpp" namespace ov { + TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input& node); TRANSFORMATIONS_API ngraph::Strides get_strides_prop(const ngraph::Input& node); TRANSFORMATIONS_API void insert_strides_prop(ngraph::Input& node, const ngraph::Strides& strides); +TRANSFORMATIONS_API void remove_strides_prop(ngraph::Input& node); + class TRANSFORMATIONS_API StridesPropagation : public ov::RuntimeAttribute { public: OPENVINO_RTTI("strides_propagation", "0"); StridesPropagation() = default; StridesPropagation(const ngraph::Strides& value) : value{value} {} + + bool is_copyable() const override { + return false; + } + ngraph::Strides value; }; } // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/strides_optimization.cpp b/src/common/transformations/src/transformations/common_optimizations/strides_optimization.cpp index fb7a0a57e36..297ec60917f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/strides_optimization.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/strides_optimization.cpp @@ -77,7 +77,7 @@ static void insert_pooling(const ngraph::Output& first, second.replace_source_output(new_node); } -static void handle_not_equal_stride_props(std::vector>&& next_ops) { +static void handle_not_equal_stride_props(std::vector>& next_ops) { for (auto& op : next_ops) { if (!has_strides_prop(op)) continue; @@ -96,6 +96,12 @@ static void handle_not_equal_stride_props(std::vector>& nodes) { + for (auto& node : nodes) { + remove_strides_prop(node); + } +} + ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() { MATCHER_SCOPE(ConvStridesPropagation); auto data = pattern::any_input([](const Output& node) -> bool { @@ -123,7 +129,7 @@ ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() { std::tie(strides, all_ops_are_valid) = check_next_ops(next_ops); if (!all_ops_are_valid) { - handle_not_equal_stride_props(std::move(next_ops)); + handle_not_equal_stride_props(next_ops); } else { std::transform(conv_strides.begin(), conv_strides.end(), @@ -148,6 +154,8 @@ ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() { conv->set_strides(conv_strides); } + remove_strides_property_from_nodes(next_ops); + return true; }; @@ -174,6 +182,8 @@ ngraph::pass::SupportedNodesStridesPropagation::SupportedNodesStridesPropagation insert_strides_prop(input, strides); } + remove_strides_property_from_nodes(next_ops); + return true; }; @@ -188,7 +198,8 @@ ngraph::pass::UnsupportedNodesStridesPropagation::UnsupportedNodesStridesPropaga ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { auto node = m.get_match_root(); auto next_ops = op::util::get_node_target_inputs(node); - handle_not_equal_stride_props(std::move(next_ops)); + handle_not_equal_stride_props(next_ops); + remove_strides_property_from_nodes(next_ops); return true; }; diff --git a/src/common/transformations/src/transformations/rt_info/strides_property.cpp b/src/common/transformations/src/transformations/rt_info/strides_property.cpp index 4136de081f0..7f01d1c6be0 100644 --- a/src/common/transformations/src/transformations/rt_info/strides_property.cpp +++ b/src/common/transformations/src/transformations/rt_info/strides_property.cpp @@ -15,3 +15,11 @@ ngraph::Strides ov::get_strides_prop(const ngraph::Input& node) { void ov::insert_strides_prop(ngraph::Input& node, const ngraph::Strides& strides) { node.get_rt_info().emplace(StridesPropagation::get_type_info_static(), StridesPropagation{strides}); } + +void ov::remove_strides_prop(ngraph::Input& node) { + auto& rt_info = node.get_rt_info(); + auto it = rt_info.find(StridesPropagation::get_type_info_static()); + if (it != rt_info.end()) { + rt_info.erase(it); + } +}