Added extractor for ONNX operation Size (#6031)

* Added extractor for ONNX operation Size

* Moved transformation of Size operation from TF specific to generic front phase

* Updated list of supported ONNX operation

* Moved unit test for Size decomposition to a new location
This commit is contained in:
Evgeny Lazarev 2021-06-04 17:18:01 +03:00 committed by GitHub
parent 859a3b8a30
commit 0db9d3e2c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 22 additions and 2 deletions

View File

@ -500,6 +500,7 @@ Standard ONNX\* operators:
| Sigmoid | No |
| Sign | No |
| Sin | No |
| Size | No |
| Slice | No |
| Softmax | No |
| Softplus | No |

View File

@ -338,6 +338,7 @@ extensions/front/onnx/roialign_ext.py
extensions/front/onnx/roifeatureextractor_ext.py
extensions/front/onnx/scatter_ext.py
extensions/front/onnx/shape_ext.py
extensions/front/onnx/size_ext.py
extensions/front/onnx/slice_ext.py
extensions/front/onnx/softmax_ext.py
extensions/front/onnx/softmaxONNX_to_softmax.py
@ -363,6 +364,7 @@ extensions/front/reshape_dim_normalizer.py
extensions/front/restore_ports.py
extensions/front/RollWithEmptyAxesReplacer.py
extensions/front/scatter_normalizer.py
extensions/front/SizeReplacer.py
extensions/front/softmax.py
extensions/front/Softplus_fusion.py
extensions/front/softsign_replacer.py
@ -483,7 +485,6 @@ extensions/front/tf/roll_ext.py
extensions/front/tf/RollRealImagPack.py
extensions/front/tf/select_ext.py
extensions/front/tf/sign_ext.py
extensions/front/tf/SizeReplacer.py
extensions/front/tf/slice_ext.py
extensions/front/tf/softmax_ext.py
extensions/front/tf/softplus_ext.py

View File

@ -0,0 +1,17 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.size import Size
from mo.front.extractor import FrontExtractorOp
class SizeExtractor(FrontExtractorOp):
op = 'Size'
enabled = True
@classmethod
def extract(cls, node):
Size.update_node_stat(node, {'output_type': np.int64})
return cls.enabled

View File

@ -9,6 +9,7 @@ from mo.ops.op import Op
class Size(Op):
op = 'Size'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
assert 'output_type' in attrs, 'Size has mandatory `output_type` attribute'

View File

@ -6,7 +6,7 @@ import unittest
import numpy as np
from generator import generator, generate
from extensions.front.tf.SizeReplacer import SizeFrontReplacer
from extensions.front.SizeReplacer import SizeFrontReplacer
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, \