StridesOptimization - remove strides property from current node's target inputs (#13054)
StridesOptimization propagates strides attributes up the graph. This attribute is kept in Input<Node> runtime info. There is a case in CUDA plugin, where StridesOptimization is called twice and if strides attribute is kept in node's runtime info, the second run of this transformation tries to propagate strides once again which can result in shape inference failure.
This commit is contained in:
parent
2d6528c75f
commit
101462100b
@ -12,14 +12,22 @@
|
||||
#include "openvino/core/runtime_attribute.hpp"
|
||||
|
||||
namespace ov {
|
||||
|
||||
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);
|
||||
TRANSFORMATIONS_API ngraph::Strides get_strides_prop(const ngraph::Input<ngraph::Node>& node);
|
||||
TRANSFORMATIONS_API void insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides);
|
||||
TRANSFORMATIONS_API void remove_strides_prop(ngraph::Input<ngraph::Node>& node);
|
||||
|
||||
class TRANSFORMATIONS_API StridesPropagation : public ov::RuntimeAttribute {
|
||||
public:
|
||||
OPENVINO_RTTI("strides_propagation", "0");
|
||||
StridesPropagation() = default;
|
||||
StridesPropagation(const ngraph::Strides& value) : value{value} {}
|
||||
|
||||
bool is_copyable() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::Strides value;
|
||||
};
|
||||
} // namespace ov
|
||||
|
@ -77,7 +77,7 @@ static void insert_pooling(const ngraph::Output<ngraph::Node>& first,
|
||||
second.replace_source_output(new_node);
|
||||
}
|
||||
|
||||
static void handle_not_equal_stride_props(std::vector<ngraph::Input<ngraph::Node>>&& next_ops) {
|
||||
static void handle_not_equal_stride_props(std::vector<ngraph::Input<ngraph::Node>>& next_ops) {
|
||||
for (auto& op : next_ops) {
|
||||
if (!has_strides_prop(op))
|
||||
continue;
|
||||
@ -96,6 +96,12 @@ static void handle_not_equal_stride_props(std::vector<ngraph::Input<ngraph::Node
|
||||
}
|
||||
}
|
||||
|
||||
static void remove_strides_property_from_nodes(std::vector<ngraph::Input<ngraph::Node>>& nodes) {
|
||||
for (auto& node : nodes) {
|
||||
remove_strides_prop(node);
|
||||
}
|
||||
}
|
||||
|
||||
ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() {
|
||||
MATCHER_SCOPE(ConvStridesPropagation);
|
||||
auto data = pattern::any_input([](const Output<Node>& node) -> bool {
|
||||
@ -123,7 +129,7 @@ ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() {
|
||||
std::tie(strides, all_ops_are_valid) = check_next_ops(next_ops);
|
||||
|
||||
if (!all_ops_are_valid) {
|
||||
handle_not_equal_stride_props(std::move(next_ops));
|
||||
handle_not_equal_stride_props(next_ops);
|
||||
} else {
|
||||
std::transform(conv_strides.begin(),
|
||||
conv_strides.end(),
|
||||
@ -148,6 +154,8 @@ ngraph::pass::ConvStridesPropagation::ConvStridesPropagation() {
|
||||
conv->set_strides(conv_strides);
|
||||
}
|
||||
|
||||
remove_strides_property_from_nodes(next_ops);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -174,6 +182,8 @@ ngraph::pass::SupportedNodesStridesPropagation::SupportedNodesStridesPropagation
|
||||
insert_strides_prop(input, strides);
|
||||
}
|
||||
|
||||
remove_strides_property_from_nodes(next_ops);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -188,7 +198,8 @@ ngraph::pass::UnsupportedNodesStridesPropagation::UnsupportedNodesStridesPropaga
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto node = m.get_match_root();
|
||||
auto next_ops = op::util::get_node_target_inputs(node);
|
||||
handle_not_equal_stride_props(std::move(next_ops));
|
||||
handle_not_equal_stride_props(next_ops);
|
||||
remove_strides_property_from_nodes(next_ops);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
@ -15,3 +15,11 @@ ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
|
||||
void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
|
||||
node.get_rt_info().emplace(StridesPropagation::get_type_info_static(), StridesPropagation{strides});
|
||||
}
|
||||
|
||||
void ov::remove_strides_prop(ngraph::Input<ngraph::Node>& node) {
|
||||
auto& rt_info = node.get_rt_info();
|
||||
auto it = rt_info.find(StridesPropagation::get_type_info_static());
|
||||
if (it != rt_info.end()) {
|
||||
rt_info.erase(it);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user