StridesOptimization - remove strides property from current node's target inputs (#13054)

StridesOptimization propagates strides attributes up the graph.
This attribute is kept in Input<Node> runtime info.
There is a case in CUDA plugin, where StridesOptimization is called
twice and if strides attribute is kept in node's runtime info,
the second run of this transformation tries to propagate strides
once again which can result in shape inference failure.
This commit is contained in:
Mateusz Tabaka 2022-09-16 00:00:53 +02:00 committed by GitHub
parent 2d6528c75f
commit 101462100b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 3 deletions

View File

@ -12,14 +12,22 @@
#include "openvino/core/runtime_attribute.hpp"
namespace ov {
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);
TRANSFORMATIONS_API ngraph::Strides get_strides_prop(const ngraph::Input<ngraph::Node>& node);
TRANSFORMATIONS_API void insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides);
TRANSFORMATIONS_API void remove_strides_prop(ngraph::Input<ngraph::Node>& node);
class TRANSFORMATIONS_API StridesPropagation : public ov::RuntimeAttribute {
public:
OPENVINO_RTTI("strides_propagation", "0");
StridesPropagation() = default;
StridesPropagation(const ngraph::Strides& value) : value{value} {}
bool is_copyable() const override {
return false;
}
ngraph::Strides value;
};
} // namespace ov

View File

@ -77,7 +77,7 @@ static void insert_pooling(const ngraph::Output<ngraph::Node>& first,
second.replace_source_output(new_node);
}
static void handle_not_equal_stride_props(std::vector<ngraph::Input<ngraph::Node>>&& next_ops) {
static void handle_not_equal_stride_props(std::vector<ngraph::Input<ngraph::Node>>& next_ops) {
for (auto& op : next_ops) {
if (!has_strides_prop(op))
continue;
@ -96,6 +96,12 @@ static void handle_not_equal_stride_props(std::vector<ngraph::Input<ngraph::Node
}
}
static void remove_strides_property_from_nodes(std::vector<ngraph::Input<ngraph::Node>>& nodes) {
for (auto& node : nodes) {
remove_strides_prop(node);
}
}
ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() {
MATCHER_SCOPE(ConvStridesPropagation);
auto data = pattern::any_input([](const Output<Node>& node) -> bool {
@ -123,7 +129,7 @@ ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() {
std::tie(strides, all_ops_are_valid) = check_next_ops(next_ops);
if (!all_ops_are_valid) {
handle_not_equal_stride_props(std::move(next_ops));
handle_not_equal_stride_props(next_ops);
} else {
std::transform(conv_strides.begin(),
conv_strides.end(),
@ -148,6 +154,8 @@ ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() {
conv->set_strides(conv_strides);
}
remove_strides_property_from_nodes(next_ops);
return true;
};
@ -174,6 +182,8 @@ ngraph::pass::SupportedNodesStridesPropagation::SupportedNodesStridesPropagation
insert_strides_prop(input, strides);
}
remove_strides_property_from_nodes(next_ops);
return true;
};
@ -188,7 +198,8 @@ ngraph::pass::UnsupportedNodesStridesPropagation::UnsupportedNodesStridesPropaga
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto node = m.get_match_root();
auto next_ops = op::util::get_node_target_inputs(node);
handle_not_equal_stride_props(std::move(next_ops));
handle_not_equal_stride_props(next_ops);
remove_strides_property_from_nodes(next_ops);
return true;
};

View File

@ -15,3 +15,11 @@ ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
node.get_rt_info().emplace(StridesPropagation::get_type_info_static(), StridesPropagation{strides});
}
void ov::remove_strides_prop(ngraph::Input<ngraph::Node>& node) {
auto& rt_info = node.get_rt_info();
auto it = rt_info.find(StridesPropagation::get_type_info_static());
if (it != rt_info.end()) {
rt_info.erase(it);
}
}