Added extractor for ONNX operation Size (#6031)
* Added extractor for ONNX operation Size * Moved transformation of Size operation from TF specific to generic front phase * Updated list of supported ONNX operation * Moved unit test for Size decomposition to a new location
This commit is contained in:
parent
859a3b8a30
commit
0db9d3e2c5
@ -500,6 +500,7 @@ Standard ONNX\* operators:
|
|||||||
| Sigmoid | No |
|
| Sigmoid | No |
|
||||||
| Sign | No |
|
| Sign | No |
|
||||||
| Sin | No |
|
| Sin | No |
|
||||||
|
| Size | No |
|
||||||
| Slice | No |
|
| Slice | No |
|
||||||
| Softmax | No |
|
| Softmax | No |
|
||||||
| Softplus | No |
|
| Softplus | No |
|
||||||
|
@ -338,6 +338,7 @@ extensions/front/onnx/roialign_ext.py
|
|||||||
extensions/front/onnx/roifeatureextractor_ext.py
|
extensions/front/onnx/roifeatureextractor_ext.py
|
||||||
extensions/front/onnx/scatter_ext.py
|
extensions/front/onnx/scatter_ext.py
|
||||||
extensions/front/onnx/shape_ext.py
|
extensions/front/onnx/shape_ext.py
|
||||||
|
extensions/front/onnx/size_ext.py
|
||||||
extensions/front/onnx/slice_ext.py
|
extensions/front/onnx/slice_ext.py
|
||||||
extensions/front/onnx/softmax_ext.py
|
extensions/front/onnx/softmax_ext.py
|
||||||
extensions/front/onnx/softmaxONNX_to_softmax.py
|
extensions/front/onnx/softmaxONNX_to_softmax.py
|
||||||
@ -363,6 +364,7 @@ extensions/front/reshape_dim_normalizer.py
|
|||||||
extensions/front/restore_ports.py
|
extensions/front/restore_ports.py
|
||||||
extensions/front/RollWithEmptyAxesReplacer.py
|
extensions/front/RollWithEmptyAxesReplacer.py
|
||||||
extensions/front/scatter_normalizer.py
|
extensions/front/scatter_normalizer.py
|
||||||
|
extensions/front/SizeReplacer.py
|
||||||
extensions/front/softmax.py
|
extensions/front/softmax.py
|
||||||
extensions/front/Softplus_fusion.py
|
extensions/front/Softplus_fusion.py
|
||||||
extensions/front/softsign_replacer.py
|
extensions/front/softsign_replacer.py
|
||||||
@ -483,7 +485,6 @@ extensions/front/tf/roll_ext.py
|
|||||||
extensions/front/tf/RollRealImagPack.py
|
extensions/front/tf/RollRealImagPack.py
|
||||||
extensions/front/tf/select_ext.py
|
extensions/front/tf/select_ext.py
|
||||||
extensions/front/tf/sign_ext.py
|
extensions/front/tf/sign_ext.py
|
||||||
extensions/front/tf/SizeReplacer.py
|
|
||||||
extensions/front/tf/slice_ext.py
|
extensions/front/tf/slice_ext.py
|
||||||
extensions/front/tf/softmax_ext.py
|
extensions/front/tf/softmax_ext.py
|
||||||
extensions/front/tf/softplus_ext.py
|
extensions/front/tf/softplus_ext.py
|
||||||
|
17
model-optimizer/extensions/front/onnx/size_ext.py
Normal file
17
model-optimizer/extensions/front/onnx/size_ext.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.ops.size import Size
|
||||||
|
from mo.front.extractor import FrontExtractorOp
|
||||||
|
|
||||||
|
|
||||||
|
class SizeExtractor(FrontExtractorOp):
|
||||||
|
op = 'Size'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract(cls, node):
|
||||||
|
Size.update_node_stat(node, {'output_type': np.int64})
|
||||||
|
return cls.enabled
|
@ -9,6 +9,7 @@ from mo.ops.op import Op
|
|||||||
|
|
||||||
class Size(Op):
|
class Size(Op):
|
||||||
op = 'Size'
|
op = 'Size'
|
||||||
|
enabled = False
|
||||||
|
|
||||||
def __init__(self, graph: Graph, attrs: dict):
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
assert 'output_type' in attrs, 'Size has mandatory `output_type` attribute'
|
assert 'output_type' in attrs, 'Size has mandatory `output_type` attribute'
|
||||||
|
@ -6,7 +6,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from generator import generator, generate
|
from generator import generator, generate
|
||||||
|
|
||||||
from extensions.front.tf.SizeReplacer import SizeFrontReplacer
|
from extensions.front.SizeReplacer import SizeFrontReplacer
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, \
|
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, \
|
Loading…
Reference in New Issue
Block a user