diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp index ef46b4143c2..f4fbeaefb1f 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp @@ -32,6 +32,10 @@ TSBinaryForward::TSBinaryForward() { const auto& pattern_to_output = m.get_pattern_value_map(); auto& main_node_output = pattern_to_output.at(main_node_label); auto main_node = main_node_output.get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } + TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); // todo: support dynamic rank case @@ -73,6 +77,9 @@ TSBinaryBackward::TSBinaryBackward() { auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) { register_new_node(new_node); diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp index 4694b945650..726f40f216c 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp @@ -27,6 +27,9 @@ TSConcatForward::TSConcatForward() { auto& main_node_output = pattern_to_output.at(main_node_label); auto main_node = main_node_output.get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); auto concat_node = as_type_ptr(main_node); @@ -77,6 +80,10 @@ TSConcatBackward::TSConcatBackward() { auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } + auto concat_node = as_type_ptr(main_node); auto concat_axis = concat_node->get_concatenation_axis(); if (concat_axis < 0) { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp index 482841d2c65..c59fb887be5 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp @@ -29,6 +29,10 @@ TSDataMovementForward::TSDataMovementForward() { const auto& pattern_to_node = m.get_pattern_map(); auto& main_node = pattern_to_node.at(main_node_label); + if (transformation_callback(main_node)) { + return false; + } + auto transpose = std::dynamic_pointer_cast(pattern_to_node.at(transpose_label)); if (!transpose) { return false; @@ -92,6 +96,9 @@ TSDataMovementBackward::TSDataMovementBackward() { auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const, diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp index 0a9c2b7458f..88d603aba01 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp @@ -29,6 +29,10 @@ TSInterpolateForward::TSInterpolateForward() { const auto& pattern_to_node = m.get_pattern_map(); auto& main_node = pattern_to_node.at(main_node_label); + if (transformation_callback(main_node)) { + return false; + } + auto transpose = std::dynamic_pointer_cast(pattern_to_node.at(transpose_label)); if (!transpose) { return false; @@ -102,6 +106,9 @@ TSInterpolateBackward::TSInterpolateBackward() { auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const, diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp index a5b608f5b99..1b10c01e696 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp @@ -50,6 +50,10 @@ TSReductionForward::TSReductionForward() { auto transpose = pattern_to_output.at(transpose_label); auto reduction = pattern_to_output.at(reduce_label); + if (transformation_callback(reduction)) { + return false; + } + auto keep_dims = get_keep_dims(reduction); auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); @@ -107,6 +111,10 @@ TSReductionBackward::TSReductionBackward() { const auto& pattern_to_output = m.get_pattern_map(); auto transpose = pattern_to_output.at(transpose_label); auto reduction = pattern_to_output.at(reduce_label); + if (transformation_callback(reduction)) { + return false; + } + auto keep_dims = get_keep_dims(reduction); auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp index de4390f07ee..e8b0f5c663d 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp @@ -29,6 +29,10 @@ TSSliceForward::TSSliceForward() { const auto& pattern_to_node = m.get_pattern_map(); auto& main_node = pattern_to_node.at(main_node_label); + if (transformation_callback(main_node)) { + return false; + } + auto transpose = std::dynamic_pointer_cast(pattern_to_node.at(transpose_label)); if (!transpose || main_node->get_input_size() < 5) { return false; @@ -84,6 +88,9 @@ TSSliceBackward::TSSliceBackward() { auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } if (main_node->get_input_size() < 5) { return false; diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp index 3aeb74436e7..38c5e5c9c28 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp @@ -134,7 +134,7 @@ TSSplitBackward::TSSplitBackward() { split = FindInputNode(transpose_label_node); } - if (!split) { + if (!split || transformation_callback(split)) { return false; } auto split_axis_constant = as_type_ptr(split->input_value(1).get_node_shared_ptr()); @@ -200,6 +200,10 @@ TSSplitForward::TSSplitForward() { auto& main_node_output = pattern_to_output.at(main_node_label); auto main_node = main_node_output.get_node_shared_ptr(); + if (transformation_callback(main_node)) { + return false; + } + auto split_axis_constant = as_type_ptr(main_node->input_value(1).get_node_shared_ptr()); if (!split_axis_constant) { return false; diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp index 8ff816bd3f1..f2954701e34 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp @@ -109,6 +109,9 @@ TSSqueezeForward::TSSqueezeForward() { auto transpose = pattern_to_output.at(transpose_label); auto squeeze = pattern_to_output.at(squeeze_label); + if (transformation_callback(squeeze)) { + return false; + } auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); auto squeeze_axes = as_type_ptr(squeeze->get_input_node_shared_ptr(1)); @@ -196,6 +199,9 @@ TSSqueezeBackward::TSSqueezeBackward() { auto transpose = pattern_to_output.at(transpose_label); auto squeeze = pattern_to_output.at(squeeze_label); + if (transformation_callback(squeeze)) { + return false; + } auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); auto squeeze_axes = as_type_ptr(squeeze->get_input_node_shared_ptr(1)); diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp index ed543ca9d92..e3147d5cf7a 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp @@ -64,6 +64,9 @@ TSUnaryForward::TSUnaryForward() { const auto& pattern_to_output = m.get_pattern_value_map(); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); + if (transformation_callback(unary)) { + return false; + } const NodePair new_nodes = SwapNodes(transpose, unary); @@ -105,6 +108,9 @@ TSUnaryBackward::TSUnaryBackward() { auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); + if (transformation_callback(unary)) { + return false; + } for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) { register_new_node(new_node); diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp index 9d9416d8e38..3b92a2072a7 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp @@ -110,6 +110,9 @@ TSUnsqueezeForward::TSUnsqueezeForward() { auto transpose = pattern_to_output.at(transpose_label); auto unsqueeze = pattern_to_output.at(unsqueeze_label); + if (transformation_callback(unsqueeze)) { + return false; + } auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); auto unsqueeze_axes = as_type_ptr(unsqueeze->get_input_node_shared_ptr(1)); @@ -179,6 +182,9 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() { auto transpose = pattern_to_output.at(transpose_label); auto unsqueeze = pattern_to_output.at(unsqueeze_label); + if (transformation_callback(unsqueeze)) { + return false; + } auto transpose_order = std::dynamic_pointer_cast(transpose->get_input_node_shared_ptr(1)); auto unsqueeze_axes = std::dynamic_pointer_cast(unsqueeze->get_input_node_shared_ptr(1));