fix node validation

This commit is contained in:
Tikhonov Ivan 2023-02-23 10:31:56 +00:00
parent fa9fe34c16
commit 3baf0c7900
7 changed files with 33 additions and 97 deletions

View File

@ -111,7 +111,4 @@ ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
const ov::AxisVector& transpose_axis_order,
const std::shared_ptr<ov::opset9::Constant>& axis);
void ValidateForward(const std::shared_ptr<ov::Node>& main_node);
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node);
} // namespace transpose_sinking

View File

@ -23,14 +23,15 @@ using namespace transpose_sinking;
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
MATCHER_SCOPE(TransposeSinkingBinaryForward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>(IfNodeHasTransposeInputs);
auto main_node_label =
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
// todo: support dynamic rank case
@ -38,11 +39,11 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
if (!updated) {
return false;
}
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
ValidateForward(main_node);
return true;
};
@ -74,12 +75,11 @@ ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);
}
main_node->validate_and_infer_types();
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
ValidateBackward(main_node);
return true;
};

View File

@ -43,16 +43,17 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
return false;
}
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
concat_node->set_concatenation_axis(-1);
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
concat_node->set_concatenation_axis(-1);
ValidateForward(main_node);
return true;
};
@ -88,17 +89,16 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
concat_node->set_concatenation_axis(-1);
ValidateBackward(main_node);
concat_node->validate_and_infer_types();
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
return true;
};

View File

@ -55,7 +55,6 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
ValidateForward(main_node);
return true;
};
@ -99,7 +98,7 @@ ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() {
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
ValidateBackward(main_node);
main_node->validate_and_infer_types();
return true;
};

View File

@ -175,8 +175,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
new_split_axis_const);
// remove split output transposes
split->validate_and_infer_types();
RemoveSingleOutputConsumers(split);
ValidateBackward(split);
return true;
};
@ -207,11 +207,6 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const size_t transposed_split_axis = transpose_axis_order[split_axis];
auto new_split_axis_const =
@ -219,7 +214,13 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
split->input(1).replace_source_output(new_split_axis_const);
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
new_split_axis_const);
ValidateForward(main_node);
split->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
return true;
};

View File

@ -27,7 +27,7 @@ using NodePair = std::pair<NodePtr, NodePtr>;
* @param second_node first node pointer
* @return NodePair pair of nodes in new order that allows to register them in MatcherPass
*/
NodePair SwapNodes(NodePtr first_node, NodePtr second_node) {
NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
auto second_node_inputs = second_node->input_values();
second_node_inputs[0] = first_node->input_value(0);
@ -45,49 +45,9 @@ NodePair SwapNodes(NodePtr first_node, NodePtr second_node) {
return std::make_pair(new_first_node, new_second_node);
}
/**
* @brief SwapOutputs has much better performance than SwapNodes and covers the most of the real situations
* but cannot work when the consumers count greater than one
* @param first_node first node pointer
* @param second_node second node pointer
* @return NodePair pair of nodes in new order that allows to register them in MatcherPass
*/
NodePair SwapOutputs(NodePtr first_node, NodePtr second_node) {
const auto first_node_output_names = first_node->output(0).get_names();
const auto second_node_output_names = second_node->output(0).get_names();
auto swap_names = [&]() {
const std::string first_name = first_node->get_friendly_name();
first_node->set_friendly_name(second_node->get_friendly_name());
second_node->set_friendly_name(first_name);
first_node->output(0).set_names(second_node_output_names);
second_node->output(0).set_names(first_node_output_names);
};
auto out_1 = first_node->input_value(0);
second_node->input(0).replace_source_output(out_1);
auto out_2 = second_node->output(0);
second_node->output(0).replace(first_node->output(0));
first_node->input(0).replace_source_output(out_2);
second_node->validate_and_infer_types();
first_node->validate_and_infer_types();
swap_names();
return std::make_pair(second_node, first_node);
}
NodePair Swap(NodePtr first_node, NodePtr second_node) {
NodePair new_nodes;
if (first_node->output(0).get_target_inputs().size() > 1 || second_node->output(0).get_target_inputs().size() > 1)
new_nodes = SwapNodes(first_node, second_node);
else
new_nodes = SwapOutputs(first_node, second_node);
new_nodes = SwapNodes(first_node, second_node);
return new_nodes;
}
@ -111,7 +71,6 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
register_new_node(new_nodes.second);
UpdateForwardSinkingAbility(new_nodes.second);
ValidateForward(unary_label);
return true;
};
@ -143,7 +102,6 @@ ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBack
register_new_node(new_nodes.first);
register_new_node(new_nodes.second);
ValidateBackward(unary_label);
return true;
};
@ -183,10 +141,9 @@ ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBack
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
register_new_node(new_node);
}
unary->validate_and_infer_types();
// remove output transposes
RemoveSingleOutputConsumers(unary);
ValidateBackward(unary_label);
return true;
};

View File

@ -202,24 +202,22 @@ void RemoveInputNode(const NodePtr& main_node, size_t input_idx) {
NodeVector InsertOutputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty())
return {};
auto new_transpose_order = AlignTransposeOrder(main_node->output(0), transpose_input_info);
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
NodeVector new_nodes;
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
auto new_transpose_const =
std::make_shared<Constant>(transpose_element_type, Shape{new_transpose_order.size()}, new_transpose_order);
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
Shape{transpose_axis_order.size()},
transpose_axis_order);
auto main_node_consumers = main_node->output(i).get_target_inputs();
auto new_transpose = std::make_shared<Transpose>(main_node->output(i), new_transpose_const);
for (auto& consumer : main_node_consumers) {
consumer.replace_source_output(new_transpose);
}
copy_runtime_info(main_node, {new_transpose, new_transpose_const});
SwapOutputNames(main_node->output(i), new_transpose->output(0));
if (main_node->get_output_size() > 1)
new_transpose->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i));
else
@ -386,20 +384,4 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
}
}
void ValidateForward(const std::shared_ptr<ov::Node>& main_node) {
main_node->validate_and_infer_types();
for (const auto& out : main_node->outputs()) {
for (const auto consumer : out.get_target_inputs()) {
consumer.get_node()->validate_and_infer_types();
}
}
}
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node) {
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
main_node->get_input_node_shared_ptr(i)->validate_and_infer_types();
}
main_node->validate_and_infer_types();
}
} // namespace transpose_sinking