Add additional validations

This commit is contained in:
Tikhonov Ivan 2023-02-21 14:34:26 +00:00
parent 81b9e6eece
commit ef6e141082
8 changed files with 42 additions and 13 deletions

View File

@ -110,4 +110,8 @@ void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
const ov::AxisVector& transpose_axis_order,
const std::shared_ptr<ov::opset9::Constant>& axis);
void ValidateForward(const std::shared_ptr<ov::Node>& main_node);
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node);
} // namespace transpose_sinking

View File

@ -330,12 +330,16 @@ ov::pass::TransposeReduction::TransposeReduction() {
auto new_transpose_order = std::make_shared<opset6::Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
auto new_const = std::make_shared<opset6::Constant>(reduction_axes->get_element_type(),
reduction_axes->get_shape(),
new_values);
auto new_reduction = reduction->clone_with_new_inputs(
{transpose->input_value(0), !unsqueeze ? new_const : reduction->input_value(1)});
std::shared_ptr<Node> new_reduction;
if (!unsqueeze) {
auto new_const = std::make_shared<opset6::Constant>(reduction_axes->get_element_type(),
reduction_axes->get_shape(),
new_values);
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
}
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
replace_node(reduction, new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());

View File

@ -42,7 +42,7 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
ValidateForward(main_node);
return true;
};
@ -79,7 +79,7 @@ ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
ValidateBackward(main_node);
return true;
};

View File

@ -52,6 +52,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
concat_node->set_concatenation_axis(-1);
ValidateForward(main_node);
return true;
};
@ -97,6 +98,7 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
concat_node->set_concatenation_axis(-1);
ValidateBackward(main_node);
return true;
};

View File

@ -55,6 +55,7 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
ValidateForward(main_node);
return true;
};
@ -98,6 +99,7 @@ ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() {
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
ValidateBackward(main_node);
return true;
};

View File

@ -176,7 +176,7 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
// remove split output transposes
RemoveSingleOutputConsumers(split);
ValidateBackward(split);
return true;
};
@ -217,10 +217,9 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
auto new_split_axis_const =
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
split->input(1).replace_source_output(new_split_axis_const);
split->validate_and_infer_types();
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
new_split_axis_const);
ValidateForward(main_node);
return true;
};

View File

@ -72,6 +72,8 @@ NodePair SwapOutputs(NodePtr first_node, NodePtr second_node) {
second_node->output(0).replace(first_node->output(0));
first_node->input(0).replace_source_output(out_2);
second_node->validate_and_infer_types();
first_node->validate_and_infer_types();
swap_names();
@ -109,7 +111,7 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
register_new_node(new_nodes.second);
UpdateForwardSinkingAbility(new_nodes.second);
ValidateForward(unary_label);
return true;
};
@ -141,7 +143,7 @@ ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBack
register_new_node(new_nodes.first);
register_new_node(new_nodes.second);
ValidateBackward(unary_label);
return true;
};
@ -184,7 +186,7 @@ ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBack
// remove output transposes
RemoveSingleOutputConsumers(unary);
ValidateBackward(unary_label);
return true;
};

View File

@ -386,4 +386,20 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
}
}
void ValidateForward(const std::shared_ptr<ov::Node>& main_node) {
main_node->validate_and_infer_types();
for (const auto& out : main_node->outputs()) {
for (const auto consumer : out.get_target_inputs()) {
consumer.get_node()->validate_and_infer_types();
}
}
}
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node) {
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
main_node->get_input_node_shared_ptr(i)->validate_and_infer_types();
}
main_node->validate_and_infer_types();
}
} // namespace transpose_sinking