[ MO ] Support MXNet operation arange_like (#8939)
* arange_like_op * added comments * added unittests * added step attr, changed axis condition, updated tests * added op description * fix nodes renaming * sorted imports * added case with repeat > 1 * finished arange_like, removed unit test * small fix in gather infer function * gather fix * fix doc * added unittests * correct renames * removed ConvertLike from div_sqrt_dim * used ReduceProd instead reshape-shapeof * added keep_dims attr to reduce_prod node
This commit is contained in:
parent
4ecab1eeea
commit
bc70b2b68b
@ -53,6 +53,7 @@
|
|||||||
| Symbol Name in MXNet\*| Limitations|
|
| Symbol Name in MXNet\*| Limitations|
|
||||||
| :----------| :----------|
|
| :----------| :----------|
|
||||||
| _Plus | |
|
| _Plus | |
|
||||||
|
| _contrib_arange_like | |
|
||||||
| _contrib_box_nms | |
|
| _contrib_box_nms | |
|
||||||
| _contrib_DeformableConvolution | |
|
| _contrib_DeformableConvolution | |
|
||||||
| _contrib_DeformablePSROIPooling | |
|
| _contrib_DeformablePSROIPooling | |
|
||||||
|
@ -268,6 +268,8 @@ openvino/tools/mo/front/mxnet/activation.py
|
|||||||
openvino/tools/mo/front/mxnet/adaptive_avg_pooling_ext.py
|
openvino/tools/mo/front/mxnet/adaptive_avg_pooling_ext.py
|
||||||
openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
|
openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
|
||||||
openvino/tools/mo/front/mxnet/arange_ext.py
|
openvino/tools/mo/front/mxnet/arange_ext.py
|
||||||
|
openvino/tools/mo/front/mxnet/arange_like_ext.py
|
||||||
|
openvino/tools/mo/front/mxnet/arange_like_replacer.py
|
||||||
openvino/tools/mo/front/mxnet/arange_replacer.py
|
openvino/tools/mo/front/mxnet/arange_replacer.py
|
||||||
openvino/tools/mo/front/mxnet/batch_dot_ext.py
|
openvino/tools/mo/front/mxnet/batch_dot_ext.py
|
||||||
openvino/tools/mo/front/mxnet/block_grad_ext.py
|
openvino/tools/mo/front/mxnet/block_grad_ext.py
|
||||||
@ -847,6 +849,7 @@ openvino/tools/mo/ops/__init__.py
|
|||||||
openvino/tools/mo/ops/activation.py
|
openvino/tools/mo/ops/activation.py
|
||||||
openvino/tools/mo/ops/activation_ops.py
|
openvino/tools/mo/ops/activation_ops.py
|
||||||
openvino/tools/mo/ops/adaptive_avg_pooling.py
|
openvino/tools/mo/ops/adaptive_avg_pooling.py
|
||||||
|
openvino/tools/mo/ops/arange_like.py
|
||||||
openvino/tools/mo/ops/argmax.py
|
openvino/tools/mo/ops/argmax.py
|
||||||
openvino/tools/mo/ops/argmin.py
|
openvino/tools/mo/ops/argmin.py
|
||||||
openvino/tools/mo/ops/assert_op.py
|
openvino/tools/mo/ops/assert_op.py
|
||||||
|
25
tools/mo/openvino/tools/mo/front/mxnet/arange_like_ext.py
Normal file
25
tools/mo/openvino/tools/mo/front/mxnet/arange_like_ext.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||||
|
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
|
||||||
|
from openvino.tools.mo.graph.graph import Node
|
||||||
|
from openvino.tools.mo.ops.arange_like import ArangeLikeOp
|
||||||
|
|
||||||
|
|
||||||
|
class ArangeLikeExt(FrontExtractorOp):
|
||||||
|
op = '_contrib_arange_like'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract(cls, node: Node):
|
||||||
|
attrs = get_mxnet_layer_attrs(node.symbol_dict)
|
||||||
|
ArangeLikeOp.update_node_stat(node, {
|
||||||
|
'start': attrs.float('start', 0),
|
||||||
|
'repeat': attrs.int('repeat', 1),
|
||||||
|
'step': attrs.float('step', 1),
|
||||||
|
'axis': attrs.int('axis', None),
|
||||||
|
})
|
||||||
|
return cls.enabled
|
148
tools/mo/openvino/tools/mo/front/mxnet/arange_like_replacer.py
Normal file
148
tools/mo/openvino/tools/mo/front/mxnet/arange_like_replacer.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array
|
||||||
|
from openvino.tools.mo.front.common.replacement import FrontReplacementOp
|
||||||
|
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||||
|
from openvino.tools.mo.graph.graph import Graph, rename_nodes
|
||||||
|
from openvino.tools.mo.ops.Cast import Cast
|
||||||
|
from openvino.tools.mo.ops.ReduceOps import ReduceProd
|
||||||
|
from openvino.tools.mo.ops.elementwise import Add, Div, Mul
|
||||||
|
from openvino.tools.mo.ops.gather import Gather
|
||||||
|
from openvino.tools.mo.ops.range import Range
|
||||||
|
from openvino.tools.mo.ops.reshape import Reshape
|
||||||
|
from openvino.tools.mo.ops.shape import Shape
|
||||||
|
from openvino.tools.mo.ops.slice import Slice
|
||||||
|
from openvino.tools.mo.ops.squeeze import Squeeze
|
||||||
|
from openvino.tools.mo.ops.tile import Tile
|
||||||
|
from openvino.tools.mo.utils.error import Error
|
||||||
|
|
||||||
|
|
||||||
|
class ArangeLikeReplacer(FrontReplacementOp):
|
||||||
|
op = 'arange_like'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||||
|
node = match['op']
|
||||||
|
name = node.soft_get('name', node.id)
|
||||||
|
axis = node.axis
|
||||||
|
input_shape_node = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
|
||||||
|
range_node = create_op_with_const_inputs(graph, Range, {0: mo_array(node.start),
|
||||||
|
2: mo_array(node.step)}, {'name': name + '/Range'})
|
||||||
|
node.in_port(0).get_connection().set_destination(input_shape_node.in_port(0))
|
||||||
|
|
||||||
|
if axis is not None:
|
||||||
|
'''
|
||||||
|
Replace arange_like op to subgraph:
|
||||||
|
Shape - Gather - Range
|
||||||
|
'''
|
||||||
|
gather_node = create_op_with_const_inputs(graph, Gather, {1: int64_array([axis]),
|
||||||
|
2: int64_array(0)},
|
||||||
|
{'name': name + '/Gather'})
|
||||||
|
input_shape_node.out_port(0).connect(gather_node.in_port(0))
|
||||||
|
gather_node.out_port(0).connect(range_node.in_port(1))
|
||||||
|
node.out_port(0).get_connection().set_source(range_node.out_port(0))
|
||||||
|
rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)])
|
||||||
|
else:
|
||||||
|
r'''
|
||||||
|
Replace arange_like op to subgraph:
|
||||||
|
|
|
||||||
|
ShapeOf ----------- |
|
||||||
|
| |
|
||||||
|
ReduceProd |
|
||||||
|
| |
|
||||||
|
Range |
|
||||||
|
| |
|
||||||
|
Reshape ----------- |
|
||||||
|
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
flattened_shape_node = create_op_with_const_inputs(graph, ReduceProd, {1: int64_array([0])},
|
||||||
|
{'name': input_shape_node.name + '/ReduceProd',
|
||||||
|
'keep_dims': True})
|
||||||
|
reshape_backward_node = Reshape(graph, {'name': name + '/Reshape_backward'}).create_node()
|
||||||
|
|
||||||
|
input_shape_node.out_port(0).connect(flattened_shape_node.in_port(0))
|
||||||
|
flattened_shape_node.out_port(0).connect(range_node.in_port(1))
|
||||||
|
range_node.out_port(0).connect(reshape_backward_node.in_port(0))
|
||||||
|
input_shape_node.out_port(0).connect(reshape_backward_node.in_port(1))
|
||||||
|
node.out_port(0).get_connection().set_source(reshape_backward_node.out_port(0))
|
||||||
|
rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)])
|
||||||
|
|
||||||
|
if node.repeat != 1:
|
||||||
|
r"""
|
||||||
|
First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1.
|
||||||
|
Then repeats each value of the interval using Tile. After that we can get a longer interval
|
||||||
|
so we reduce it with Slice.
|
||||||
|
|
||||||
|
Sub-graph after Range node will be look like
|
||||||
|
|
||||||
|
Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if node.repeat < 1:
|
||||||
|
raise Error("Unexpected value {} of the attribute 'repeat' for the node {}". format(node.repeat, name))
|
||||||
|
|
||||||
|
div_node = create_op_with_const_inputs(graph, Div, {1: int64_array([node.repeat])},
|
||||||
|
{'name': name + '/Divide'})
|
||||||
|
add_node = create_op_with_const_inputs(graph, Add, {1: int64_array([1])},
|
||||||
|
{'name': div_node.name + '/Add'})
|
||||||
|
cast_node = Cast(graph, {'name': name + '/ConvertToI64', 'dst_type': np.int64}).create_node()
|
||||||
|
|
||||||
|
cast_node.out_port(0).connect(div_node.in_port(0))
|
||||||
|
div_node.out_port(0).connect(add_node.in_port(0))
|
||||||
|
range_node.in_port(1).get_connection().set_destination(cast_node.in_port(0))
|
||||||
|
add_node.out_port(0).connect(range_node.in_port(1))
|
||||||
|
|
||||||
|
tile_forward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1, 1])},
|
||||||
|
{'name': range_node.name + '/ForwardReshape'})
|
||||||
|
tile = create_op_with_const_inputs(graph, Tile, {1: int64_array([1, node.repeat])},
|
||||||
|
{'name': tile_forward_reshape.name + '/Tile'})
|
||||||
|
tile_backward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1])},
|
||||||
|
{'name': tile.name + '/BackwardReshape'})
|
||||||
|
slice_node = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]),
|
||||||
|
4: int64_array([1])},
|
||||||
|
{'name': tile_backward_reshape.name + '/Slice'})
|
||||||
|
|
||||||
|
tile_forward_reshape.out_port(0).connect(tile.in_port(0))
|
||||||
|
tile.out_port(0).connect(tile_backward_reshape.in_port(0))
|
||||||
|
tile_backward_reshape.out_port(0).connect(slice_node.in_port(0))
|
||||||
|
slice_node.in_port(2).connect(div_node.in_port(0).get_source())
|
||||||
|
|
||||||
|
range_node.out_port(0).get_connection().set_source(slice_node.out_port(0))
|
||||||
|
range_node.out_port(0).connect(tile_forward_reshape.in_port(0))
|
||||||
|
|
||||||
|
if axis is not None:
|
||||||
|
rename_nodes([(range_node, name + '/Range'), (slice_node, name)])
|
||||||
|
|
||||||
|
# MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so
|
||||||
|
# we have to correct the stop value for the Range node if step != 1 or start != 0
|
||||||
|
if node.step != 1:
|
||||||
|
# If step attribute is not integer, we will generate an interval with a larger size and then reduce it
|
||||||
|
# using Slice
|
||||||
|
true_elements_count_port = range_node.in_port(1).get_source()
|
||||||
|
mul_value = np.ceil(node.step) if node.step > 0 else np.floor(node.step)
|
||||||
|
stop_value = create_op_with_const_inputs(graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))},
|
||||||
|
op_attrs={'name': range_node.name + '/Stop'})
|
||||||
|
range_node.in_port(1).get_connection().insert_node(stop_value)
|
||||||
|
|
||||||
|
slice_range_values = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]),
|
||||||
|
4: int64_array([1])},
|
||||||
|
{'name': range_node.name + '/Slice'})
|
||||||
|
slice_range_values.in_port(2).connect(true_elements_count_port)
|
||||||
|
range_node.out_port(0).get_connection().insert_node(slice_range_values)
|
||||||
|
|
||||||
|
if axis is not None and node.repeat == 1:
|
||||||
|
rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)])
|
||||||
|
|
||||||
|
if node.start != 0:
|
||||||
|
correct_stop_value = create_op_with_const_inputs(graph, Add, port_value_dict={1: mo_array(node.start)},
|
||||||
|
op_attrs={'name': range_node.name + '/Correct_Stop'})
|
||||||
|
range_node.in_port(1).get_connection().insert_node(correct_stop_value)
|
||||||
|
|
||||||
|
# Range node supports only scalar inputs
|
||||||
|
squeeze_node = create_op_with_const_inputs(graph, Squeeze, port_value_dict={1: int64_array(0)},
|
||||||
|
op_attrs={"name": range_node.name + '/Stop/Squeeze'})
|
||||||
|
range_node.in_port(1).get_connection().insert_node(squeeze_node)
|
@ -39,15 +39,12 @@ class DivSqrtDim(FrontReplacementOp):
|
|||||||
# Due to specification, Power must have inputs with the same data type.
|
# Due to specification, Power must have inputs with the same data type.
|
||||||
convert_pow_input = Cast(graph, dict(dst_type=np.float32,
|
convert_pow_input = Cast(graph, dict(dst_type=np.float32,
|
||||||
name=shape_values_node.name + '/ConvertToFP32')).create_node()
|
name=shape_values_node.name + '/ConvertToFP32')).create_node()
|
||||||
convert_pow_output = ConvertLike(graph, dict(name=pow_node.name + 'ConvertLike')).create_node()
|
|
||||||
div_node = Div(graph, dict(name="Div")).create_node()
|
div_node = Div(graph, dict(name="Div")).create_node()
|
||||||
|
|
||||||
shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
|
shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
|
||||||
convert_pow_input.out_port(0).connect(pow_node.in_port(0))
|
convert_pow_input.out_port(0).connect(pow_node.in_port(0))
|
||||||
div_sqrt.in_port(0).get_connection().set_destination(div_node.in_port(0))
|
div_sqrt.in_port(0).get_connection().set_destination(div_node.in_port(0))
|
||||||
pow_node.out_port(0).connect(convert_pow_output.in_port(0))
|
div_node.in_port(1).connect(pow_node.out_port(0))
|
||||||
convert_pow_output.in_port(1).connect(data_out_port)
|
|
||||||
div_node.in_port(1).connect(convert_pow_output.out_port(0))
|
|
||||||
div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))
|
div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))
|
||||||
|
|
||||||
rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)])
|
rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)])
|
||||||
|
29
tools/mo/openvino/tools/mo/ops/arange_like.py
Normal file
29
tools/mo/openvino/tools/mo/ops/arange_like.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from openvino.tools.mo.graph.graph import Graph
|
||||||
|
from openvino.tools.mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
|
class ArangeLikeOp(Op):
|
||||||
|
"""
|
||||||
|
MXNet operation which returns a sequence of numbers. If axis attribute is None, the output has the
|
||||||
|
same shape as the input. Otherwise, the output is a 1D array with size of the specified axis.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
start - Start of interval
|
||||||
|
step - Spacing between values
|
||||||
|
repeat - The repeating time of all elements. Now we can support only default value (= 1)
|
||||||
|
axis - Arange elements according to the size of a certain axis of input array. Defualt value is None
|
||||||
|
|
||||||
|
"""
|
||||||
|
op = 'arange_like'
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
|
mandatory_props = {
|
||||||
|
'type': None,
|
||||||
|
'op': self.op,
|
||||||
|
'infer': None,
|
||||||
|
'in_ports_count': 1,
|
||||||
|
'out_ports_count': 1,
|
||||||
|
}
|
||||||
|
super().__init__(graph, mandatory_props, attrs)
|
@ -90,6 +90,7 @@ class Gather(Op):
|
|||||||
data_value = node.in_port(0).data.get_value()
|
data_value = node.in_port(0).data.get_value()
|
||||||
indices_value = node.in_port(1).data.get_value()
|
indices_value = node.in_port(1).data.get_value()
|
||||||
if data_value is not None and indices_value is not None and is_fully_defined(indices_value):
|
if data_value is not None and indices_value is not None and is_fully_defined(indices_value):
|
||||||
|
indices_value = int64_array(indices_value)
|
||||||
if batch_dims == 0:
|
if batch_dims == 0:
|
||||||
node.out_port(0).data.set_value(np.ma.take(data_value, indices_value, axis))
|
node.out_port(0).data.set_value(np.ma.take(data_value, indices_value, axis))
|
||||||
else:
|
else:
|
||||||
|
238
tools/mo/unit_tests/mo/front/mxnet/arange_like_test.py
Normal file
238
tools/mo/unit_tests/mo/front/mxnet/arange_like_test.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import unittest
|
||||||
|
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from openvino.tools.mo.front.mxnet.arange_like_replacer import ArangeLikeReplacer
|
||||||
|
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from unit_tests.utils.graph import build_graph, shaped_parameter, regular_op_with_empty_data, result, connect, \
|
||||||
|
shaped_const_with_data, connect_data
|
||||||
|
|
||||||
|
|
||||||
|
class ArangeLikeReplacerTest(unittest.TestCase):
|
||||||
|
def test_axis_not_none_start_0(self):
|
||||||
|
graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1,
|
||||||
|
'start': 0, 'step': 1}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'arange_like'),
|
||||||
|
*connect('arange_like', 'result')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ref_graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||||
|
**shaped_const_with_data('gather_axis', None),
|
||||||
|
**shaped_const_with_data('gather_indices', None),
|
||||||
|
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
|
||||||
|
**shaped_const_with_data('range_start', None),
|
||||||
|
**shaped_const_with_data('range_step', None),
|
||||||
|
**shaped_const_with_data('squeeze_const', None),
|
||||||
|
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||||
|
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'shape_of'),
|
||||||
|
*connect('shape_of', '0:gather'),
|
||||||
|
*connect('gather_axis', '1:gather'),
|
||||||
|
*connect('gather_indices', '2:gather'),
|
||||||
|
*connect('range_start', '0:range'),
|
||||||
|
*connect('gather', '0:squeeze'),
|
||||||
|
*connect('squeeze_const', '1:squeeze'),
|
||||||
|
*connect('squeeze', '1:range'),
|
||||||
|
*connect('range_step', '2:range'),
|
||||||
|
*connect('range', 'result')
|
||||||
|
],
|
||||||
|
update_attributes={
|
||||||
|
'gather_axis': {'value': 3},
|
||||||
|
'gather_indices': {'value': 0},
|
||||||
|
'range_start': {'value': 0},
|
||||||
|
'range_step': {'value': 1}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||||
|
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_axis_not_none_start_1_step_2(self):
|
||||||
|
graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1,
|
||||||
|
'start': 1, 'step': 2}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'arange_like'),
|
||||||
|
*connect('arange_like', 'result')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ref_graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||||
|
**shaped_const_with_data('gather_axis', None),
|
||||||
|
**shaped_const_with_data('gather_indices', None),
|
||||||
|
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
|
||||||
|
**regular_op_with_empty_data('mul', {'op': 'Mul', 'type': 'Multiply'}),
|
||||||
|
**shaped_const_with_data('mul_const', None),
|
||||||
|
**shaped_const_with_data('range_start', None),
|
||||||
|
**shaped_const_with_data('range_step', None),
|
||||||
|
**shaped_const_with_data('add_const', None),
|
||||||
|
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
|
||||||
|
**shaped_const_with_data('squeeze_const', None),
|
||||||
|
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||||
|
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||||
|
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
|
||||||
|
**shaped_const_with_data('slice_start', None),
|
||||||
|
**shaped_const_with_data('slice_axes', None),
|
||||||
|
**shaped_const_with_data('slice_step', None),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'shape_of'),
|
||||||
|
*connect('shape_of', '0:gather'),
|
||||||
|
*connect('gather_axis', '1:gather'),
|
||||||
|
*connect('gather_indices', '2:gather'),
|
||||||
|
*connect('range_start', '0:range'),
|
||||||
|
*connect('gather', '0:mul'),
|
||||||
|
*connect('mul_const', '1:mul'),
|
||||||
|
*connect('mul', '0:add'),
|
||||||
|
*connect('add_const', '1:add'),
|
||||||
|
*connect('squeeze_const', '1:squeeze'),
|
||||||
|
*connect('add', '0:squeeze'),
|
||||||
|
*connect('squeeze', '1:range'),
|
||||||
|
*connect('range_step', '2:range'),
|
||||||
|
*connect('range', '0:slice'),
|
||||||
|
*connect('slice_start', '1:slice'),
|
||||||
|
*connect_data('gather', '2:slice'),
|
||||||
|
*connect('slice_axes', '3:slice'),
|
||||||
|
*connect('slice_step', '4:slice'),
|
||||||
|
*connect('slice', 'result')
|
||||||
|
],
|
||||||
|
update_attributes={
|
||||||
|
'gather_axis': {'value': 3},
|
||||||
|
'gather_indices': {'value': 0},
|
||||||
|
'range_start': {'value': 1},
|
||||||
|
'range_step': {'value': 2},
|
||||||
|
'add_const': {'value': 1},
|
||||||
|
'mul_const': {'value': 2},
|
||||||
|
'slice_start': {'value': int64_array([0])},
|
||||||
|
'slice_axes': {'value': int64_array([0])},
|
||||||
|
'slice_step': {'value': int64_array([1])},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||||
|
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_axis_none_start_0(self):
|
||||||
|
graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None,
|
||||||
|
'repeat': 1, 'start': 0, 'step': 1}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'arange_like'),
|
||||||
|
*connect('arange_like', 'result')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ref_graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||||
|
**regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}),
|
||||||
|
**shaped_const_with_data('reduce_prod_const', None),
|
||||||
|
**shaped_const_with_data('squeeze_const', None),
|
||||||
|
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||||
|
**shaped_const_with_data('range_start', None),
|
||||||
|
**shaped_const_with_data('range_step', None),
|
||||||
|
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||||
|
**regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'shape_of'),
|
||||||
|
*connect('shape_of', '0:reduce_prod'),
|
||||||
|
*connect('reduce_prod_const', '1:reduce_prod'),
|
||||||
|
*connect('squeeze_const', '1:squeeze'),
|
||||||
|
*connect('reduce_prod', '0:squeeze'),
|
||||||
|
*connect('range_start', '0:range'),
|
||||||
|
*connect('range_step', '2:range'),
|
||||||
|
*connect('squeeze', '1:range'),
|
||||||
|
*connect('range', '0:reshape_backward'),
|
||||||
|
*connect_data('shape_of', '1:reshape_backward'),
|
||||||
|
*connect('reshape_backward', 'result')
|
||||||
|
],
|
||||||
|
update_attributes={
|
||||||
|
'range_start': {'value': 0},
|
||||||
|
'range_step': {'value': 1},
|
||||||
|
'reduce_prod_const': {'value': int64_array([0])}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||||
|
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_axis_none_start_1(self):
|
||||||
|
graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None,
|
||||||
|
'repeat': 1, 'start': 1, 'step': 1}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'arange_like'),
|
||||||
|
*connect('arange_like', 'result')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ref_graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||||
|
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||||
|
**regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}),
|
||||||
|
**shaped_const_with_data('reduce_prod_const', None),
|
||||||
|
**shaped_const_with_data('squeeze_const', None),
|
||||||
|
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||||
|
**shaped_const_with_data('add_const', None),
|
||||||
|
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
|
||||||
|
**shaped_const_with_data('range_start', None),
|
||||||
|
**shaped_const_with_data('range_step', None),
|
||||||
|
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||||
|
**regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}),
|
||||||
|
**result('result')
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
*connect('input', 'shape_of'),
|
||||||
|
*connect('shape_of', '0:reduce_prod'),
|
||||||
|
*connect('reduce_prod_const', '1:reduce_prod'),
|
||||||
|
*connect('squeeze_const', '1:squeeze'),
|
||||||
|
*connect('add_const', '1:add'),
|
||||||
|
*connect('reduce_prod', '0:add'),
|
||||||
|
*connect('add', '0:squeeze'),
|
||||||
|
*connect('range_start', '0:range'),
|
||||||
|
*connect('range_step', '2:range'),
|
||||||
|
*connect('squeeze', '1:range'),
|
||||||
|
*connect('range', '0:reshape_backward'),
|
||||||
|
*connect_data('shape_of', '1:reshape_backward'),
|
||||||
|
*connect('reshape_backward', 'result')
|
||||||
|
],
|
||||||
|
update_attributes={
|
||||||
|
'range_start': {'value': 1},
|
||||||
|
'range_step': {'value': 1},
|
||||||
|
'add_const': {'value': 1},
|
||||||
|
'reduce_prod_const': {'value': int64_array([0])}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||||
|
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
@ -36,7 +36,6 @@ class DivSqrtDimTest(unittest.TestCase):
|
|||||||
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
|
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
|
||||||
**regular_op_with_empty_data('power', {'op': 'AttributedPower', 'power': 0.5, 'type': 'Power'}),
|
**regular_op_with_empty_data('power', {'op': 'AttributedPower', 'power': 0.5, 'type': 'Power'}),
|
||||||
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32}),
|
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32}),
|
||||||
**regular_op_with_empty_data('z_convert_like', {'op': 'ConvertLike', 'type': 'ConvertLike'}),
|
|
||||||
**regular_op_with_empty_data('div', {'op': 'Div', 'type': 'Divide'}),
|
**regular_op_with_empty_data('div', {'op': 'Div', 'type': 'Divide'}),
|
||||||
**result('result')
|
**result('result')
|
||||||
},
|
},
|
||||||
@ -48,9 +47,7 @@ class DivSqrtDimTest(unittest.TestCase):
|
|||||||
*connect('gather_indices', '2:gather'),
|
*connect('gather_indices', '2:gather'),
|
||||||
*connect('gather', 'cast'),
|
*connect('gather', 'cast'),
|
||||||
*connect('cast', 'power'),
|
*connect('cast', 'power'),
|
||||||
*connect('power', '0:z_convert_like'),
|
*connect('power', '1:div'),
|
||||||
*connect_front('input_d', '1:z_convert_like'),
|
|
||||||
*connect('z_convert_like', '1:div'),
|
|
||||||
*connect('div', 'result')
|
*connect('div', 'result')
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user