Enable transformation callback for TS transformations (#16767)

This commit is contained in:
Ivan Tikhonov 2023-04-06 20:42:01 +04:00 committed by GitHub
parent d2deae225a
commit 4812879318
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 66 additions and 1 deletions

View File

@ -32,6 +32,10 @@ TSBinaryForward::TSBinaryForward() {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
// todo: support dynamic rank case
@ -73,6 +77,9 @@ TSBinaryBackward::TSBinaryBackward() {
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);

View File

@ -27,6 +27,9 @@ TSConcatForward::TSConcatForward() {
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
auto concat_node = as_type_ptr<Concat>(main_node);
@ -77,6 +80,10 @@ TSConcatBackward::TSConcatBackward() {
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
auto concat_node = as_type_ptr<Concat>(main_node);
auto concat_axis = concat_node->get_concatenation_axis();
if (concat_axis < 0) {

View File

@ -29,6 +29,10 @@ TSDataMovementForward::TSDataMovementForward() {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
if (transformation_callback(main_node)) {
return false;
}
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose) {
return false;
@ -92,6 +96,9 @@ TSDataMovementBackward::TSDataMovementBackward() {
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,

View File

@ -29,6 +29,10 @@ TSInterpolateForward::TSInterpolateForward() {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
if (transformation_callback(main_node)) {
return false;
}
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose) {
return false;
@ -102,6 +106,9 @@ TSInterpolateBackward::TSInterpolateBackward() {
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,

View File

@ -50,6 +50,10 @@ TSReductionForward::TSReductionForward() {
auto transpose = pattern_to_output.at(transpose_label);
auto reduction = pattern_to_output.at(reduce_label);
if (transformation_callback(reduction)) {
return false;
}
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
@ -107,6 +111,10 @@ TSReductionBackward::TSReductionBackward() {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto reduction = pattern_to_output.at(reduce_label);
if (transformation_callback(reduction)) {
return false;
}
auto keep_dims = get_keep_dims(reduction);
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));

View File

@ -29,6 +29,10 @@ TSSliceForward::TSSliceForward() {
const auto& pattern_to_node = m.get_pattern_map();
auto& main_node = pattern_to_node.at(main_node_label);
if (transformation_callback(main_node)) {
return false;
}
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose || main_node->get_input_size() < 5) {
return false;
@ -84,6 +88,9 @@ TSSliceBackward::TSSliceBackward() {
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
if (main_node->get_input_size() < 5) {
return false;

View File

@ -134,7 +134,7 @@ TSSplitBackward::TSSplitBackward() {
split = FindInputNode<VariadicSplit>(transpose_label_node);
}
if (!split) {
if (!split || transformation_callback(split)) {
return false;
}
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
@ -200,6 +200,10 @@ TSSplitForward::TSSplitForward() {
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
auto split_axis_constant = as_type_ptr<Constant>(main_node->input_value(1).get_node_shared_ptr());
if (!split_axis_constant) {
return false;

View File

@ -109,6 +109,9 @@ TSSqueezeForward::TSSqueezeForward() {
auto transpose = pattern_to_output.at(transpose_label);
auto squeeze = pattern_to_output.at(squeeze_label);
if (transformation_callback(squeeze)) {
return false;
}
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
@ -196,6 +199,9 @@ TSSqueezeBackward::TSSqueezeBackward() {
auto transpose = pattern_to_output.at(transpose_label);
auto squeeze = pattern_to_output.at(squeeze_label);
if (transformation_callback(squeeze)) {
return false;
}
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));

View File

@ -64,6 +64,9 @@ TSUnaryForward::TSUnaryForward() {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
if (transformation_callback(unary)) {
return false;
}
const NodePair new_nodes = SwapNodes(transpose, unary);
@ -105,6 +108,9 @@ TSUnaryBackward::TSUnaryBackward() {
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
if (transformation_callback(unary)) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
register_new_node(new_node);

View File

@ -110,6 +110,9 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
auto transpose = pattern_to_output.at(transpose_label);
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
if (transformation_callback(unsqueeze)) {
return false;
}
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
@ -179,6 +182,9 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
auto transpose = pattern_to_output.at(transpose_label);
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
if (transformation_callback(unsqueeze)) {
return false;
}
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));