Added extractor for ONNX operation Size (#6031)
* Added extractor for ONNX operation Size * Moved transformation of Size operation from TF specific to generic front phase * Updated list of supported ONNX operation * Moved unit test for Size decomposition to a new location
This commit is contained in:
parent
859a3b8a30
commit
0db9d3e2c5
@ -500,6 +500,7 @@ Standard ONNX\* operators:
|
||||
| Sigmoid | No |
|
||||
| Sign | No |
|
||||
| Sin | No |
|
||||
| Size | No |
|
||||
| Slice | No |
|
||||
| Softmax | No |
|
||||
| Softplus | No |
|
||||
|
@ -338,6 +338,7 @@ extensions/front/onnx/roialign_ext.py
|
||||
extensions/front/onnx/roifeatureextractor_ext.py
|
||||
extensions/front/onnx/scatter_ext.py
|
||||
extensions/front/onnx/shape_ext.py
|
||||
extensions/front/onnx/size_ext.py
|
||||
extensions/front/onnx/slice_ext.py
|
||||
extensions/front/onnx/softmax_ext.py
|
||||
extensions/front/onnx/softmaxONNX_to_softmax.py
|
||||
@ -363,6 +364,7 @@ extensions/front/reshape_dim_normalizer.py
|
||||
extensions/front/restore_ports.py
|
||||
extensions/front/RollWithEmptyAxesReplacer.py
|
||||
extensions/front/scatter_normalizer.py
|
||||
extensions/front/SizeReplacer.py
|
||||
extensions/front/softmax.py
|
||||
extensions/front/Softplus_fusion.py
|
||||
extensions/front/softsign_replacer.py
|
||||
@ -483,7 +485,6 @@ extensions/front/tf/roll_ext.py
|
||||
extensions/front/tf/RollRealImagPack.py
|
||||
extensions/front/tf/select_ext.py
|
||||
extensions/front/tf/sign_ext.py
|
||||
extensions/front/tf/SizeReplacer.py
|
||||
extensions/front/tf/slice_ext.py
|
||||
extensions/front/tf/softmax_ext.py
|
||||
extensions/front/tf/softplus_ext.py
|
||||
|
17
model-optimizer/extensions/front/onnx/size_ext.py
Normal file
17
model-optimizer/extensions/front/onnx/size_ext.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.size import Size
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
class SizeExtractor(FrontExtractorOp):
|
||||
op = 'Size'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
Size.update_node_stat(node, {'output_type': np.int64})
|
||||
return cls.enabled
|
@ -9,6 +9,7 @@ from mo.ops.op import Op
|
||||
|
||||
class Size(Op):
|
||||
op = 'Size'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
assert 'output_type' in attrs, 'Size has mandatory `output_type` attribute'
|
||||
|
@ -6,7 +6,7 @@ import unittest
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.front.tf.SizeReplacer import SizeFrontReplacer
|
||||
from extensions.front.SizeReplacer import SizeFrontReplacer
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, \
|
Loading…
Reference in New Issue
Block a user