Tf 2.0 od api models with loop (#4772)
* Do not run transformations for TF 2.X OD API models recursively (needed for models with Loop operation) * Added anchor front transformation to group all TF OD API transformations. Added new necessary dependencies from KerasRNN transformations related to While support * Added JSON configuration files for TF 2.4 OD API SSD and EfficientDet models * Updated documentation with table of supported TF 2.x OD API models * Improved visualization of the dependency graph * Updated version of the pre-processing transformation for TF 2.4 OD API models * Fixes in the TF 2.x OD API models conversion * Fixed order of applying mean/scale values for TF 2.X OD API pre-processing * Updates to the documentation * Fixes for the preprocessor block transformation for the TF OD API models * Added code comments * Fixed bom file * Unit tests for the TF 2.4 OD API ObjectDetectionAPIPreprocessor2Replacement transformation * Code cleanup * Updates to the documentation on how to convert TF OD API models and graph dumper * Added assert to make sure that operations in the `get_specific_ops_with_const_inputs` has exactly 2 inputs
This commit is contained in:
parent
26a4022672
commit
522ad39a48
@ -35,11 +35,11 @@ Detailed information on how to convert models from the <a href="https://github.c
|
||||
|VGG-16| [vgg_16_2016_08_28.tar.gz](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz)| [103.94,116.78,123.68] | 1 |
|
||||
|VGG-19| [vgg_19_2016_08_28.tar.gz](http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz)| [103.94,116.78,123.68] | 1 |
|
||||
|
||||
**Supported Frozen Topologies from TensorFlow Object Detection Models Zoo**
|
||||
**Supported Pre-Trained Topologies from TensorFlow 1 Object Detection Models Zoo**
|
||||
|
||||
Detailed information on how to convert models from the <a href="https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md">Object Detection Models Zoo</a> is available in the [Converting TensorFlow Object Detection API Models](tf_specific/Convert_Object_Detection_API_Models.md) chapter. The table below contains models from the Object Detection Models zoo that are supported.
|
||||
Detailed information on how to convert models from the <a href="https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md">TensorFlow 1 Detection Model Zoo</a> is available in the [Converting TensorFlow Object Detection API Models](tf_specific/Convert_Object_Detection_API_Models.md) chapter. The table below contains models from the Object Detection Models zoo that are supported.
|
||||
|
||||
| Model Name| TensorFlow Object Detection API Models (Frozen)|
|
||||
| Model Name| TensorFlow 1 Object Detection API Models|
|
||||
| :------------- | -----:|
|
||||
|SSD MobileNet V1 COCO\*| [ssd_mobilenet_v1_coco_2018_01_28.tar.gz](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz)|
|
||||
|SSD MobileNet V1 0.75 Depth COCO| [ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03.tar.gz](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03.tar.gz)|
|
||||
@ -68,6 +68,43 @@ Detailed information on how to convert models from the <a href="https://github.c
|
||||
|Faster R-CNN Inception ResNet V2 Low Proposals Open Images\*| [faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid_2018_01_28.tar.gz](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid_2018_01_28.tar.gz)|
|
||||
|Faster R-CNN ResNet 101 AVA v2.1\*| [faster_rcnn_resnet101_ava_v2.1_2018_04_30.tar.gz](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_ava_v2.1_2018_04_30.tar.gz)|
|
||||
|
||||
**Supported Pre-Trained Topologies from TensorFlow 2 Object Detection Models Zoo**
|
||||
|
||||
Detailed information on how to convert models from the <a href="https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md">TensorFlow 2 Detection Model Zoo</a> is available in the [Converting TensorFlow Object Detection API Models](tf_specific/Convert_Object_Detection_API_Models.md) chapter. The table below contains models from the Object Detection Models zoo that are supported.
|
||||
|
||||
| Model Name| TensorFlow 2 Object Detection API Models|
|
||||
| :------------- | -----:|
|
||||
| EfficientDet D0 512x512 | [efficientdet_d0_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d0_coco17_tpu-32.tar.gz)|
|
||||
| EfficientDet D1 640x640 | [efficientdet_d1_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d1_coco17_tpu-32.tar.gz)|
|
||||
| EfficientDet D2 768x768 | [efficientdet_d2_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d2_coco17_tpu-32.tar.gz)|
|
||||
| EfficientDet D3 896x896 | [efficientdet_d3_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d3_coco17_tpu-32.tar.gz)|
|
||||
| EfficientDet D4 1024x1024 | [efficientdet_d4_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d4_coco17_tpu-32.tar.gz)|
|
||||
| EfficientDet D5 1280x1280 | [efficientdet_d5_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d5_coco17_tpu-32.tar.gz)|
|
||||
| EfficientDet D6 1280x1280 | [efficientdet_d6_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d6_coco17_tpu-32.tar.gz)|
|
||||
| EfficientDet D7 1536x1536 | [efficientdet_d7_coco17_tpu-32.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d7_coco17_tpu-32.tar.gz)|
|
||||
| SSD MobileNet v2 320x320 | [ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz)|
|
||||
| SSD MobileNet V1 FPN 640x640 | [ssd_mobilenet_v1_fpn_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v1_fpn_640x640_coco17_tpu-8.tar.gz)|
|
||||
| SSD MobileNet V2 FPNLite 320x320 | [ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.tar.gz)|
|
||||
| SSD MobileNet V2 FPNLite 640x640 | [ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz)|
|
||||
| SSD ResNet50 V1 FPN 640x640 (RetinaNet50) | [ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz)|
|
||||
| SSD ResNet50 V1 FPN 1024x1024 (RetinaNet50) | [ssd_resnet50_v1_fpn_1024x1024_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_1024x1024_coco17_tpu-8.tar.gz)|
|
||||
| SSD ResNet101 V1 FPN 640x640 (RetinaNet101) | [ssd_resnet101_v1_fpn_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet101_v1_fpn_640x640_coco17_tpu-8.tar.gz)|
|
||||
| SSD ResNet101 V1 FPN 1024x1024 (RetinaNet101) | [ssd_resnet101_v1_fpn_1024x1024_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet101_v1_fpn_1024x1024_coco17_tpu-8.tar.gz)|
|
||||
| SSD ResNet152 V1 FPN 640x640 (RetinaNet152) | [ssd_resnet152_v1_fpn_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet152_v1_fpn_640x640_coco17_tpu-8.tar.gz)|
|
||||
| SSD ResNet152 V1 FPN 1024x1024 (RetinaNet152) | [ssd_resnet152_v1_fpn_1024x1024_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet152_v1_fpn_1024x1024_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet50 V1 640x640 | [faster_rcnn_resnet50_v1_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet50_v1_640x640_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet50 V1 1024x1024 | [faster_rcnn_resnet50_v1_1024x1024_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet50_v1_1024x1024_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet50 V1 800x1333 | [faster_rcnn_resnet50_v1_800x1333_coco17_gpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet50_v1_800x1333_coco17_gpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet101 V1 640x640 | [faster_rcnn_resnet101_v1_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet101_v1_640x640_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet101 V1 1024x1024 | [faster_rcnn_resnet101_v1_1024x1024_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet101_v1_1024x1024_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet101 V1 800x1333 | [faster_rcnn_resnet101_v1_800x1333_coco17_gpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet101_v1_800x1333_coco17_gpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet152 V1 640x640 | [faster_rcnn_resnet152_v1_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet152_v1_640x640_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet152 V1 1024x1024 | [faster_rcnn_resnet152_v1_1024x1024_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet152_v1_1024x1024_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN ResNet152 V1 800x1333 | [faster_rcnn_resnet152_v1_800x1333_coco17_gpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_resnet152_v1_800x1333_coco17_gpu-8.tar.gz)|
|
||||
| Faster R-CNN Inception ResNet V2 640x640 | [faster_rcnn_inception_resnet_v2_640x640_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_640x640_coco17_tpu-8.tar.gz)|
|
||||
| Faster R-CNN Inception ResNet V2 1024x1024 | [faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz)|
|
||||
| Mask R-CNN Inception ResNet V2 1024x1024 | [mask_rcnn_inception_resnet_v2_1024x1024_coco17_gpu-8.tar.gz](http://download.tensorflow.org/models/object_detection/tf2/20200711/mask_rcnn_inception_resnet_v2_1024x1024_coco17_gpu-8.tar.gz)|
|
||||
|
||||
**Supported Frozen Quantized Topologies**
|
||||
|
||||
The topologies hosted on the TensorFlow\* Lite [site](https://www.tensorflow.org/lite/guide/hosted_models). The frozen model file (`.pb` file) should be fed to the Model Optimizer.
|
||||
|
@ -18,27 +18,30 @@ To convert a TensorFlow\* Object Detection API model, go to the `<INSTALL_DIR>/d
|
||||
* `--input_model <path_to_frozen.pb>` --- File with a pre-trained model (binary or text .pb file after freezing) OR `--saved_model_dir <path_to_saved_model>` for the TensorFlow\* 2 models
|
||||
* `--transformations_config <path_to_subgraph_replacement_configuration_file.json>` --- A subgraph replacement configuration file with transformations description. For the models downloaded from the TensorFlow\* Object Detection API zoo, you can find the configuration files in the `<INSTALL_DIR>/deployment_tools/model_optimizer/extensions/front/tf` directory. Use:
|
||||
* `ssd_v2_support.json` --- for frozen SSD topologies from the models zoo version up to 1.13.X inclusively
|
||||
* `ssd_support_api_v.1.14.json` --- for frozen SSD topologies trained using the TensorFlow\* Object Detection API version 1.14 up to 1.14.X inclusively
|
||||
* `ssd_support_api_v.1.15.json` --- for frozen SSD topologies trained using the TensorFlow\* Object Detection API version 1.15 up to 2.0
|
||||
* `ssd_support_api_v.2.0.json` --- for frozen SSD topologies trained using the TensorFlow\* Object Detection API version 2.0 or higher
|
||||
* `faster_rcnn_support.json` --- for frozen Faster R-CNN topologies from the models zoo
|
||||
* `ssd_support_api_v.1.14.json` --- for SSD topologies trained using the TensorFlow\* Object Detection API version 1.14 up to 1.14.X inclusively
|
||||
* `ssd_support_api_v.1.15.json` --- for SSD topologies trained using the TensorFlow\* Object Detection API version 1.15 up to 2.0
|
||||
* `ssd_support_api_v.2.0.json` --- for SSD topologies trained using the TensorFlow\* Object Detection API version 2.0 up to 2.3.X inclusively
|
||||
* `ssd_support_api_v.2.4.json` --- for SSD topologies trained using the TensorFlow\* Object Detection API version 2.4 or higher
|
||||
* `efficient_det_support_api_v.2.0.json` --- for EfficientDet topologies trained using the TensorFlow\* Object Detection API version 2.0 up to 2.3.X inclusively
|
||||
* `efficient_det_support_api_v.2.4.json` --- for EfficientDet topologies trained using the TensorFlow\* Object Detection API version 2.4 or higher
|
||||
* `faster_rcnn_support.json` --- for Faster R-CNN topologies from the TF 1.X models zoo trained with TensorFlow\* version up to 1.6.X inclusively
|
||||
* `faster_rcnn_support_api_v1.7.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.7.0 up to 1.9.X inclusively
|
||||
* `faster_rcnn_support_api_v1.10.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.10.0 up to 1.12.X inclusively
|
||||
* `faster_rcnn_support_api_v1.13.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.13.X
|
||||
* `faster_rcnn_support_api_v1.14.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.14.0 up to 1.14.X inclusively
|
||||
* `faster_rcnn_support_api_v1.15.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.15.0 up to 2.0
|
||||
* `faster_rcnn_support_api_v2.0.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 2.0 or higher
|
||||
* `mask_rcnn_support.json` --- for frozen Mask R-CNN topologies from the models zoo
|
||||
* `mask_rcnn_support.json` --- for Mask R-CNN topologies from the TF 1.X models zoo trained with TensorFlow\* version 1.9.0 or lower.
|
||||
* `mask_rcnn_support_api_v1.7.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.7.0 up to 1.9.X inclusively
|
||||
* `mask_rcnn_support_api_v1.11.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.11.0 up to 1.12.X inclusively
|
||||
* `mask_rcnn_support_api_v1.13.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.13.0 up to 1.13.X inclusively
|
||||
* `mask_rcnn_support_api_v1.14.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.14.0 up to 1.14.X inclusively
|
||||
* `mask_rcnn_support_api_v1.15.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.15.0 up to 2.0
|
||||
* `mask_rcnn_support_api_v2.0.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 2.0 or higher
|
||||
* `rfcn_support.json` --- for the frozen RFCN topology from the models zoo frozen with TensorFlow\* version 1.9.0 or lower.
|
||||
* `rfcn_support_api_v1.10.json` --- for the frozen RFCN topology from the models zoo frozen with TensorFlow\* version 1.10.0 up to 1.12.X inclusively
|
||||
* `rfcn_support_api_v1.13.json` --- for the frozen RFCN topology from the models zoo frozen with TensorFlow\* version 1.13.X.
|
||||
* `rfcn_support_api_v1.14.json` --- for the frozen RFCN topology from the models zoo frozen with TensorFlow\* version 1.14.0 or higher.
|
||||
* `rfcn_support.json` --- for RFCN topology from the models zoo trained with TensorFlow\* version up to 1.9.X inclusively
|
||||
* `rfcn_support_api_v1.10.json` --- for RFCN topology from the models zoo frozen with TensorFlow\* version 1.10.0 up to 1.12.X inclusively
|
||||
* `rfcn_support_api_v1.13.json` --- for RFCN topology from the models zoo frozen with TensorFlow\* version 1.13.X
|
||||
* `rfcn_support_api_v1.14.json` --- for RFCN topology from the models zoo frozen with TensorFlow\* version 1.14.0 or higher
|
||||
* `--tensorflow_object_detection_api_pipeline_config <path_to_pipeline.config>` --- A special configuration file that describes the topology hyper-parameters and structure of the TensorFlow Object Detection API model. For the models downloaded from the TensorFlow\* Object Detection API zoo, the configuration file is named `pipeline.config`. If you plan to train a model yourself, you can find templates for these files in the [models repository](https://github.com/tensorflow/models/tree/master/research/object_detection/samples/configs).
|
||||
* `--input_shape` (optional) --- A custom input image shape. Refer to [Custom Input Shape](#tf_od_custom_input_shape) for more information how the `--input_shape` parameter is handled for the TensorFlow* Object Detection API models.
|
||||
|
||||
@ -96,7 +99,7 @@ def calculate_shape_keeping_aspect_ratio(H: int, W: int, min_dimension: int, max
|
||||
|
||||
Models with `keep_aspect_ratio_resizer` were trained to recognize object in real aspect ratio, in contrast with most of the classification topologies trained to recognize objects stretched vertically and horizontally as well. By default, the Model Optimizer converts topologies with `keep_aspect_ratio_resizer` to consume a square input image. If the non-square image is provided as input, it is stretched without keeping aspect ratio that results to objects detection quality decrease.
|
||||
|
||||
> **NOTE**: It is highly recommended to specify the `--input_shape` command line parameter for the models with `keep_aspect_ratio_resizer` if the input image dimensions are known in advance.
|
||||
> **NOTE**: It is highly recommended specifying the `--input_shape` command line parameter for the models with `keep_aspect_ratio_resizer` if the input image dimensions are known in advance.
|
||||
|
||||
## Important Notes About Feeding Input Images to the Samples
|
||||
|
||||
|
@ -390,6 +390,7 @@ extensions/front/tf/cumsum_ext.py
|
||||
extensions/front/tf/deconv_ext.py
|
||||
extensions/front/tf/depth_to_space.py
|
||||
extensions/front/tf/efficient_det_support_api_v2.0.json
|
||||
extensions/front/tf/efficient_det_support_api_v2.4.json
|
||||
extensions/front/tf/elementwise_ext.py
|
||||
extensions/front/tf/embedding_segments_sum.py
|
||||
extensions/front/tf/expand_dims_ext.py
|
||||
@ -404,6 +405,7 @@ extensions/front/tf/faster_rcnn_support_api_v1.14.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v1.15.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v1.7.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v2.0.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v2.4.json
|
||||
extensions/front/tf/fifo_queue_v2_ext.py
|
||||
extensions/front/tf/fifo_replacer.py
|
||||
extensions/front/tf/fill_ext.py
|
||||
@ -430,6 +432,7 @@ extensions/front/tf/mask_rcnn_support_api_v1.14.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v1.15.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v1.7.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v2.0.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v2.4.json
|
||||
extensions/front/tf/matmul_ext.py
|
||||
extensions/front/tf/mvn.py
|
||||
extensions/front/tf/mvn_unrolled.py
|
||||
@ -478,6 +481,7 @@ extensions/front/tf/ssd_support.json
|
||||
extensions/front/tf/ssd_support_api_v1.14.json
|
||||
extensions/front/tf/ssd_support_api_v1.15.json
|
||||
extensions/front/tf/ssd_support_api_v2.0.json
|
||||
extensions/front/tf/ssd_support_api_v2.4.json
|
||||
extensions/front/tf/ssd_toolbox_detection_output.json
|
||||
extensions/front/tf/ssd_toolbox_multihead_detection_output.json
|
||||
extensions/front/tf/ssd_v2_support.json
|
||||
|
@ -25,6 +25,7 @@ from extensions.front.split_normalizer import SqueezeAxis
|
||||
from extensions.front.standalone_const_eraser import StandaloneConstEraser
|
||||
from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
|
||||
from extensions.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
|
||||
from extensions.front.tf.KerasRNNTransformation import KerasRNNInputSlicing, KerasRNNOutputConcatenation
|
||||
from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
|
||||
from extensions.front.tf.pad_tf_to_pad import PadTFToPad
|
||||
from extensions.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
|
||||
@ -32,7 +33,7 @@ from extensions.middle.InsertLayoutPropagationTransposes import mark_as_correct_
|
||||
from extensions.ops.DetectionOutput import DetectionOutput
|
||||
from extensions.ops.ReduceOps import ReduceMean
|
||||
from extensions.ops.activation_ops import Sigmoid
|
||||
from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.elementwise import Mul, Sub, Add, Div
|
||||
from extensions.ops.gather import Gather
|
||||
from extensions.ops.parameter import Parameter
|
||||
from extensions.ops.priorbox_clustered import PriorBoxClusteredOp
|
||||
@ -41,10 +42,12 @@ from extensions.ops.psroipooling import PSROIPoolingOp
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.layout import get_batch_dim, get_height_dim, get_width_dim
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.extractor import output_user_data_repack, add_output_ops
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \
|
||||
mark_squeeze_reshape_concat_before_detection_output, add_fake_background_loc, create_op_node_with_second_input
|
||||
mark_squeeze_reshape_concat_before_detection_output, add_fake_background_loc, create_op_node_with_second_input, \
|
||||
create_op_with_const_inputs
|
||||
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph, FrontReplacementFromConfigFileGeneral
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.concat import Concat
|
||||
@ -57,7 +60,7 @@ from mo.ops.roipooling import ROIPooling
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.softmax import Softmax
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.graph import backward_bfs_for_operation, bfs_search, clear_tensor_names_info
|
||||
from mo.utils.graph import backward_bfs_for_operation, bfs_search, clear_tensor_names_info, sub_graph_between_nodes
|
||||
from mo.utils.pipeline_config import PipelineConfig
|
||||
|
||||
missing_param_error = 'To convert the model specify path to the pipeline configuration file which was used to ' \
|
||||
@ -512,19 +515,107 @@ def update_parameter_shape(graph: Graph, match: [SubgraphMatch, None]):
|
||||
return initial_input_node_name, parameter_node
|
||||
|
||||
|
||||
class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
class ObjectDetectionAPITransformationsStart(FrontReplacementPattern):
|
||||
"""
|
||||
The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
|
||||
to applying mean/scaling values are kept.
|
||||
This is a anchor transformation which is used to distinguish TF OD API models related transformations.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIPreprocessorReplacement'
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
return [CropAndResizeReplacement, FakeQuantWithMinMaxVarsToQuantize]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectDetectionAPITransformationsFinish(FrontReplacementPattern):
|
||||
"""
|
||||
This is a anchor transformation which is used to distinguish TF OD API models related transformations.
|
||||
"""
|
||||
enabled = True
|
||||
# cleanup the graph after applying of TF OD API transformations to remove a lot of unconnected nodes to avoid issues
|
||||
# with shape inference
|
||||
force_clean_up = True
|
||||
|
||||
def run_before(self):
|
||||
# PadTFToPad inserts Transpose ops for Pad ops inside the sub-graph corresponding to DetectionOutput.
|
||||
# But the inputs corresponding to padding values is re-used as inputs for newly created Pad node. This input
|
||||
# is removed during removing nodes from the DO sub-graph so the first input to Transpose is missing which
|
||||
# results in TransposeOrderNormalizer transformation failure.
|
||||
return [Pack, TransposeOrderNormalizer, PadTFToPad]
|
||||
return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, StandaloneConstEraser, TFSliceToSliceReplacer,
|
||||
KerasRNNOutputConcatenation, KerasRNNInputSlicing]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
pass
|
||||
|
||||
|
||||
def get_specific_ops_with_const_inputs(first_node: Node, allowed_ops: list, forward: bool = True):
|
||||
"""
|
||||
Returns the list with information about consecutive nodes of operation from "allowed_ops".
|
||||
|
||||
:param first_node: The first node (not included) to start looking for nodes from the "allowed_ops" list
|
||||
:param allowed_ops: list of allowed operations
|
||||
:param forward: flag specifying direction of search
|
||||
:return: list of triplets (Node, const_port_index, const_value)
|
||||
"""
|
||||
node = first_node.out_port(0).get_destination().node if forward else first_node.in_port(0).get_source().node
|
||||
result = [] # (Node, port # with constant input, value)
|
||||
while node.soft_get('op') in allowed_ops:
|
||||
num_in_ports = len(node.in_ports())
|
||||
assert num_in_ports == 2, 'The node "{}" should have exactly 2 inputs, but it has only {}.' \
|
||||
''.format(node.soft_get('name', node.id), num_in_ports)
|
||||
for port in (0, 1):
|
||||
if node.in_port(port).get_source().node.has_valid('value'): # this is a constant input to the node
|
||||
result.append((node, port, node.in_port(port).get_source().node.value.copy()))
|
||||
node = node.out_port(0).get_destination().node if forward else node.in_port(1 - port).get_source().node
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def get_preprocessing_ops(graph: Graph, start_node_id_suffix: str, end_node_id_suffix: str):
|
||||
"""
|
||||
Finds a sequence of pre-processing nodes (Sub, Mul, Div and Add) after the node with the id suffix
|
||||
'end_node_id_suffix' or ending with the node with id suffix 'end_node_id_suffix'.
|
||||
|
||||
:param graph: graph to look for pre-processing ops
|
||||
:param start_node_id_suffix: suffix of the start node name
|
||||
:param end_node_id_suffix: suffix of the end node name
|
||||
:return: the list with pre-processing nodes information and flag specifying nodes position
|
||||
"""
|
||||
start_node = None
|
||||
end_node = None
|
||||
for node in graph.get_op_nodes():
|
||||
if node.id.endswith(start_node_id_suffix):
|
||||
start_node = node
|
||||
if node.id.endswith(end_node_id_suffix):
|
||||
end_node = node
|
||||
|
||||
assert start_node is not None and end_node is not None, \
|
||||
'Failed to find start/end nodes of the pre-processing block. The section of the transformation JSON ' \
|
||||
'configuration file related to "ObjectDetectionAPIPreprocessor2Replacement" transformation should be updated ' \
|
||||
'for this particular model.'
|
||||
allowed_ops = ['Sub', 'Mul', 'Div', 'Add']
|
||||
preprocessing_nodes = get_specific_ops_with_const_inputs(start_node, allowed_ops, False)
|
||||
trailing = False # switch to apply newly created pre-processing nodes after/before start_node/end_node
|
||||
if len(preprocessing_nodes) == 0:
|
||||
preprocessing_nodes = get_specific_ops_with_const_inputs(end_node, allowed_ops, True)
|
||||
trailing = True
|
||||
return preprocessing_nodes, trailing
|
||||
|
||||
|
||||
class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
"""
|
||||
The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
|
||||
to applying mean/scaling values are kept.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIPreprocessorReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPITransformationsFinish]
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPITransformationsStart]
|
||||
|
||||
def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
|
||||
new_nodes_to_remove = match.matched_nodes_names()
|
||||
@ -593,15 +684,33 @@ class ObjectDetectionAPIPreprocessor2Replacement(FrontReplacementFromConfigFileG
|
||||
"""
|
||||
The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
|
||||
to applying mean/scaling values are kept. The transformation is used for TensorFlow 2.X models.
|
||||
|
||||
There are 6 possible cases:
|
||||
1. ... -> Scale -> Start -> Resize -> End -> ...
|
||||
2. ... -> Start -> Resize -> End -> Scale -> ...
|
||||
3. ... -> Start -> Resize -> End -> ...
|
||||
4. ... -> Start -> While (... -> Scale -> Resize -> ...) -> End -> ...
|
||||
5. ... -> Start -> While (... -> Resize -> Scale -> ...) -> End -> ...
|
||||
6. ... -> Start -> While (... -> Resize -> ...) -> End -> ...
|
||||
|
||||
Where:
|
||||
- "Start" - is the node name specified in the transformation configuration file
|
||||
- "End" - is the node name specified in the transformation configuration file
|
||||
- "Scale" - a node or a sequence of element-wise nodes like Mul, Add, Sub or Div with Const input
|
||||
- "While" (... nodes ... ) - a Loop operation with body nodes specified in parentheses
|
||||
- "Resize" - the Resize sub-graph being removed
|
||||
|
||||
The transformation creates a new sub-graph of pre-processing nodes if in the original model it is inside the Loop,
|
||||
or keeps the existing one if they are in the main graph originally.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIPreprocessor2Replacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_before(self):
|
||||
# PadTFToPad inserts Transpose ops for Pad ops inside the sub-graph corresponding to DetectionOutput.
|
||||
# But the inputs corresponding to padding values is re-used as inputs for newly created Pad node. This input
|
||||
# is removed during removing nodes from the DO sub-graph so the first input to Transpose is missing which
|
||||
# results in TransposeOrderNormalizer transformation failure.
|
||||
return [Pack, TransposeOrderNormalizer, PadTFToPad]
|
||||
return [ObjectDetectionAPITransformationsFinish]
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPITransformationsStart]
|
||||
|
||||
def transform_graph(self, graph: Graph, replacement_descriptions: dict):
|
||||
update_parameter_shape(graph, None)
|
||||
@ -609,16 +718,69 @@ class ObjectDetectionAPIPreprocessor2Replacement(FrontReplacementFromConfigFileG
|
||||
start_nodes = replacement_descriptions['start_nodes']
|
||||
end_nodes = replacement_descriptions['end_nodes']
|
||||
|
||||
start_nodes = [node_id for node_id in start_nodes if node_id in graph.nodes]
|
||||
end_nodes = [node_id for node_id in end_nodes if node_id in graph.nodes]
|
||||
|
||||
assert len(start_nodes) >= 1
|
||||
assert start_nodes[0] in graph.nodes
|
||||
input_node = Node(graph, start_nodes[0])
|
||||
start_node = Node(graph, start_nodes[0])
|
||||
|
||||
assert len(end_nodes) >= 1
|
||||
assert end_nodes[0] in graph.nodes
|
||||
output_node = Node(graph, end_nodes[0])
|
||||
end_node = Node(graph, end_nodes[0])
|
||||
|
||||
output_node.out_port(0).get_connection().set_source(input_node.in_port(0).get_source())
|
||||
input_node.in_port(0).disconnect()
|
||||
# determine nodes between specified input and output nodes to check if there is a Loop op among them
|
||||
sub_graph_node_ids = sub_graph_between_nodes(graph, start_nodes, end_nodes, include_control_flow=False,
|
||||
allow_non_reachable_end_nodes=True)
|
||||
|
||||
pre_processing_in_loop = False
|
||||
# If the pre-processing block contains Loop operation then mean and scale value should be obtained from it using
|
||||
# some pre-defined marker nodes existing for all pre-processing blocks.
|
||||
# If there is no Loop then pre-processing nodes are in the main graph and they should be obtained from it
|
||||
loop_nodes_ids = [node_id for node_id in sub_graph_node_ids if graph.node[node_id].get('op') == 'Loop']
|
||||
if len(loop_nodes_ids):
|
||||
assert len(loop_nodes_ids) == 1, 'There should be exactly one Loop node in the pre-processor block.'
|
||||
pre_processing_in_loop = True
|
||||
loop_node = Node(graph, loop_nodes_ids[0])
|
||||
body_graph = loop_node.body
|
||||
# we stick to the nodes with ids 'map/while/Preprocessor/unstack' and 'map/while/Preprocessor/stack' as they
|
||||
# "wrap" nodes performing image resize. The scale/mean values nodes are located strictly before or after
|
||||
# them
|
||||
pre_processing_ops, trailing = get_preprocessing_ops(body_graph,
|
||||
'map/while/Preprocessor/unstack',
|
||||
'map/while/Preprocessor/stack')
|
||||
else:
|
||||
pre_processing_ops, trailing = get_preprocessing_ops(graph, start_node.id, end_node.id)
|
||||
|
||||
if len(pre_processing_ops):
|
||||
# if the pre-processing is applied before the resize then reverse them to be in the topological order
|
||||
if not trailing:
|
||||
pre_processing_ops = list(reversed(pre_processing_ops))
|
||||
|
||||
if pre_processing_in_loop: # case 4 and 5
|
||||
# build a sub-graph containing a sequence of pre_processing_ops if they came from the Loop
|
||||
new_preprocessing_ops = []
|
||||
ops_mapping = {'Add': Add, 'Div': Div, 'Mul': Mul, 'Sub': Sub}
|
||||
for idx in range(len(pre_processing_ops)):
|
||||
origin_node, const_port_ind, value = pre_processing_ops[idx]
|
||||
new_node = create_op_with_const_inputs(graph, ops_mapping[origin_node.op], {const_port_ind: value})
|
||||
if len(new_preprocessing_ops):
|
||||
new_node.in_port(1 - const_port_ind).connect(new_preprocessing_ops[-1].out_port(0))
|
||||
new_preprocessing_ops.append(new_node)
|
||||
|
||||
# replace sub-graph between start and end nodes (including them) with new_preprocessing_ops nodes
|
||||
end_node.out_port(0).get_connection().set_source(new_preprocessing_ops[-1].out_port(0))
|
||||
start_node.in_port(0).get_connection().set_destination(
|
||||
new_preprocessing_ops[0].in_port(new_preprocessing_ops[0].is_in_port_connected(0)))
|
||||
else:
|
||||
if trailing: # case 2
|
||||
# change output of the end_node to be produced with the start node producer
|
||||
source_port = start_node.in_port(0).get_source()
|
||||
source_port.disconnect()
|
||||
end_node.out_port(0).get_connection().set_source(source_port)
|
||||
else: # case 1
|
||||
# change output of the end_node to be produced with the last preprocessing op
|
||||
end_node.out_port(0).get_connection().set_source(pre_processing_ops[-1][0].out_port(0))
|
||||
else: # simply remove the nodes in between start_node and end_node (including them). Case 3 and 6
|
||||
end_node.out_port(0).get_connection().set_source(start_node.in_port(0).get_source())
|
||||
|
||||
print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
|
||||
' applicable) are kept.')
|
||||
@ -633,12 +795,13 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
Refer to the code for more details.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIDetectionOutputReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement, SqueezeAxis, TransposeOrderNormalizer]
|
||||
return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement]
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIProposalReplacement, CropAndResizeReplacement, FakeQuantWithMinMaxVarsToQuantize]
|
||||
return [ObjectDetectionAPIProposalReplacement]
|
||||
|
||||
def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
|
||||
new_nodes_to_remove = match.matched_nodes_names().copy()
|
||||
@ -822,6 +985,10 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
|
||||
class ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
replacement_id = 'ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPITransformationsFinish]
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIProposalReplacement]
|
||||
@ -893,6 +1060,10 @@ class ObjectDetectionAPIMaskRCNNSigmoidReplacement(FrontReplacementFromConfigFil
|
||||
Adds activation with sigmoid function to the end of the network producing masks tensors.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIMaskRCNNSigmoidReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPITransformationsFinish]
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement]
|
||||
@ -919,12 +1090,13 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
|
||||
Refer to comments inside the function for more information about performed actions.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIProposalReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement]
|
||||
|
||||
def run_before(self):
|
||||
return [CropAndResizeReplacement, TransposeOrderNormalizer, Pack]
|
||||
return [ObjectDetectionAPITransformationsFinish]
|
||||
|
||||
def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
|
||||
return {match.output_node(0)[0].id: new_sub_graph['proposal_node'].id}
|
||||
@ -1070,13 +1242,13 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
|
||||
|
||||
class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
replacement_id = 'ObjectDetectionAPISSDPostprocessorReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement,
|
||||
FakeQuantWithMinMaxVarsToQuantize]
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement]
|
||||
|
||||
def run_before(self):
|
||||
return [StandaloneConstEraser, TransposeOrderNormalizer, TFSliceToSliceReplacer]
|
||||
return [ObjectDetectionAPITransformationsFinish]
|
||||
|
||||
def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
|
||||
# the DetectionOutput in IE produces single tensor, but in TF it produces two tensors, so create only one output
|
||||
@ -1215,10 +1387,13 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
|
||||
SecondStageBoxPredictor_1/Conv_1/BiasAdd will be output if it exists in the graph.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIOutputReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPITransformationsStart]
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement,
|
||||
TransposeOrderNormalizer]
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement]
|
||||
|
||||
def transform_graph(self, graph: Graph, replacement_descriptions: dict):
|
||||
if graph.graph['cmd_params'].output is not None:
|
||||
@ -1240,9 +1415,13 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
|
||||
|
||||
class ObjectDetectionAPIPSROIPoolingReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
replacement_id = 'ObjectDetectionAPIPSROIPoolingReplacement'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIProposalReplacement, TransposeOrderNormalizer]
|
||||
return [ObjectDetectionAPIProposalReplacement]
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPITransformationsFinish]
|
||||
|
||||
def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
|
||||
return {match.output_node(0)[0].id: new_sub_graph['output_node'].id}
|
||||
@ -1311,10 +1490,13 @@ class ObjectDetectionAPIConstValueOverride(FrontReplacementFromConfigFileGeneral
|
||||
no more equal to the 'first_stage_max_proposals' saved as a constant.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIConstValueOverride'
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPITransformationsStart]
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement,
|
||||
TransposeOrderNormalizer]
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement]
|
||||
|
||||
def transform_graph(self, graph: Graph, replacement_descriptions: dict):
|
||||
argv = graph.graph['cmd_params']
|
||||
|
@ -17,13 +17,17 @@
|
||||
import unittest
|
||||
|
||||
from generator import generator, generate
|
||||
from unittest.mock import patch
|
||||
|
||||
from extensions.front.tf.ObjectDetectionAPI import calculate_shape_keeping_aspect_ratio, \
|
||||
calculate_placeholder_spatial_shape
|
||||
calculate_placeholder_spatial_shape, ObjectDetectionAPIPreprocessor2Replacement
|
||||
from mo.front.common.partial_infer.utils import float32_array
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.graph.graph import Graph
|
||||
from mo.utils.custom_replacement_config import CustomReplacementDescriptor
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import const, regular_op, result, build_graph, connect_front
|
||||
|
||||
|
||||
class FakePipelineConfig:
|
||||
@ -109,3 +113,198 @@ class TestCalculatePlaceholderSpatialShape(unittest.TestCase):
|
||||
|
||||
def test_missing_input_shape_information(self):
|
||||
self.assertRaises(Error, calculate_placeholder_spatial_shape, self.graph, self.match, self.pipeline_config)
|
||||
|
||||
|
||||
@patch('extensions.front.tf.ObjectDetectionAPI.update_parameter_shape')
|
||||
class TestObjectDetectionAPIPreprocessor2Replacement(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.start_node_name = 'StatefulPartitionedCall/Preprocessor/unstack'
|
||||
self.end_node_name = 'StatefulPartitionedCall/Preprocessor/stack'
|
||||
self.end_node_name2 = 'StatefulPartitionedCall/Preprocessor/stack2'
|
||||
self.loop_start_node_name = 'prefix/map/while/Preprocessor/unstack'
|
||||
self.loop_end_node_name = 'prefix/map/while/Preprocessor/stack'
|
||||
self.mul_const = float32_array([0.025, 0.374, -0.45])
|
||||
self.sub_const = float32_array([2.0, 3.0, 4.0])
|
||||
|
||||
self.nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
|
||||
**regular_op('mul', {'op': 'Mul', 'type': 'Multiply', 'name': 'my_mul'}),
|
||||
**regular_op('sub', {'op': 'Sub', 'type': 'Subtract', 'name': 'my_sub'}),
|
||||
**const('mul_const', self.mul_const),
|
||||
**const('sub_const', self.sub_const),
|
||||
|
||||
**regular_op(self.start_node_name, {'op': 'Identity'}),
|
||||
**regular_op(self.end_node_name, {'op': 'Identity'}),
|
||||
**regular_op(self.end_node_name2, {'op': 'Identity'}),
|
||||
|
||||
**regular_op('loop', {'op': 'Loop', 'body': None}),
|
||||
|
||||
**regular_op('resize', {'type': 'Interpolate'}),
|
||||
**result('result'),
|
||||
}
|
||||
self.replacement_desc = {'start_nodes': [self.start_node_name],
|
||||
'end_nodes': [self.end_node_name, self.end_node_name2]}
|
||||
|
||||
def build_ref_graph(self, preprocessing: bool):
|
||||
if preprocessing:
|
||||
ref_edges = [*connect_front('input', '0:mul'),
|
||||
*connect_front('mul_const', '1:mul'),
|
||||
*connect_front('sub_const', '0:sub'),
|
||||
*connect_front('mul', '1:sub'),
|
||||
*connect_front('sub', 'result'),
|
||||
]
|
||||
else:
|
||||
ref_edges = [*connect_front('input', 'result')]
|
||||
ref_graph = build_graph(self.nodes, ref_edges, nodes_with_edges_only=True)
|
||||
ref_graph.stage = 'front'
|
||||
return ref_graph
|
||||
|
||||
def test_case_1(self, update_parameter_shape_mock):
|
||||
# test for case #1 described in the ObjectDetectionAPIPreprocessor2Replacement
|
||||
update_parameter_shape_mock.return_value = None
|
||||
edges = [*connect_front('input', '0:mul'),
|
||||
*connect_front('mul_const', '1:mul'),
|
||||
*connect_front('sub_const', '0:sub'),
|
||||
*connect_front('mul', '1:sub'),
|
||||
*connect_front('sub', self.start_node_name),
|
||||
*connect_front(self.start_node_name, 'resize'),
|
||||
*connect_front('resize', self.end_node_name),
|
||||
*connect_front(self.end_node_name, 'result'),
|
||||
]
|
||||
graph = build_graph(self.nodes, edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
ObjectDetectionAPIPreprocessor2Replacement().transform_graph(graph, self.replacement_desc)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, self.build_ref_graph(True), 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_case_2(self, update_parameter_shape_mock):
|
||||
# test for case #2 described in the ObjectDetectionAPIPreprocessor2Replacement
|
||||
update_parameter_shape_mock.return_value = None
|
||||
|
||||
edges = [*connect_front('input', self.start_node_name),
|
||||
*connect_front(self.start_node_name, 'resize'),
|
||||
*connect_front('resize', self.end_node_name),
|
||||
*connect_front(self.end_node_name, '0:mul'),
|
||||
*connect_front('mul_const', '1:mul'),
|
||||
*connect_front('sub_const', '0:sub'),
|
||||
*connect_front('mul', '1:sub'),
|
||||
*connect_front('sub', 'result'),
|
||||
]
|
||||
graph = build_graph(self.nodes, edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
ObjectDetectionAPIPreprocessor2Replacement().transform_graph(graph, self.replacement_desc)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, self.build_ref_graph(True), 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_case_3(self, update_parameter_shape_mock):
|
||||
# test for case #3 described in the ObjectDetectionAPIPreprocessor2Replacement
|
||||
update_parameter_shape_mock.return_value = None
|
||||
|
||||
edges = [*connect_front('input', self.start_node_name),
|
||||
*connect_front(self.start_node_name, 'resize'),
|
||||
*connect_front('resize', self.end_node_name),
|
||||
*connect_front(self.end_node_name, 'result'),
|
||||
]
|
||||
graph = build_graph(self.nodes, edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
ObjectDetectionAPIPreprocessor2Replacement().transform_graph(graph, self.replacement_desc)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, self.build_ref_graph(False), 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def build_main_graph(self, pre_processing: str):
|
||||
def build_body_graph(pre_processing: str):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
|
||||
**regular_op('mul', {'op': 'Mul', 'type': 'Multiply', 'name': 'my_body_mul'}),
|
||||
**regular_op('sub', {'op': 'Sub', 'type': 'Subtract', 'name': 'my_body_sub'}),
|
||||
**const('body_mul_const', self.mul_const),
|
||||
**const('body_sub_const', self.sub_const),
|
||||
|
||||
**regular_op(self.loop_start_node_name, {'op': 'Identity'}),
|
||||
**regular_op(self.loop_end_node_name, {'op': 'Identity'}),
|
||||
|
||||
**regular_op('resize', {'type': 'Interpolate'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = None
|
||||
if pre_processing == 'no':
|
||||
edges = [*connect_front('input', self.loop_start_node_name),
|
||||
*connect_front(self.loop_start_node_name, 'resize'),
|
||||
*connect_front('resize', self.loop_end_node_name),
|
||||
*connect_front(self.loop_end_node_name, 'result'),
|
||||
]
|
||||
elif pre_processing == 'trailing':
|
||||
edges = [*connect_front('input', self.loop_start_node_name),
|
||||
*connect_front(self.loop_start_node_name, 'resize'),
|
||||
*connect_front('resize', self.loop_end_node_name),
|
||||
*connect_front(self.loop_end_node_name, '0:mul'),
|
||||
*connect_front('body_mul_const', '1:mul'),
|
||||
*connect_front('body_sub_const', '0:sub'),
|
||||
*connect_front('mul', '1:sub'),
|
||||
*connect_front('sub', 'result'),
|
||||
]
|
||||
else:
|
||||
edges = [*connect_front('input', '0:mul'),
|
||||
*connect_front('body_mul_const', '1:mul'),
|
||||
*connect_front('body_sub_const', '0:sub'),
|
||||
*connect_front('mul', '1:sub'),
|
||||
*connect_front('sub', self.loop_start_node_name),
|
||||
*connect_front(self.loop_start_node_name, 'resize'),
|
||||
*connect_front('resize', self.loop_end_node_name),
|
||||
*connect_front(self.loop_end_node_name, 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
return graph
|
||||
|
||||
edges = [*connect_front('input', self.start_node_name),
|
||||
*connect_front(self.start_node_name, 'loop'),
|
||||
*connect_front('loop:0', self.end_node_name),
|
||||
*connect_front('loop:1', self.end_node_name2),
|
||||
*connect_front(self.end_node_name, 'result'),
|
||||
]
|
||||
graph = build_graph(self.nodes, edges, {'loop': {'body': build_body_graph(pre_processing)}},
|
||||
nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
return graph
|
||||
|
||||
def test_case_4(self, update_parameter_shape_mock):
|
||||
# test for case #4 described in the ObjectDetectionAPIPreprocessor2Replacement
|
||||
update_parameter_shape_mock.return_value = None
|
||||
|
||||
graph = self.build_main_graph('leading')
|
||||
|
||||
ObjectDetectionAPIPreprocessor2Replacement().transform_graph(graph, self.replacement_desc)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, self.build_ref_graph(True), 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_case_5(self, update_parameter_shape_mock):
|
||||
# test for case #5 described in the ObjectDetectionAPIPreprocessor2Replacement
|
||||
update_parameter_shape_mock.return_value = None
|
||||
|
||||
graph = self.build_main_graph('trailing')
|
||||
|
||||
ObjectDetectionAPIPreprocessor2Replacement().transform_graph(graph, self.replacement_desc)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, self.build_ref_graph(True), 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_case_6(self, update_parameter_shape_mock):
|
||||
# test for case #6 described in the ObjectDetectionAPIPreprocessor2Replacement
|
||||
update_parameter_shape_mock.return_value = None
|
||||
|
||||
graph = self.build_main_graph('no')
|
||||
|
||||
ObjectDetectionAPIPreprocessor2Replacement().transform_graph(graph, self.replacement_desc)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, self.build_ref_graph(False), 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
@ -0,0 +1,50 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor"],
|
||||
"end_nodes": ["StatefulPartitionedCall/map/TensorArrayV2Stack/TensorListStack",
|
||||
"StatefulPartitionedCall/map/TensorArrayV2Stack_1/TensorListStack"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"code_type": "caffe.PriorBoxParameter.CENTER_SIZE",
|
||||
"pad_mode": "caffe.ResizeParameter.CONSTANT",
|
||||
"resize_mode": "caffe.ResizeParameter.WARP",
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"disable_prior_boxes_layers_generator": true
|
||||
},
|
||||
"id": "ObjectDetectionAPISSDPostprocessorReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/Identity",
|
||||
"StatefulPartitionedCall/Identity_1",
|
||||
"StatefulPartitionedCall/Identity_2",
|
||||
"StatefulPartitionedCall/Identity_3",
|
||||
"StatefulPartitionedCall/Identity_4",
|
||||
"StatefulPartitionedCall/Identity_5",
|
||||
"StatefulPartitionedCall/Identity_6",
|
||||
"StatefulPartitionedCall/Identity_7"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/Postprocessor/Reshape_1",
|
||||
"StatefulPartitionedCall/Postprocessor/scale_logits",
|
||||
"StatefulPartitionedCall/Postprocessor/Tile",
|
||||
"StatefulPartitionedCall/Postprocessor/Cast_1"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/Identity,StatefulPartitionedCall/Identity_1,StatefulPartitionedCall/Identity_2,StatefulPartitionedCall/Identity_3,StatefulPartitionedCall/Identity_4,StatefulPartitionedCall/Identity_5,StatefulPartitionedCall/Identity_6,StatefulPartitionedCall/Identity_7"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -0,0 +1,82 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor"],
|
||||
"end_nodes": ["StatefulPartitionedCall/map/TensorArrayV2Stack/TensorListStack",
|
||||
"StatefulPartitionedCall/map/TensorArrayV2Stack_1/TensorListStack"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true
|
||||
},
|
||||
"id": "ObjectDetectionAPIProposalReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/stack_3",
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression/stack_10",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/concat/concat",
|
||||
"StatefulPartitionedCall/concat_1/concat",
|
||||
"StatefulPartitionedCall/GridAnchorGenerator/Identity",
|
||||
"StatefulPartitionedCall/Cast_1",
|
||||
"StatefulPartitionedCall/Cast_2",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"background_label_id": 0,
|
||||
"coordinates_swap_method": "swap_weights"
|
||||
},
|
||||
"id": "ObjectDetectionAPIDetectionOutputReplacement",
|
||||
"inputs": [
|
||||
[
|
||||
{
|
||||
"node": "Reshape$",
|
||||
"port": 0
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"node": "Reshape_1$",
|
||||
"port": 0
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"node": "ExpandDims$",
|
||||
"port": 0
|
||||
}
|
||||
]
|
||||
],
|
||||
"instances": [
|
||||
".*SecondStagePostprocessor/"
|
||||
],
|
||||
"match_kind": "scope",
|
||||
"outputs": [
|
||||
{
|
||||
"node": "Cast_3$",
|
||||
"port": 0
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/SecondStagePostprocessor/Cast_3"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -0,0 +1,91 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor"],
|
||||
"end_nodes": ["StatefulPartitionedCall/map/TensorArrayV2Stack/TensorListStack",
|
||||
"StatefulPartitionedCall/map/TensorArrayV2Stack_1/TensorListStack"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true
|
||||
},
|
||||
"id": "ObjectDetectionAPIProposalReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/stack_3",
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression/stack_10",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/concat/concat",
|
||||
"StatefulPartitionedCall/concat_1/concat",
|
||||
"StatefulPartitionedCall/GridAnchorGenerator/Identity",
|
||||
"StatefulPartitionedCall/Cast_1",
|
||||
"StatefulPartitionedCall/Cast_2",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"background_label_id": 0,
|
||||
"coordinates_swap_method": "swap_weights"
|
||||
},
|
||||
"id": "ObjectDetectionAPIDetectionOutputReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression_1/stack_8",
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression_1/stack_6"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/Reshape_4",
|
||||
"StatefulPartitionedCall/Reshape_5",
|
||||
"StatefulPartitionedCall/ExpandDims_6",
|
||||
"StatefulPartitionedCall/Cast_5"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/Reshape_10"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/CropAndResize_1/CropAndResize",
|
||||
"StatefulPartitionedCall/CropAndResize_1/Reshape"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"masks_node_prefix_name": "StatefulPartitionedCall/mask_rcnn_keras_box_predictor/mask_rcnn_mask_head/"
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNSigmoidReplacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/mask_rcnn_keras_box_predictor/mask_rcnn_mask_head/MaskPredictor_last_conv2d/BiasAdd,StatefulPartitionedCall/Reshape_13"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -2,8 +2,7 @@
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/Preprocessor/unstack"],
|
||||
"end_nodes": ["StatefulPartitionedCall/Preprocessor/stack",
|
||||
"StatefulPartitionedCall/Preprocessor/stack_1"]
|
||||
"end_nodes": ["StatefulPartitionedCall/Preprocessor/stack"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
|
@ -0,0 +1,50 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor"],
|
||||
"end_nodes": ["StatefulPartitionedCall/map/TensorArrayV2Stack/TensorListStack",
|
||||
"StatefulPartitionedCall/map/TensorArrayV2Stack_1/TensorListStack"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"code_type": "caffe.PriorBoxParameter.CENTER_SIZE",
|
||||
"pad_mode": "caffe.ResizeParameter.CONSTANT",
|
||||
"resize_mode": "caffe.ResizeParameter.WARP",
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"disable_prior_boxes_layers_generator": true
|
||||
},
|
||||
"id": "ObjectDetectionAPISSDPostprocessorReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/Identity",
|
||||
"StatefulPartitionedCall/Identity_1",
|
||||
"StatefulPartitionedCall/Identity_2",
|
||||
"StatefulPartitionedCall/Identity_3",
|
||||
"StatefulPartitionedCall/Identity_4",
|
||||
"StatefulPartitionedCall/Identity_5",
|
||||
"StatefulPartitionedCall/Identity_6",
|
||||
"StatefulPartitionedCall/Identity_7"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/Postprocessor/Reshape_1",
|
||||
"StatefulPartitionedCall/Postprocessor/scale_logits",
|
||||
"StatefulPartitionedCall/Postprocessor/Tile",
|
||||
"StatefulPartitionedCall/Postprocessor/Cast_1"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/Identity,StatefulPartitionedCall/Identity_1,StatefulPartitionedCall/Identity_2,StatefulPartitionedCall/Identity_3,StatefulPartitionedCall/Identity_4,StatefulPartitionedCall/Identity_5,StatefulPartitionedCall/Identity_6,StatefulPartitionedCall/Identity_7"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -21,6 +21,9 @@ from mo.graph.graph import Graph
|
||||
|
||||
class TransformationsConfig(FrontReplacementPattern):
|
||||
enabled = True
|
||||
# do not run this transformation recursively otherwise transformations which are enabled with a configuration file
|
||||
# will be registered multiple times
|
||||
run_not_recursively = True
|
||||
graph_condition = [lambda graph: graph.graph['cmd_params'].transformations_config is not None]
|
||||
|
||||
def run_before(self):
|
||||
|
@ -128,7 +128,7 @@ class DependencyGraph(Graph):
|
||||
if nodes_to_dump is None:
|
||||
nodes_to_dump = self.nodes()
|
||||
string = '\ndigraph {\n'
|
||||
string += 'node [color=lightblue2, style=filled];\n'
|
||||
string += 'node [color=lightblue2, style=filled, shape=box];\n'
|
||||
|
||||
for node in nodes_to_dump:
|
||||
attrs = ""
|
||||
|
@ -118,7 +118,7 @@ def is_connected_component(graph: Graph, node_names: list):
|
||||
|
||||
|
||||
def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None,
|
||||
include_control_flow=True):
|
||||
include_control_flow=True, allow_non_reachable_end_nodes=False):
|
||||
"""
|
||||
Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
|
||||
added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
|
||||
@ -128,6 +128,7 @@ def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, de
|
||||
:param detect_extra_start_node: callable function to add additional nodes to the list of start nodes instead of
|
||||
traversing the graph further. The list of additional start nodes is returned of the function is not None.
|
||||
:param include_control_flow: flag to specify whether to follow the control flow edges or not
|
||||
:param allow_non_reachable_end_nodes: do not fail if the end nodes are not reachable from the start nodes
|
||||
:return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
|
||||
"""
|
||||
sub_graph_nodes = list()
|
||||
@ -162,7 +163,7 @@ def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, de
|
||||
for start_node in start_nodes:
|
||||
graph.dfs(start_node, forward_visited)
|
||||
for end_node in end_nodes:
|
||||
if end_node not in forward_visited:
|
||||
if not allow_non_reachable_end_nodes and end_node not in forward_visited:
|
||||
raise Error('End node "{}" is not reachable from start nodes: {}. '.format(end_node, start_nodes) +
|
||||
refer_to_faq_msg(74))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user