This commit is contained in:
Evgeny Kotov 2023-03-21 19:47:27 +01:00
parent cdfd77a415
commit 5c249d98d9

View File

@ -51,6 +51,7 @@ NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
return std::make_pair(new_first_node, new_second_node);
}
NodePtr GetPatternNode(const PatternValueMap& pattern_to_output, const NodeVector& nodes) {
for (const auto& node : nodes) {
auto it = pattern_to_output.find(node);
@ -61,6 +62,7 @@ NodePtr GetPatternNode(const PatternValueMap& pattern_to_output, const NodeVecto
return {};
}
} // namespace
TSUnaryForward::TSUnaryForward() {
MATCHER_SCOPE(TSUnaryForward);
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
@ -82,11 +84,13 @@ TSUnaryForward::TSUnaryForward() {
auto m = std::make_shared<Matcher>(unary_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
namespace {
bool IfSinkingEnabled(const Output<Node>& output) {
return is_sinking_node(output.get_node_shared_ptr());
}
} // namespace
TSUnaryBackward::TSUnaryBackward() {
MATCHER_SCOPE(TSUnaryBackwardMultiConsumers);
auto unary_restrictions = [](const Output<Node>& output) -> bool {