* Allow MO to generate IR with -1 in dimensions * Some fixes to support -1 for StridedSlice operation * Updated TensorArrayGatherV3 shape infer to support dynamic output shape * Several fixes to support undefined dimensions in the Broadcast,Reshape,Slice and Tile * Fixed bug in the normalization transformation of TF NMS to opset NMS * Updated shape infer functions related to StridedSlice and NMS * Updated Select shape inference function to use common shape broadcasting function supporting dynamism * Fixed operation TFResize shape infer function to work correctly for case when model is converted with --disable_nhwc_to_nchw * Dynamic Range and update asserts in NMS * Changed the way how dynamic dimensions are specified. Refactored shape inference functions and common places to use new approach * More fixes to support dynamic shapes * More fixes for support of dynamic shapes * Fixed generation of IR with dynamic dimensions * Allow reading IRs with undefined dimensions * More changes in the IE to support dynamic dimensions * Fixes for Switch, Merge, Concat shape and value infer related to dynamism * Fixed TensorArray related ops to properly handle dynamic dimensions. Fixed StridedSlice infer for case with new_axis * Fixed shape_for_layout function to generate masked array * Fixed shape inference for Convolution and Poolings to support dynamic spatial dimensions * Updated shape infer functions for CTCGreedyDecotder, CTCLoss and Enter * Fixed shape inference with dynamic dimensions for MatMul, Split, Upsample, SpaceToBatch, some fixes for the TI * Fixes for undefined dimensions support for Proposal and DetectionOutput * Fixed ExtractImagePatches, DepthToSpace and RegionYolo shape infer functions to work with partially dynamic dimensions * Changes in tf_window_op_pad_infer to better work with dynamic dimensions * Fixed output shape calculation for StridedSlice operation * More StridedSlice fixes * Fixed resolve_convolution_with_group * Fixed unit tests * Fixed unit tests * Fixed Switch op unit tests * Fixed shape inference for Upsample operation * Updated unit tests for the Concat operation * Fixed eltwise shape infer unit tests * Fixed shape infer tests for Convolution and DetectionOutput ops * Fixed Crop shape infer function tests * Fixed Slice op unit test and minor fix in the shape inference. Fixed emitter * Updated unit test for telemetry and match_shape function for dynamism * Fixed unit test for the DetectionOutput * Added support for the TF ClipByValue operation * Fixed GatherND shape inference for dynamic shapes support * Dynamic shapes support for the MO IR Reader * Fixed BlockLSTM operation to not work as an extractor * Allow to serialize IRs with partially defined shapes * Updated SelectBroadcast transformation to not check shape values * Fixed MO IR comparator * Fixed SS value propagation when slices are dynamic * Do not re-run graph clean-up for ProposalMutation * Fixed InterpolateSequenceToInterpolate transformation to support dynamic dimensions * Fixed Loop iteration count calculation and reading IteratorGetNext shapes * Fixed unit test for serialization * Fixed serialization test * Fixed RandomUniform shape infer * Fixed several transformations related to RNN to respect dynamic output shapes * Fixed Deconvolutin shape calculation for dynamic batch. Eltwise shape infer improvements * Fixed shape infer functions for ExperimentalDetectron ops, reverted changes for NonZero and removed debug prints * Fixed check for dynamism of a list, fixed value propagation for Concat op and remove redundant shape infer for reshape * Update Eltwise value propagation to use np.ma * Fixed ExpandDims shape infer function * Shape infer functions fixes and improvements * Remove Accum op from the MO * Updated activation functions shape infer * Removed unsupported operation Correlation * Fixed shape infers for several functions * Removed unsupported DataAugmentation operation * Fixed shape infer functions for several ops in extensions directory * Removed not-support operation PowerFile * Removed unsupported SpatialTransformer,SimplerNMS and PredictionHeatmap operations * More shape infer functions updates * Merge shape infer fix * Fixed typo * Fixed TensorArraySize shape infer function * Fixed VariadicSplit and Squeeze shape infer * Fixed ONNX models Parameter extractor * Updated Select value propagation for the dynamic case * Fixed ReorgYolo shape infer and test * Removed unnecessary tests * Fixed Tile shape infer * Fixed SparseFillEmptryRows unit tests * Fixed package bom * Added extractor for the TF operation Mod * Fixed value propagation for MatMul operation * Updated Parameter extender to generate shape_array when shape is partially defined only * Fixed BOM file * Fixed issue with the TF OD API models and DetectionOutput op. Now the shape infer function for the DO do not re-infer "num_classes" attribute value if it is already known * Fixed unit test for the DO infer * Fixed num classes calculation for the DO generation for Faster/Mask-RCNN models * Changed NMS op to produce static output shape * Restore dynamic output shape calculation for the NMS for NMS-5 * Fixed CellNormalizer transformation. It should work for static shapes only * RNNCell Op class fixes * Revert some changes * Updated documentation with a list of supported operations * Revert changes * Fixes for the ConstantFill op * Removed redundant SequenceLengthToMask transformation * TensorArray* ops shape infer code style and refactoring * Reverse some unnecessary changes in the ConvolutionNormalizer * Fixes and unit tests for shape_array, compare_shapes, is_fully_defined functions * Implemented shape_insert, shape_delete functions and tests for them * Modified code to use shape_delete function * Added usage of shape_insert function where necessary * Use shape_insert function in many places * Some fixes in shape inference for various ops * Updated shape_delete function to support negative indices * Changes and unit tests for the MatMul infer function * Removed strange code from the TF Merge infer function * Merge op shape infer fixes * Fixed value propagation in the transformation EltwiseInputReshape.py for the dynamic dimension case * Code cleanup * Updated GatherND to support dynamic dimensions * Minor fixes * Fixed shape_insert and shape_delete to support np.int64 and np.int32 types * Updated Upsample operation unit tests with dynamic input shapes * Minor change in the extensions/back/ConvolutionNormalizer.py to make sure that input dimensions are static * Fixed ConvertGroupedStridedSlice transformation and added unit tests * Revert debug changes * Fixed value propagation for Unsqueeze to work with partially defined input values * Typo fix * Added unit tests for the Unsqueeze op shape infer * broadcasting functions changes and unit tests * Fixed Tile value inference for partially defined input tensor * Unit tests for Split and VariadicSplit ops * Fixes for the Concat infer + unit tests * Removed redundant tf_pack shape infer * Fixed Concat value infer and added unit tests * Fixed StridedSlice shape inference for case with dynamic slices * Fixes related to StridedSlice shape infer, changes in tests * Unit tests for the eltwise shape and value infer * Fixed Pad op value propagation to allow dynamic input values to be propagated * Unit test for Pooling dynamic input shape infer * Squeeze op unit tests for dynamic input shape * Added assert to the Squeeze op shape infer for case when squeeze dimension is dynamic value * Added message to the MO when input shapes are dynamic * Convolution dynamic unit test * Removed redundant transformation GroupedConvWeightsNormalize * Removed non-ascii character from the message * Fixed typo in the BOM file * Code style and comment fixes * Fixed copy-paste issue in the DO shape infer function * Fixed setting dynamic shape in the MO command line * Added function to compare tensor with dynamic values. Fixes in the unit tests and shape infer functions * Improved Reshape shape infer + added unit tests * Fixed value propagation for Select op * Renamed several internal functions, minor code fixes. * Code style fixes * Modified condition in the _set_shape method of the Port class to not check shape if the "override_output_shape" attribute is specified * Fixed constant value propagation for ReduceOps when inputs have dynamic values. Added unit test * Fixed shape infer for the Loop for dynamic dimensions case * Fix in the NMS shape infer to avoid ragged numpy array generation. Fixed Scatter shape infer validation * Improved shapes infer for eltwise ops with respect to dynamic dimensions * Changed code comments * Renamed tensor names in the ClipByValueTFTransformation * Changed np.ma.allequal to strict_compare_tensors in the Merge op infer * Chanded np.ma.allequal with strict_compare_tensor. * Fixed Merge op value infer * Fixed debug code * Removed commented line * Updated condition to check for dynamic shapes in the Partial infer to not fail for MxNet models * Improvements to the get_shape_from_slice and is_dynamic_slice functions * Reverted change in the `normalize_slices_attr` for ellipsis mask case * Updated shape conditions in the ScatterNDBase op to support dynamic dimensions * Crop op file refactoring * Set "type" attribute to None for SparseFillEmptyRows op which is not from any opset * Removed unnecessary extractor test * Restored Crop operation type * Removed "type" attribute from the Crop operation and updated the MO code to find Crop by "op" attribute * Fixed If shape infer function to produce dynamic dimensions * Updated If shape and value infer to properly work when condition is static * Fixed fusing transformation check to work with dynamic dimensions. Change comparison in the shape_inference function to not use strict shapes comparison * Optimize imports in the LayerNorm * ConvertGroupedStridedSlice minor fixes related to dynamism support * Fixed ConvertGroupedStridedSlice to properly check if the dimension is sliced
355 lines
17 KiB
Python
355 lines
17 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import logging as log
|
|
import numpy as np
|
|
|
|
from mo.front.common.partial_infer.utils import int64_array, is_fully_defined, dynamic_dimension_value
|
|
from mo.graph.graph import Node, Graph
|
|
from mo.middle.passes.infer import partial_infer
|
|
from mo.ops.op import Op
|
|
|
|
|
|
class If(Op):
|
|
"""
|
|
If operation is an operation which has an input with condition which defines what sub-graph "then" or "else" to be
|
|
executed.
|
|
"""
|
|
op = 'If'
|
|
enabled = False
|
|
|
|
def __init__(self, graph: Graph, attrs: dict):
|
|
base_attrs = {
|
|
'type': self.op,
|
|
'op': self.op,
|
|
'then_graph': None, # an Graph object with a "then" body sub-graph (condition is True)
|
|
'else_graph': None, # an Graph object with a "else" body sub-graph (condition is False)
|
|
'sub_graphs': ['then_graph', 'else_graph'], # built-in attribute with all sub-graphs
|
|
'version': 'opset8',
|
|
'infer': self.infer,
|
|
'type_infer': self.type_infer,
|
|
}
|
|
base_attrs.update(attrs)
|
|
super().__init__(graph, base_attrs, attrs)
|
|
|
|
def port_map_attrs(self):
|
|
return [
|
|
'external_port_id',
|
|
'internal_layer_id'
|
|
]
|
|
|
|
@staticmethod
|
|
def connect_body_input(if_node: Node, condition: bool, if_input_port_idx: int, body_parameter: Node):
|
|
"""
|
|
Update the specified body parameter and connect it with If input
|
|
|
|
:param if_node: the If node
|
|
:param condition: the boolean defining a condition (then/else) graph to add connect the body
|
|
:param if_input_port_idx: the input port index to connect
|
|
:param body_parameter: the body parameter node to connect
|
|
:return: None
|
|
"""
|
|
assert if_node.soft_get('op') == 'If'
|
|
assert body_parameter.soft_get('op') == 'Parameter'
|
|
sub_graph = if_node.then_graph if condition else if_node.else_graph
|
|
assert body_parameter.id in sub_graph
|
|
body_parameter['input_id'] = if_input_port_idx
|
|
|
|
@staticmethod
|
|
def connect_body_output(if_node: Node, condition: bool, if_output_port_idx: int, internal_result: Node):
|
|
"""
|
|
Update the specified output port and connect it with If output
|
|
|
|
:param if_node: the If node
|
|
:param condition: the boolean defining a condition (then/else) graph to add connect the body
|
|
:param if_output_port_idx: the output port index to connect
|
|
:param internal_result: the body Result node to connect
|
|
:return: None
|
|
"""
|
|
assert if_node.soft_get('op') == 'If'
|
|
assert internal_result.soft_get('op') == 'Result'
|
|
sub_graph = if_node.then_graph if condition else if_node.else_graph
|
|
assert internal_result.id in sub_graph
|
|
internal_result['output_id'] = if_output_port_idx
|
|
|
|
@staticmethod
|
|
def update_body_parameters_type(if_node: Node, condition: bool):
|
|
"""
|
|
Update the data type for If body Parameter nodes based on data type of the outer graph nodes producing data
|
|
for them.
|
|
|
|
:param if_node: The If node
|
|
:param condition: the boolean defining a condition (then/else) graph
|
|
:return: None
|
|
"""
|
|
assert if_node.soft_get('type') == 'If'
|
|
|
|
subgraph = if_node.then_graph if condition else if_node.else_graph
|
|
for node in subgraph.get_op_nodes():
|
|
if node.has('input_id'):
|
|
assert node.soft_get('type') == 'Parameter'
|
|
input_port_id = node['input_id']
|
|
input_type = if_node.in_port(input_port_id).get_data_type()
|
|
node.data_type = input_type
|
|
log.debug('Updated data type for the body node with name "{}" with value {}'
|
|
.format(node.name, node.data_type))
|
|
|
|
@staticmethod
|
|
def update_body_parameters_shape(if_node: Node, condition: bool):
|
|
"""
|
|
Update shape for If body parameters.
|
|
|
|
:param if_node: The If node
|
|
:param condition: the boolean defining a condition (then/else) graph to add connect the body
|
|
:return: None
|
|
"""
|
|
subgraph = if_node.then_graph if condition else if_node.else_graph
|
|
for node in subgraph.get_op_nodes():
|
|
if node.has('input_id'):
|
|
assert node.soft_get('type') == 'Parameter'
|
|
input_port_id = node['input_id']
|
|
input_shape = if_node.in_port(input_port_id).data.get_shape()
|
|
if node.soft_get('shape', None) is None:
|
|
node['shape'] = None
|
|
node.shape = input_shape.copy()
|
|
log.debug('Updated shape for the body node with name "{}" with value {}'
|
|
.format(node.soft_get('name', node.soft_get('id')), node.shape))
|
|
|
|
@staticmethod
|
|
def results_mapping_and_finding_fake_outputs(output_nodes_in_subgraph, branch_name, outputs_mapping):
|
|
"""
|
|
This method checked result nodes in subgraph and set map between output from If operation and internal subgraph
|
|
result. Also This method return True if internal graph has fake results.
|
|
|
|
:param output_nodes_in_subgraph: Result node with attribute 'output_id'
|
|
:param branch_name: name of subgraph
|
|
:param outputs_mapping: map between If operation output ID and subgraph results
|
|
|
|
:return: True if all results of subgraph are empty tensors
|
|
"""
|
|
graph_contain_fake_outputs = True
|
|
|
|
for output_node in output_nodes_in_subgraph:
|
|
assert output_node.soft_get('type') == 'Result'
|
|
port_id = output_node['output_id']
|
|
assert port_id in outputs_mapping.keys(), 'Incorrect mapping then_graph outputs with {0} outputs! ' \
|
|
'Can\'t find port with ID {1} in If operation.' \
|
|
.format(output_node.name, port_id)
|
|
outputs_mapping[port_id][branch_name] = output_node
|
|
out_node_shape = output_node.in_port(0).data.get_shape()
|
|
graph_contain_fake_outputs = graph_contain_fake_outputs and np.any(out_node_shape == 0)
|
|
return graph_contain_fake_outputs
|
|
|
|
@staticmethod
|
|
def update_if_output_ports_shape(if_node: Node):
|
|
"""
|
|
Update shape and values for If output ports.
|
|
|
|
:param if_node: The If node to update output ports and shapes
|
|
:return: None
|
|
"""
|
|
node_name = if_node.soft_get('name', if_node.id)
|
|
|
|
then_outputs = [node for node in if_node.then_graph.get_op_nodes() if node.has('output_id')]
|
|
else_outputs = [node for node in if_node.else_graph.get_op_nodes() if node.has('output_id')]
|
|
outputs_mapping = {}
|
|
outputs_number = len(if_node.out_ports())
|
|
|
|
if outputs_number == 0 and len(if_node.out_ports(control_flow=True)) != 0:
|
|
# Some models have if with control flow outputs.
|
|
# These shape inference for such ifs
|
|
# TODO: need to rethink and redo support for control flow edges in if operation
|
|
for node in if_node.out_nodes(control_flow=True).values():
|
|
node.shape = int64_array([])
|
|
return
|
|
|
|
for port_id in if_node.out_ports().keys():
|
|
outputs_mapping[port_id] = {}
|
|
|
|
# variables then_contains_fake_outputs/else_contains_fake_outputs contains True value
|
|
# if all outputs from then_body/else_body have shape [0]. It means then_body/else_body does not return data
|
|
# and further shape_inference for this branch is not possible.
|
|
# TODO: exclude support fake_outputs from this code when we will support shape_inference with empty tensors
|
|
|
|
then_contains_fake_outputs = \
|
|
If.results_mapping_and_finding_fake_outputs(then_outputs, 'then_graph', outputs_mapping)
|
|
else_contains_fake_outputs = \
|
|
If.results_mapping_and_finding_fake_outputs(else_outputs, 'else_graph', outputs_mapping)
|
|
|
|
# use_then_shape is True when else_body or when both bodies do not return data. If use_then_shape is True If's
|
|
# outputs will have the same shapes as then_body results
|
|
use_then_shape = else_contains_fake_outputs or not then_contains_fake_outputs
|
|
|
|
cond_value = if_node.in_port(0).data.get_value()
|
|
|
|
for port_id in outputs_mapping:
|
|
then_else_nodes = outputs_mapping[port_id]
|
|
assert 'then_graph' in then_else_nodes.keys(), 'then_graph does not connect with If.out_port[{0}] ' \
|
|
'in {1} node!'.format(port_id, node_name)
|
|
assert 'else_graph' in then_else_nodes.keys(), 'else_graph does not connect with If.out_port[{0}] ' \
|
|
'in {1} node!'.format(port_id, node_name)
|
|
|
|
then_shape = then_else_nodes['then_graph'].in_port(0).data.get_shape()
|
|
then_value = then_else_nodes['then_graph'].in_port(0).data.get_value()
|
|
else_shape = then_else_nodes['else_graph'].in_port(0).data.get_shape()
|
|
else_value = then_else_nodes['else_graph'].in_port(0).data.get_value()
|
|
|
|
if is_fully_defined(cond_value):
|
|
if cond_value.item() is True:
|
|
if then_value is not None:
|
|
if_node.out_port(port_id).data.set_value(then_value)
|
|
else:
|
|
if_node.out_port(port_id).data.set_shape(then_shape)
|
|
else:
|
|
if else_value is not None:
|
|
if_node.out_port(port_id).data.set_value(else_value)
|
|
else:
|
|
if_node.out_port(port_id).data.set_shape(else_shape)
|
|
else:
|
|
if then_contains_fake_outputs ^ else_contains_fake_outputs:
|
|
# if exactly one of the outputs is fake then use another one
|
|
if_node.out_port(port_id).data.set_shape(then_shape if use_then_shape else else_shape)
|
|
else:
|
|
# find "intersection" which is equal to the dimension value if corresponding dimensions are equal
|
|
# and dynamic otherwise
|
|
assert len(then_shape) == len(else_shape), 'Ranks of "then" and "else" output tensors are ' \
|
|
'different for node {} for port {}'.format(node_name,
|
|
port_id)
|
|
output_shape = [d1 if is_fully_defined(d1) and is_fully_defined(d2) and d1 == d2 else
|
|
dynamic_dimension_value for d1, d2 in zip(then_shape, else_shape)]
|
|
if_node.out_port(port_id).data.set_shape(output_shape)
|
|
|
|
|
|
@staticmethod
|
|
def update_if_output_ports_type(if_node: Node):
|
|
"""
|
|
Update types for If output ports.
|
|
|
|
:param if_node: The If node to update output ports and types
|
|
:return: None
|
|
"""
|
|
then_outputs = [node for node in if_node.then_graph.get_op_nodes() if node.has('output_id')]
|
|
else_outputs = [node for node in if_node.else_graph.get_op_nodes() if node.has('output_id')]
|
|
outputs_mapping = {}
|
|
outputs_number = len(if_node.out_ports())
|
|
assert outputs_number == len(then_outputs), 'Incorrect number outputs in then_graph of If with"' \
|
|
'name {0}! then_graph must has {1} outputs' \
|
|
.format(if_node.name, outputs_number)
|
|
assert outputs_number == len(else_outputs), 'Incorrect number outputs in else_graph of If with"' \
|
|
'name {0}! else_graph must has {1} outputs' \
|
|
.format(if_node.name, outputs_number)
|
|
for port_id in if_node.out_ports().keys():
|
|
outputs_mapping[port_id] = {}
|
|
port_ids = outputs_mapping.keys()
|
|
for then_output_node in then_outputs:
|
|
assert then_output_node.soft_get('type') == 'Result'
|
|
port_id = then_output_node['output_id']
|
|
assert port_id in port_ids, 'Incorrect mapping then_graph outputs with {0} outputs! ' \
|
|
'Can\'t find port with ID {1} in If operation.' \
|
|
.format(then_output_node.name, port_id)
|
|
outputs_mapping[port_id]['then_graph'] = then_output_node
|
|
|
|
for else_output_node in else_outputs:
|
|
assert else_output_node.soft_get('type') == 'Result'
|
|
port_id = else_output_node['output_id']
|
|
assert port_id in port_ids, 'Incorrect mapping then_graph outputs with {0} outputs! ' \
|
|
'Can\'t find port with ID {1} in If operation.' \
|
|
.format(else_output_node.name, port_id)
|
|
outputs_mapping[port_id]['else_graph'] = else_output_node
|
|
|
|
for port_id in outputs_mapping:
|
|
then_else_nodes = outputs_mapping[port_id]
|
|
assert 'then_graph' in then_else_nodes.keys(), 'then_graph does not connect with If.out_port[{0}] ' \
|
|
'in {1} node!'.format(port_id, if_node.name)
|
|
assert 'else_graph' in then_else_nodes.keys(), 'else_graph does not connect with If.out_port[{0}] ' \
|
|
'in {1} node!'.format(port_id, if_node.name)
|
|
then_type = then_else_nodes['then_graph'].in_port(0).get_data_type()
|
|
else_type = then_else_nodes['else_graph'].in_port(0).get_data_type()
|
|
assert then_type == else_type, 'Cannot get type for if.out_port[{0}]! ' \
|
|
'Types in then_graph and else_graph are not equal!'.format(port_id)
|
|
if_node.out_port(port_id).set_data_type(then_type)
|
|
|
|
@staticmethod
|
|
def re_numerate_internal_id_and_get_if_id(if_node):
|
|
"""
|
|
This method is called before IR generation. This method sets internal_layer_id.
|
|
|
|
:param if_node: The If node where is necessary to set internal_layer_id in bodies.
|
|
:return: if_node
|
|
"""
|
|
then_graph_nodes = if_node.then_graph.nodes()
|
|
for idx in range(len(if_node.then_graph.get_op_nodes())):
|
|
then_graph_nodes[idx]['internal_layer_id'] = idx
|
|
else_graph_nodes = if_node.else_graph.nodes()
|
|
for idx in range(len(if_node.else_graph.get_op_nodes())):
|
|
else_graph_nodes[idx]['internal_layer_id'] = idx
|
|
return if_node.node
|
|
|
|
def substitute_ie_attrs(self, new_attrs: dict):
|
|
"""
|
|
Replace standard list of attribute in layer/data by attributes
|
|
delivered by backend_attrs
|
|
"""
|
|
|
|
port_map_attrs = self.port_map_attrs()
|
|
new_attrs.update({
|
|
'IE': [(
|
|
'layer',
|
|
[('id', lambda node: self.re_numerate_internal_id_and_get_if_id(node)), 'name', 'type', 'version'],
|
|
[
|
|
'@ports',
|
|
('then_port_map', [], [
|
|
('@list', lambda node: self.generate_port_map(node, True, 'in'),
|
|
('input', port_map_attrs, [])),
|
|
('@list', lambda node: self.generate_port_map(node, True, 'out'),
|
|
('output', port_map_attrs, [])),
|
|
]),
|
|
('else_port_map', [], [
|
|
('@list', lambda node: self.generate_port_map(node, False, 'in'),
|
|
('input', port_map_attrs, [])),
|
|
('@list', lambda node: self.generate_port_map(node, False, 'out'),
|
|
('output', port_map_attrs, [])),
|
|
]),
|
|
('then_body', [], [('@network', 'then_graph')]),
|
|
('else_body', [], [('@network', 'else_graph')]),
|
|
])]
|
|
})
|
|
|
|
@staticmethod
|
|
def generate_port_map(if_node: Node, condition: bool, dir: str):
|
|
"""
|
|
Extract port_map attributes from if_node and its subgraphs attributes.
|
|
|
|
:param if_node: The If node
|
|
:param condition: the boolean defining a condition (then/else) graph
|
|
:param dir: the str value defining type (for inputs or for putputs) of port_map
|
|
:return: port_map -> list of dictionaries with to values(external_port_id or internal_layer_id)
|
|
"""
|
|
port_map = []
|
|
subgraph = if_node.then_graph if condition else if_node.else_graph
|
|
name_of_connection = 'input_id' if dir == 'in' else 'output_id'
|
|
|
|
for internal_node in subgraph.get_op_nodes():
|
|
if internal_node.has(name_of_connection):
|
|
port_map.append({'external_port_id': internal_node[name_of_connection],
|
|
'internal_layer_id': internal_node['internal_layer_id']})
|
|
|
|
return port_map
|
|
|
|
@staticmethod
|
|
def infer(if_node: Node):
|
|
If.update_body_parameters_shape(if_node, True)
|
|
If.update_body_parameters_shape(if_node, False)
|
|
partial_infer(if_node.then_graph)
|
|
partial_infer(if_node.else_graph)
|
|
If.update_if_output_ports_shape(if_node)
|
|
|
|
@staticmethod
|
|
def type_infer(if_node: Node):
|
|
from mo.middle.passes.infer import type_infer
|
|
If.update_body_parameters_type(if_node, True)
|
|
If.update_body_parameters_type(if_node, False)
|
|
type_infer(if_node.then_graph)
|
|
type_infer(if_node.else_graph)
|
|
If.update_if_output_ports_type(if_node)
|