Files
openvino/docs/HOWTO/mo_extensions/front/tf/Complex.py
Artyom Anokhov 0f1fa2cbde Update copyrights with 2023 year (#15149)
* Update copyrights with 2023 year

* Updated more files.
2023-02-02 15:47:25 +01:00

45 lines
1.6 KiB
Python

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [complex:transformation]
import logging as log
import numpy as np
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph
class Complex(FrontReplacementSubgraph):
enabled = True
def pattern(self):
return dict(
nodes=[
('strided_slice_real', dict(op='StridedSlice')),
('strided_slice_imag', dict(op='StridedSlice')),
('complex', dict(op='Complex')),
],
edges=[
('strided_slice_real', 'complex', {'in': 0}),
('strided_slice_imag', 'complex', {'in': 1}),
])
@staticmethod
def replace_sub_graph(graph: Graph, match: dict):
strided_slice_real = match['strided_slice_real']
strided_slice_imag = match['strided_slice_imag']
complex_node = match['complex']
# make sure that both strided slice operations get the same data as input
assert strided_slice_real.in_port(0).get_source() == strided_slice_imag.in_port(0).get_source()
# identify the output port of the operation producing datat for strided slice nodes
input_node_output_port = strided_slice_real.in_port(0).get_source()
input_node_output_port.disconnect()
# change the connection so now all consumers of "complex_node" get data from input node of strided slice nodes
complex_node.out_port(0).get_connection().set_source(input_node_output_port)
#! [complex:transformation]