Add additional validations
This commit is contained in:
parent
81b9e6eece
commit
ef6e141082
@ -110,4 +110,8 @@ void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
|
||||
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
|
||||
const ov::AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::opset9::Constant>& axis);
|
||||
|
||||
void ValidateForward(const std::shared_ptr<ov::Node>& main_node);
|
||||
|
||||
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node);
|
||||
} // namespace transpose_sinking
|
||||
|
@ -330,12 +330,16 @@ ov::pass::TransposeReduction::TransposeReduction() {
|
||||
auto new_transpose_order = std::make_shared<opset6::Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
auto new_const = std::make_shared<opset6::Constant>(reduction_axes->get_element_type(),
|
||||
reduction_axes->get_shape(),
|
||||
new_values);
|
||||
|
||||
auto new_reduction = reduction->clone_with_new_inputs(
|
||||
{transpose->input_value(0), !unsqueeze ? new_const : reduction->input_value(1)});
|
||||
std::shared_ptr<Node> new_reduction;
|
||||
if (!unsqueeze) {
|
||||
auto new_const = std::make_shared<opset6::Constant>(reduction_axes->get_element_type(),
|
||||
reduction_axes->get_shape(),
|
||||
new_values);
|
||||
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
} else {
|
||||
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
|
||||
}
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
|
||||
replace_node(reduction, new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
|
@ -42,7 +42,7 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
ValidateForward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -79,7 +79,7 @@ ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
SwapNames(transpose, main_node);
|
||||
|
||||
ValidateBackward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -52,6 +52,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
ValidateForward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -97,6 +98,7 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
ValidateBackward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -55,6 +55,7 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
ValidateForward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -98,6 +99,7 @@ ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() {
|
||||
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
|
||||
main_node->input(2).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
|
||||
ValidateBackward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -176,7 +176,7 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
|
||||
// remove split output transposes
|
||||
RemoveSingleOutputConsumers(split);
|
||||
|
||||
ValidateBackward(split);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -217,10 +217,9 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
auto new_split_axis_const =
|
||||
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
|
||||
split->input(1).replace_source_output(new_split_axis_const);
|
||||
split->validate_and_infer_types();
|
||||
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
|
||||
new_split_axis_const);
|
||||
|
||||
ValidateForward(main_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -72,6 +72,8 @@ NodePair SwapOutputs(NodePtr first_node, NodePtr second_node) {
|
||||
second_node->output(0).replace(first_node->output(0));
|
||||
|
||||
first_node->input(0).replace_source_output(out_2);
|
||||
second_node->validate_and_infer_types();
|
||||
first_node->validate_and_infer_types();
|
||||
|
||||
swap_names();
|
||||
|
||||
@ -109,7 +111,7 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
||||
register_new_node(new_nodes.second);
|
||||
|
||||
UpdateForwardSinkingAbility(new_nodes.second);
|
||||
|
||||
ValidateForward(unary_label);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -141,7 +143,7 @@ ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBack
|
||||
|
||||
register_new_node(new_nodes.first);
|
||||
register_new_node(new_nodes.second);
|
||||
|
||||
ValidateBackward(unary_label);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -184,7 +186,7 @@ ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBack
|
||||
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(unary);
|
||||
|
||||
ValidateBackward(unary_label);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -386,4 +386,20 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateForward(const std::shared_ptr<ov::Node>& main_node) {
|
||||
main_node->validate_and_infer_types();
|
||||
for (const auto& out : main_node->outputs()) {
|
||||
for (const auto consumer : out.get_target_inputs()) {
|
||||
consumer.get_node()->validate_and_infer_types();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateBackward(const std::shared_ptr<ov::Node>& main_node) {
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
main_node->get_input_node_shared_ptr(i)->validate_and_infer_types();
|
||||
}
|
||||
main_node->validate_and_infer_types();
|
||||
}
|
||||
|
||||
} // namespace transpose_sinking
|
||||
|
Loading…
Reference in New Issue
Block a user