[ MO ] Support MXNet operation arange_like (#8939)
* arange_like_op * added comments * added unittests * added step attr, changed axis condition, updated tests * added op description * fix nodes renaming * sorted imports * added case with repeat > 1 * finished arange_like, removed unit test * small fix in gather infer function * gather fix * fix doc * added unittests * correct renames * removed ConvertLike from div_sqrt_dim * used ReduceProd instead reshape-shapeof * added keep_dims attr to reduce_prod node
This commit is contained in:
parent
4ecab1eeea
commit
bc70b2b68b
@ -53,6 +53,7 @@
|
||||
| Symbol Name in MXNet\*| Limitations|
|
||||
| :----------| :----------|
|
||||
| _Plus | |
|
||||
| _contrib_arange_like | |
|
||||
| _contrib_box_nms | |
|
||||
| _contrib_DeformableConvolution | |
|
||||
| _contrib_DeformablePSROIPooling | |
|
||||
|
@ -268,6 +268,8 @@ openvino/tools/mo/front/mxnet/activation.py
|
||||
openvino/tools/mo/front/mxnet/adaptive_avg_pooling_ext.py
|
||||
openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
|
||||
openvino/tools/mo/front/mxnet/arange_ext.py
|
||||
openvino/tools/mo/front/mxnet/arange_like_ext.py
|
||||
openvino/tools/mo/front/mxnet/arange_like_replacer.py
|
||||
openvino/tools/mo/front/mxnet/arange_replacer.py
|
||||
openvino/tools/mo/front/mxnet/batch_dot_ext.py
|
||||
openvino/tools/mo/front/mxnet/block_grad_ext.py
|
||||
@ -847,6 +849,7 @@ openvino/tools/mo/ops/__init__.py
|
||||
openvino/tools/mo/ops/activation.py
|
||||
openvino/tools/mo/ops/activation_ops.py
|
||||
openvino/tools/mo/ops/adaptive_avg_pooling.py
|
||||
openvino/tools/mo/ops/arange_like.py
|
||||
openvino/tools/mo/ops/argmax.py
|
||||
openvino/tools/mo/ops/argmin.py
|
||||
openvino/tools/mo/ops/assert_op.py
|
||||
|
25
tools/mo/openvino/tools/mo/front/mxnet/arange_like_ext.py
Normal file
25
tools/mo/openvino/tools/mo/front/mxnet/arange_like_ext.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.arange_like import ArangeLikeOp
|
||||
|
||||
|
||||
class ArangeLikeExt(FrontExtractorOp):
|
||||
op = '_contrib_arange_like'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
attrs = get_mxnet_layer_attrs(node.symbol_dict)
|
||||
ArangeLikeOp.update_node_stat(node, {
|
||||
'start': attrs.float('start', 0),
|
||||
'repeat': attrs.int('repeat', 1),
|
||||
'step': attrs.float('step', 1),
|
||||
'axis': attrs.int('axis', None),
|
||||
})
|
||||
return cls.enabled
|
148
tools/mo/openvino/tools/mo/front/mxnet/arange_like_replacer.py
Normal file
148
tools/mo/openvino/tools/mo/front/mxnet/arange_like_replacer.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementOp
|
||||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from openvino.tools.mo.graph.graph import Graph, rename_nodes
|
||||
from openvino.tools.mo.ops.Cast import Cast
|
||||
from openvino.tools.mo.ops.ReduceOps import ReduceProd
|
||||
from openvino.tools.mo.ops.elementwise import Add, Div, Mul
|
||||
from openvino.tools.mo.ops.gather import Gather
|
||||
from openvino.tools.mo.ops.range import Range
|
||||
from openvino.tools.mo.ops.reshape import Reshape
|
||||
from openvino.tools.mo.ops.shape import Shape
|
||||
from openvino.tools.mo.ops.slice import Slice
|
||||
from openvino.tools.mo.ops.squeeze import Squeeze
|
||||
from openvino.tools.mo.ops.tile import Tile
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
|
||||
|
||||
class ArangeLikeReplacer(FrontReplacementOp):
|
||||
op = 'arange_like'
|
||||
enabled = True
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
name = node.soft_get('name', node.id)
|
||||
axis = node.axis
|
||||
input_shape_node = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
|
||||
range_node = create_op_with_const_inputs(graph, Range, {0: mo_array(node.start),
|
||||
2: mo_array(node.step)}, {'name': name + '/Range'})
|
||||
node.in_port(0).get_connection().set_destination(input_shape_node.in_port(0))
|
||||
|
||||
if axis is not None:
|
||||
'''
|
||||
Replace arange_like op to subgraph:
|
||||
Shape - Gather - Range
|
||||
'''
|
||||
gather_node = create_op_with_const_inputs(graph, Gather, {1: int64_array([axis]),
|
||||
2: int64_array(0)},
|
||||
{'name': name + '/Gather'})
|
||||
input_shape_node.out_port(0).connect(gather_node.in_port(0))
|
||||
gather_node.out_port(0).connect(range_node.in_port(1))
|
||||
node.out_port(0).get_connection().set_source(range_node.out_port(0))
|
||||
rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)])
|
||||
else:
|
||||
r'''
|
||||
Replace arange_like op to subgraph:
|
||||
|
|
||||
ShapeOf ----------- |
|
||||
| |
|
||||
ReduceProd |
|
||||
| |
|
||||
Range |
|
||||
| |
|
||||
Reshape ----------- |
|
||||
|
|
||||
'''
|
||||
|
||||
flattened_shape_node = create_op_with_const_inputs(graph, ReduceProd, {1: int64_array([0])},
|
||||
{'name': input_shape_node.name + '/ReduceProd',
|
||||
'keep_dims': True})
|
||||
reshape_backward_node = Reshape(graph, {'name': name + '/Reshape_backward'}).create_node()
|
||||
|
||||
input_shape_node.out_port(0).connect(flattened_shape_node.in_port(0))
|
||||
flattened_shape_node.out_port(0).connect(range_node.in_port(1))
|
||||
range_node.out_port(0).connect(reshape_backward_node.in_port(0))
|
||||
input_shape_node.out_port(0).connect(reshape_backward_node.in_port(1))
|
||||
node.out_port(0).get_connection().set_source(reshape_backward_node.out_port(0))
|
||||
rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)])
|
||||
|
||||
if node.repeat != 1:
|
||||
r"""
|
||||
First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1.
|
||||
Then repeats each value of the interval using Tile. After that we can get a longer interval
|
||||
so we reduce it with Slice.
|
||||
|
||||
Sub-graph after Range node will be look like
|
||||
|
||||
Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice
|
||||
|
||||
"""
|
||||
|
||||
if node.repeat < 1:
|
||||
raise Error("Unexpected value {} of the attribute 'repeat' for the node {}". format(node.repeat, name))
|
||||
|
||||
div_node = create_op_with_const_inputs(graph, Div, {1: int64_array([node.repeat])},
|
||||
{'name': name + '/Divide'})
|
||||
add_node = create_op_with_const_inputs(graph, Add, {1: int64_array([1])},
|
||||
{'name': div_node.name + '/Add'})
|
||||
cast_node = Cast(graph, {'name': name + '/ConvertToI64', 'dst_type': np.int64}).create_node()
|
||||
|
||||
cast_node.out_port(0).connect(div_node.in_port(0))
|
||||
div_node.out_port(0).connect(add_node.in_port(0))
|
||||
range_node.in_port(1).get_connection().set_destination(cast_node.in_port(0))
|
||||
add_node.out_port(0).connect(range_node.in_port(1))
|
||||
|
||||
tile_forward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1, 1])},
|
||||
{'name': range_node.name + '/ForwardReshape'})
|
||||
tile = create_op_with_const_inputs(graph, Tile, {1: int64_array([1, node.repeat])},
|
||||
{'name': tile_forward_reshape.name + '/Tile'})
|
||||
tile_backward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1])},
|
||||
{'name': tile.name + '/BackwardReshape'})
|
||||
slice_node = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]),
|
||||
4: int64_array([1])},
|
||||
{'name': tile_backward_reshape.name + '/Slice'})
|
||||
|
||||
tile_forward_reshape.out_port(0).connect(tile.in_port(0))
|
||||
tile.out_port(0).connect(tile_backward_reshape.in_port(0))
|
||||
tile_backward_reshape.out_port(0).connect(slice_node.in_port(0))
|
||||
slice_node.in_port(2).connect(div_node.in_port(0).get_source())
|
||||
|
||||
range_node.out_port(0).get_connection().set_source(slice_node.out_port(0))
|
||||
range_node.out_port(0).connect(tile_forward_reshape.in_port(0))
|
||||
|
||||
if axis is not None:
|
||||
rename_nodes([(range_node, name + '/Range'), (slice_node, name)])
|
||||
|
||||
# MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so
|
||||
# we have to correct the stop value for the Range node if step != 1 or start != 0
|
||||
if node.step != 1:
|
||||
# If step attribute is not integer, we will generate an interval with a larger size and then reduce it
|
||||
# using Slice
|
||||
true_elements_count_port = range_node.in_port(1).get_source()
|
||||
mul_value = np.ceil(node.step) if node.step > 0 else np.floor(node.step)
|
||||
stop_value = create_op_with_const_inputs(graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))},
|
||||
op_attrs={'name': range_node.name + '/Stop'})
|
||||
range_node.in_port(1).get_connection().insert_node(stop_value)
|
||||
|
||||
slice_range_values = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]),
|
||||
4: int64_array([1])},
|
||||
{'name': range_node.name + '/Slice'})
|
||||
slice_range_values.in_port(2).connect(true_elements_count_port)
|
||||
range_node.out_port(0).get_connection().insert_node(slice_range_values)
|
||||
|
||||
if axis is not None and node.repeat == 1:
|
||||
rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)])
|
||||
|
||||
if node.start != 0:
|
||||
correct_stop_value = create_op_with_const_inputs(graph, Add, port_value_dict={1: mo_array(node.start)},
|
||||
op_attrs={'name': range_node.name + '/Correct_Stop'})
|
||||
range_node.in_port(1).get_connection().insert_node(correct_stop_value)
|
||||
|
||||
# Range node supports only scalar inputs
|
||||
squeeze_node = create_op_with_const_inputs(graph, Squeeze, port_value_dict={1: int64_array(0)},
|
||||
op_attrs={"name": range_node.name + '/Stop/Squeeze'})
|
||||
range_node.in_port(1).get_connection().insert_node(squeeze_node)
|
@ -39,15 +39,12 @@ class DivSqrtDim(FrontReplacementOp):
|
||||
# Due to specification, Power must have inputs with the same data type.
|
||||
convert_pow_input = Cast(graph, dict(dst_type=np.float32,
|
||||
name=shape_values_node.name + '/ConvertToFP32')).create_node()
|
||||
convert_pow_output = ConvertLike(graph, dict(name=pow_node.name + 'ConvertLike')).create_node()
|
||||
div_node = Div(graph, dict(name="Div")).create_node()
|
||||
|
||||
shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
|
||||
convert_pow_input.out_port(0).connect(pow_node.in_port(0))
|
||||
div_sqrt.in_port(0).get_connection().set_destination(div_node.in_port(0))
|
||||
pow_node.out_port(0).connect(convert_pow_output.in_port(0))
|
||||
convert_pow_output.in_port(1).connect(data_out_port)
|
||||
div_node.in_port(1).connect(convert_pow_output.out_port(0))
|
||||
div_node.in_port(1).connect(pow_node.out_port(0))
|
||||
div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))
|
||||
|
||||
rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)])
|
||||
|
29
tools/mo/openvino/tools/mo/ops/arange_like.py
Normal file
29
tools/mo/openvino/tools/mo/ops/arange_like.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openvino.tools.mo.graph.graph import Graph
|
||||
from openvino.tools.mo.ops.op import Op
|
||||
|
||||
|
||||
class ArangeLikeOp(Op):
|
||||
"""
|
||||
MXNet operation which returns a sequence of numbers. If axis attribute is None, the output has the
|
||||
same shape as the input. Otherwise, the output is a 1D array with size of the specified axis.
|
||||
|
||||
Attributes:
|
||||
start - Start of interval
|
||||
step - Spacing between values
|
||||
repeat - The repeating time of all elements. Now we can support only default value (= 1)
|
||||
axis - Arange elements according to the size of a certain axis of input array. Defualt value is None
|
||||
|
||||
"""
|
||||
op = 'arange_like'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'infer': None,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
@ -90,6 +90,7 @@ class Gather(Op):
|
||||
data_value = node.in_port(0).data.get_value()
|
||||
indices_value = node.in_port(1).data.get_value()
|
||||
if data_value is not None and indices_value is not None and is_fully_defined(indices_value):
|
||||
indices_value = int64_array(indices_value)
|
||||
if batch_dims == 0:
|
||||
node.out_port(0).data.set_value(np.ma.take(data_value, indices_value, axis))
|
||||
else:
|
||||
|
238
tools/mo/unit_tests/mo/front/mxnet/arange_like_test.py
Normal file
238
tools/mo/unit_tests/mo/front/mxnet/arange_like_test.py
Normal file
@ -0,0 +1,238 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import unittest
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.front.mxnet.arange_like_replacer import ArangeLikeReplacer
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, shaped_parameter, regular_op_with_empty_data, result, connect, \
|
||||
shaped_const_with_data, connect_data
|
||||
|
||||
|
||||
class ArangeLikeReplacerTest(unittest.TestCase):
|
||||
def test_axis_not_none_start_0(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1,
|
||||
'start': 0, 'step': 1}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'arange_like'),
|
||||
*connect('arange_like', 'result')
|
||||
]
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||
**shaped_const_with_data('gather_axis', None),
|
||||
**shaped_const_with_data('gather_indices', None),
|
||||
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
|
||||
**shaped_const_with_data('range_start', None),
|
||||
**shaped_const_with_data('range_step', None),
|
||||
**shaped_const_with_data('squeeze_const', None),
|
||||
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'shape_of'),
|
||||
*connect('shape_of', '0:gather'),
|
||||
*connect('gather_axis', '1:gather'),
|
||||
*connect('gather_indices', '2:gather'),
|
||||
*connect('range_start', '0:range'),
|
||||
*connect('gather', '0:squeeze'),
|
||||
*connect('squeeze_const', '1:squeeze'),
|
||||
*connect('squeeze', '1:range'),
|
||||
*connect('range_step', '2:range'),
|
||||
*connect('range', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'gather_axis': {'value': 3},
|
||||
'gather_indices': {'value': 0},
|
||||
'range_start': {'value': 0},
|
||||
'range_step': {'value': 1}
|
||||
}
|
||||
)
|
||||
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_axis_not_none_start_1_step_2(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1,
|
||||
'start': 1, 'step': 2}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'arange_like'),
|
||||
*connect('arange_like', 'result')
|
||||
]
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||
**shaped_const_with_data('gather_axis', None),
|
||||
**shaped_const_with_data('gather_indices', None),
|
||||
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
|
||||
**regular_op_with_empty_data('mul', {'op': 'Mul', 'type': 'Multiply'}),
|
||||
**shaped_const_with_data('mul_const', None),
|
||||
**shaped_const_with_data('range_start', None),
|
||||
**shaped_const_with_data('range_step', None),
|
||||
**shaped_const_with_data('add_const', None),
|
||||
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
|
||||
**shaped_const_with_data('squeeze_const', None),
|
||||
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
|
||||
**shaped_const_with_data('slice_start', None),
|
||||
**shaped_const_with_data('slice_axes', None),
|
||||
**shaped_const_with_data('slice_step', None),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'shape_of'),
|
||||
*connect('shape_of', '0:gather'),
|
||||
*connect('gather_axis', '1:gather'),
|
||||
*connect('gather_indices', '2:gather'),
|
||||
*connect('range_start', '0:range'),
|
||||
*connect('gather', '0:mul'),
|
||||
*connect('mul_const', '1:mul'),
|
||||
*connect('mul', '0:add'),
|
||||
*connect('add_const', '1:add'),
|
||||
*connect('squeeze_const', '1:squeeze'),
|
||||
*connect('add', '0:squeeze'),
|
||||
*connect('squeeze', '1:range'),
|
||||
*connect('range_step', '2:range'),
|
||||
*connect('range', '0:slice'),
|
||||
*connect('slice_start', '1:slice'),
|
||||
*connect_data('gather', '2:slice'),
|
||||
*connect('slice_axes', '3:slice'),
|
||||
*connect('slice_step', '4:slice'),
|
||||
*connect('slice', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'gather_axis': {'value': 3},
|
||||
'gather_indices': {'value': 0},
|
||||
'range_start': {'value': 1},
|
||||
'range_step': {'value': 2},
|
||||
'add_const': {'value': 1},
|
||||
'mul_const': {'value': 2},
|
||||
'slice_start': {'value': int64_array([0])},
|
||||
'slice_axes': {'value': int64_array([0])},
|
||||
'slice_step': {'value': int64_array([1])},
|
||||
}
|
||||
)
|
||||
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_axis_none_start_0(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None,
|
||||
'repeat': 1, 'start': 0, 'step': 1}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'arange_like'),
|
||||
*connect('arange_like', 'result')
|
||||
]
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}),
|
||||
**shaped_const_with_data('reduce_prod_const', None),
|
||||
**shaped_const_with_data('squeeze_const', None),
|
||||
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||
**shaped_const_with_data('range_start', None),
|
||||
**shaped_const_with_data('range_step', None),
|
||||
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||
**regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'shape_of'),
|
||||
*connect('shape_of', '0:reduce_prod'),
|
||||
*connect('reduce_prod_const', '1:reduce_prod'),
|
||||
*connect('squeeze_const', '1:squeeze'),
|
||||
*connect('reduce_prod', '0:squeeze'),
|
||||
*connect('range_start', '0:range'),
|
||||
*connect('range_step', '2:range'),
|
||||
*connect('squeeze', '1:range'),
|
||||
*connect('range', '0:reshape_backward'),
|
||||
*connect_data('shape_of', '1:reshape_backward'),
|
||||
*connect('reshape_backward', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'range_start': {'value': 0},
|
||||
'range_step': {'value': 1},
|
||||
'reduce_prod_const': {'value': int64_array([0])}
|
||||
}
|
||||
)
|
||||
|
||||
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_axis_none_start_1(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None,
|
||||
'repeat': 1, 'start': 1, 'step': 1}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'arange_like'),
|
||||
*connect('arange_like', 'result')
|
||||
]
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs={
|
||||
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
|
||||
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}),
|
||||
**shaped_const_with_data('reduce_prod_const', None),
|
||||
**shaped_const_with_data('squeeze_const', None),
|
||||
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
|
||||
**shaped_const_with_data('add_const', None),
|
||||
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
|
||||
**shaped_const_with_data('range_start', None),
|
||||
**shaped_const_with_data('range_step', None),
|
||||
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
|
||||
**regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}),
|
||||
**result('result')
|
||||
},
|
||||
edges=[
|
||||
*connect('input', 'shape_of'),
|
||||
*connect('shape_of', '0:reduce_prod'),
|
||||
*connect('reduce_prod_const', '1:reduce_prod'),
|
||||
*connect('squeeze_const', '1:squeeze'),
|
||||
*connect('add_const', '1:add'),
|
||||
*connect('reduce_prod', '0:add'),
|
||||
*connect('add', '0:squeeze'),
|
||||
*connect('range_start', '0:range'),
|
||||
*connect('range_step', '2:range'),
|
||||
*connect('squeeze', '1:range'),
|
||||
*connect('range', '0:reshape_backward'),
|
||||
*connect_data('shape_of', '1:reshape_backward'),
|
||||
*connect('reshape_backward', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'range_start': {'value': 1},
|
||||
'range_step': {'value': 1},
|
||||
'add_const': {'value': 1},
|
||||
'reduce_prod_const': {'value': int64_array([0])}
|
||||
}
|
||||
)
|
||||
ArangeLikeReplacer().find_and_replace_pattern(graph)
|
||||
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -36,7 +36,6 @@ class DivSqrtDimTest(unittest.TestCase):
|
||||
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
|
||||
**regular_op_with_empty_data('power', {'op': 'AttributedPower', 'power': 0.5, 'type': 'Power'}),
|
||||
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32}),
|
||||
**regular_op_with_empty_data('z_convert_like', {'op': 'ConvertLike', 'type': 'ConvertLike'}),
|
||||
**regular_op_with_empty_data('div', {'op': 'Div', 'type': 'Divide'}),
|
||||
**result('result')
|
||||
},
|
||||
@ -48,9 +47,7 @@ class DivSqrtDimTest(unittest.TestCase):
|
||||
*connect('gather_indices', '2:gather'),
|
||||
*connect('gather', 'cast'),
|
||||
*connect('cast', 'power'),
|
||||
*connect('power', '0:z_convert_like'),
|
||||
*connect_front('input_d', '1:z_convert_like'),
|
||||
*connect('z_convert_like', '1:div'),
|
||||
*connect('power', '1:div'),
|
||||
*connect('div', 'result')
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user