Enable transformation callback for TS transformations (#16767)
This commit is contained in:
parent
d2deae225a
commit
4812879318
@ -32,6 +32,10 @@ TSBinaryForward::TSBinaryForward() {
|
|||||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||||
auto main_node = main_node_output.get_node_shared_ptr();
|
auto main_node = main_node_output.get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||||
|
|
||||||
// todo: support dynamic rank case
|
// todo: support dynamic rank case
|
||||||
@ -73,6 +77,9 @@ TSBinaryBackward::TSBinaryBackward() {
|
|||||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||||
register_new_node(new_node);
|
register_new_node(new_node);
|
||||||
|
@ -27,6 +27,9 @@ TSConcatForward::TSConcatForward() {
|
|||||||
|
|
||||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||||
auto main_node = main_node_output.get_node_shared_ptr();
|
auto main_node = main_node_output.get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||||
@ -77,6 +80,10 @@ TSConcatBackward::TSConcatBackward() {
|
|||||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||||
auto concat_axis = concat_node->get_concatenation_axis();
|
auto concat_axis = concat_node->get_concatenation_axis();
|
||||||
if (concat_axis < 0) {
|
if (concat_axis < 0) {
|
||||||
|
@ -29,6 +29,10 @@ TSDataMovementForward::TSDataMovementForward() {
|
|||||||
const auto& pattern_to_node = m.get_pattern_map();
|
const auto& pattern_to_node = m.get_pattern_map();
|
||||||
|
|
||||||
auto& main_node = pattern_to_node.at(main_node_label);
|
auto& main_node = pattern_to_node.at(main_node_label);
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
||||||
if (!transpose) {
|
if (!transpose) {
|
||||||
return false;
|
return false;
|
||||||
@ -92,6 +96,9 @@ TSDataMovementBackward::TSDataMovementBackward() {
|
|||||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||||
transpose_const,
|
transpose_const,
|
||||||
|
@ -29,6 +29,10 @@ TSInterpolateForward::TSInterpolateForward() {
|
|||||||
const auto& pattern_to_node = m.get_pattern_map();
|
const auto& pattern_to_node = m.get_pattern_map();
|
||||||
|
|
||||||
auto& main_node = pattern_to_node.at(main_node_label);
|
auto& main_node = pattern_to_node.at(main_node_label);
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
||||||
if (!transpose) {
|
if (!transpose) {
|
||||||
return false;
|
return false;
|
||||||
@ -102,6 +106,9 @@ TSInterpolateBackward::TSInterpolateBackward() {
|
|||||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||||
transpose_const,
|
transpose_const,
|
||||||
|
@ -50,6 +50,10 @@ TSReductionForward::TSReductionForward() {
|
|||||||
|
|
||||||
auto transpose = pattern_to_output.at(transpose_label);
|
auto transpose = pattern_to_output.at(transpose_label);
|
||||||
auto reduction = pattern_to_output.at(reduce_label);
|
auto reduction = pattern_to_output.at(reduce_label);
|
||||||
|
if (transformation_callback(reduction)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto keep_dims = get_keep_dims(reduction);
|
auto keep_dims = get_keep_dims(reduction);
|
||||||
|
|
||||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||||
@ -107,6 +111,10 @@ TSReductionBackward::TSReductionBackward() {
|
|||||||
const auto& pattern_to_output = m.get_pattern_map();
|
const auto& pattern_to_output = m.get_pattern_map();
|
||||||
auto transpose = pattern_to_output.at(transpose_label);
|
auto transpose = pattern_to_output.at(transpose_label);
|
||||||
auto reduction = pattern_to_output.at(reduce_label);
|
auto reduction = pattern_to_output.at(reduce_label);
|
||||||
|
if (transformation_callback(reduction)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto keep_dims = get_keep_dims(reduction);
|
auto keep_dims = get_keep_dims(reduction);
|
||||||
|
|
||||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||||
|
@ -29,6 +29,10 @@ TSSliceForward::TSSliceForward() {
|
|||||||
const auto& pattern_to_node = m.get_pattern_map();
|
const auto& pattern_to_node = m.get_pattern_map();
|
||||||
|
|
||||||
auto& main_node = pattern_to_node.at(main_node_label);
|
auto& main_node = pattern_to_node.at(main_node_label);
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
||||||
if (!transpose || main_node->get_input_size() < 5) {
|
if (!transpose || main_node->get_input_size() < 5) {
|
||||||
return false;
|
return false;
|
||||||
@ -84,6 +88,9 @@ TSSliceBackward::TSSliceBackward() {
|
|||||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (main_node->get_input_size() < 5) {
|
if (main_node->get_input_size() < 5) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -134,7 +134,7 @@ TSSplitBackward::TSSplitBackward() {
|
|||||||
split = FindInputNode<VariadicSplit>(transpose_label_node);
|
split = FindInputNode<VariadicSplit>(transpose_label_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!split) {
|
if (!split || transformation_callback(split)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
||||||
@ -200,6 +200,10 @@ TSSplitForward::TSSplitForward() {
|
|||||||
|
|
||||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||||
auto main_node = main_node_output.get_node_shared_ptr();
|
auto main_node = main_node_output.get_node_shared_ptr();
|
||||||
|
if (transformation_callback(main_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto split_axis_constant = as_type_ptr<Constant>(main_node->input_value(1).get_node_shared_ptr());
|
auto split_axis_constant = as_type_ptr<Constant>(main_node->input_value(1).get_node_shared_ptr());
|
||||||
if (!split_axis_constant) {
|
if (!split_axis_constant) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -109,6 +109,9 @@ TSSqueezeForward::TSSqueezeForward() {
|
|||||||
|
|
||||||
auto transpose = pattern_to_output.at(transpose_label);
|
auto transpose = pattern_to_output.at(transpose_label);
|
||||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||||
|
if (transformation_callback(squeeze)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||||
@ -196,6 +199,9 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
|||||||
|
|
||||||
auto transpose = pattern_to_output.at(transpose_label);
|
auto transpose = pattern_to_output.at(transpose_label);
|
||||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||||
|
if (transformation_callback(squeeze)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||||
|
@ -64,6 +64,9 @@ TSUnaryForward::TSUnaryForward() {
|
|||||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
|
if (transformation_callback(unary)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
const NodePair new_nodes = SwapNodes(transpose, unary);
|
const NodePair new_nodes = SwapNodes(transpose, unary);
|
||||||
|
|
||||||
@ -105,6 +108,9 @@ TSUnaryBackward::TSUnaryBackward() {
|
|||||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
|
if (transformation_callback(unary)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
|
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
|
||||||
register_new_node(new_node);
|
register_new_node(new_node);
|
||||||
|
@ -110,6 +110,9 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
|||||||
|
|
||||||
auto transpose = pattern_to_output.at(transpose_label);
|
auto transpose = pattern_to_output.at(transpose_label);
|
||||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||||
|
if (transformation_callback(unsqueeze)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||||
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||||
@ -179,6 +182,9 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
|||||||
|
|
||||||
auto transpose = pattern_to_output.at(transpose_label);
|
auto transpose = pattern_to_output.at(transpose_label);
|
||||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||||
|
if (transformation_callback(unsqueeze)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||||
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||||
|
Loading…
Reference in New Issue
Block a user