fix node validation
This commit is contained in:
parent
fa9fe34c16
commit
3baf0c7900
@ -111,7 +111,4 @@ ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
|
||||
const ov::AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::opset9::Constant>& axis);
|
||||
|
||||
void ValidateForward(const std::shared_ptr<ov::Node>& main_node);
|
||||
|
||||
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node);
|
||||
} // namespace transpose_sinking
|
||||
|
@ -23,14 +23,15 @@ using namespace transpose_sinking;
|
||||
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryForward);
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>(IfNodeHasTransposeInputs);
|
||||
auto main_node_label =
|
||||
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
// todo: support dynamic rank case
|
||||
@ -38,11 +39,11 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
if (!updated) {
|
||||
return false;
|
||||
}
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
ValidateForward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -74,12 +75,11 @@ ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
SwapNames(transpose, main_node);
|
||||
ValidateBackward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -43,16 +43,17 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
ValidateForward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -88,17 +89,16 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
SwapNames(transpose, main_node);
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
ValidateBackward(main_node);
|
||||
concat_node->validate_and_infer_types();
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
SwapNames(transpose, main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -55,7 +55,6 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
ValidateForward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -99,7 +98,7 @@ ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() {
|
||||
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
|
||||
main_node->input(2).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
|
||||
ValidateBackward(main_node);
|
||||
main_node->validate_and_infer_types();
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -175,8 +175,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
new_split_axis_const);
|
||||
|
||||
// remove split output transposes
|
||||
split->validate_and_infer_types();
|
||||
RemoveSingleOutputConsumers(split);
|
||||
ValidateBackward(split);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -207,11 +207,6 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const size_t transposed_split_axis = transpose_axis_order[split_axis];
|
||||
auto new_split_axis_const =
|
||||
@ -219,7 +214,13 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
split->input(1).replace_source_output(new_split_axis_const);
|
||||
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
|
||||
new_split_axis_const);
|
||||
ValidateForward(main_node);
|
||||
split->validate_and_infer_types();
|
||||
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -27,7 +27,7 @@ using NodePair = std::pair<NodePtr, NodePtr>;
|
||||
* @param second_node first node pointer
|
||||
* @return NodePair pair of nodes in new order that allows to register them in MatcherPass
|
||||
*/
|
||||
NodePair SwapNodes(NodePtr first_node, NodePtr second_node) {
|
||||
NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
|
||||
auto second_node_inputs = second_node->input_values();
|
||||
second_node_inputs[0] = first_node->input_value(0);
|
||||
|
||||
@ -45,49 +45,9 @@ NodePair SwapNodes(NodePtr first_node, NodePtr second_node) {
|
||||
return std::make_pair(new_first_node, new_second_node);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief SwapOutputs has much better performance than SwapNodes and covers the most of the real situations
|
||||
* but cannot work when the consumers count greater than one
|
||||
* @param first_node first node pointer
|
||||
* @param second_node second node pointer
|
||||
* @return NodePair pair of nodes in new order that allows to register them in MatcherPass
|
||||
*/
|
||||
NodePair SwapOutputs(NodePtr first_node, NodePtr second_node) {
|
||||
const auto first_node_output_names = first_node->output(0).get_names();
|
||||
const auto second_node_output_names = second_node->output(0).get_names();
|
||||
|
||||
auto swap_names = [&]() {
|
||||
const std::string first_name = first_node->get_friendly_name();
|
||||
first_node->set_friendly_name(second_node->get_friendly_name());
|
||||
second_node->set_friendly_name(first_name);
|
||||
|
||||
first_node->output(0).set_names(second_node_output_names);
|
||||
second_node->output(0).set_names(first_node_output_names);
|
||||
};
|
||||
|
||||
auto out_1 = first_node->input_value(0);
|
||||
second_node->input(0).replace_source_output(out_1);
|
||||
|
||||
auto out_2 = second_node->output(0);
|
||||
second_node->output(0).replace(first_node->output(0));
|
||||
|
||||
first_node->input(0).replace_source_output(out_2);
|
||||
second_node->validate_and_infer_types();
|
||||
first_node->validate_and_infer_types();
|
||||
|
||||
swap_names();
|
||||
|
||||
return std::make_pair(second_node, first_node);
|
||||
}
|
||||
|
||||
NodePair Swap(NodePtr first_node, NodePtr second_node) {
|
||||
NodePair new_nodes;
|
||||
|
||||
if (first_node->output(0).get_target_inputs().size() > 1 || second_node->output(0).get_target_inputs().size() > 1)
|
||||
new_nodes = SwapNodes(first_node, second_node);
|
||||
else
|
||||
new_nodes = SwapOutputs(first_node, second_node);
|
||||
|
||||
new_nodes = SwapNodes(first_node, second_node);
|
||||
return new_nodes;
|
||||
}
|
||||
|
||||
@ -111,7 +71,6 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
||||
register_new_node(new_nodes.second);
|
||||
|
||||
UpdateForwardSinkingAbility(new_nodes.second);
|
||||
ValidateForward(unary_label);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -143,7 +102,6 @@ ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBack
|
||||
|
||||
register_new_node(new_nodes.first);
|
||||
register_new_node(new_nodes.second);
|
||||
ValidateBackward(unary_label);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -183,10 +141,9 @@ ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBack
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
unary->validate_and_infer_types();
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(unary);
|
||||
ValidateBackward(unary_label);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -202,24 +202,22 @@ void RemoveInputNode(const NodePtr& main_node, size_t input_idx) {
|
||||
NodeVector InsertOutputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty())
|
||||
return {};
|
||||
|
||||
auto new_transpose_order = AlignTransposeOrder(main_node->output(0), transpose_input_info);
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||
|
||||
NodeVector new_nodes;
|
||||
|
||||
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
|
||||
auto new_transpose_const =
|
||||
std::make_shared<Constant>(transpose_element_type, Shape{new_transpose_order.size()}, new_transpose_order);
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
Shape{transpose_axis_order.size()},
|
||||
transpose_axis_order);
|
||||
auto main_node_consumers = main_node->output(i).get_target_inputs();
|
||||
auto new_transpose = std::make_shared<Transpose>(main_node->output(i), new_transpose_const);
|
||||
for (auto& consumer : main_node_consumers) {
|
||||
consumer.replace_source_output(new_transpose);
|
||||
}
|
||||
|
||||
copy_runtime_info(main_node, {new_transpose, new_transpose_const});
|
||||
SwapOutputNames(main_node->output(i), new_transpose->output(0));
|
||||
|
||||
if (main_node->get_output_size() > 1)
|
||||
new_transpose->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i));
|
||||
else
|
||||
@ -386,20 +384,4 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateForward(const std::shared_ptr<ov::Node>& main_node) {
|
||||
main_node->validate_and_infer_types();
|
||||
for (const auto& out : main_node->outputs()) {
|
||||
for (const auto consumer : out.get_target_inputs()) {
|
||||
consumer.get_node()->validate_and_infer_types();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node) {
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
main_node->get_input_node_shared_ptr(i)->validate_and_infer_types();
|
||||
}
|
||||
main_node->validate_and_infer_types();
|
||||
}
|
||||
|
||||
} // namespace transpose_sinking
|
||||
|
Loading…
Reference in New Issue
Block a user