[ MO ] Support MXNet operation arange_like (#8939)

* arange_like_op

* added comments

* added unittests

* added step attr, changed axis condition, updated tests

* added op description

* fix nodes renaming

* sorted imports

* added case with repeat > 1

* finished arange_like, removed unit test

* small fix in gather infer function

* gather fix

* fix doc

* added unittests

* correct renames

* removed ConvertLike from div_sqrt_dim

* used ReduceProd instead reshape-shapeof

* added keep_dims attr to reduce_prod node
This commit is contained in:
Yegor Kruglov 2022-01-27 12:44:12 +03:00 committed by GitHub
parent 4ecab1eeea
commit bc70b2b68b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 447 additions and 8 deletions

View File

@ -53,6 +53,7 @@
| Symbol Name in MXNet\*| Limitations| | Symbol Name in MXNet\*| Limitations|
| :----------| :----------| | :----------| :----------|
| _Plus | | | _Plus | |
| _contrib_arange_like | |
| _contrib_box_nms | | | _contrib_box_nms | |
| _contrib_DeformableConvolution | | | _contrib_DeformableConvolution | |
| _contrib_DeformablePSROIPooling | | | _contrib_DeformablePSROIPooling | |

View File

@ -268,6 +268,8 @@ openvino/tools/mo/front/mxnet/activation.py
openvino/tools/mo/front/mxnet/adaptive_avg_pooling_ext.py openvino/tools/mo/front/mxnet/adaptive_avg_pooling_ext.py
openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py openvino/tools/mo/front/mxnet/add_input_data_to_prior_boxes.py
openvino/tools/mo/front/mxnet/arange_ext.py openvino/tools/mo/front/mxnet/arange_ext.py
openvino/tools/mo/front/mxnet/arange_like_ext.py
openvino/tools/mo/front/mxnet/arange_like_replacer.py
openvino/tools/mo/front/mxnet/arange_replacer.py openvino/tools/mo/front/mxnet/arange_replacer.py
openvino/tools/mo/front/mxnet/batch_dot_ext.py openvino/tools/mo/front/mxnet/batch_dot_ext.py
openvino/tools/mo/front/mxnet/block_grad_ext.py openvino/tools/mo/front/mxnet/block_grad_ext.py
@ -847,6 +849,7 @@ openvino/tools/mo/ops/__init__.py
openvino/tools/mo/ops/activation.py openvino/tools/mo/ops/activation.py
openvino/tools/mo/ops/activation_ops.py openvino/tools/mo/ops/activation_ops.py
openvino/tools/mo/ops/adaptive_avg_pooling.py openvino/tools/mo/ops/adaptive_avg_pooling.py
openvino/tools/mo/ops/arange_like.py
openvino/tools/mo/ops/argmax.py openvino/tools/mo/ops/argmax.py
openvino/tools/mo/ops/argmin.py openvino/tools/mo/ops/argmin.py
openvino/tools/mo/ops/assert_op.py openvino/tools/mo/ops/assert_op.py

View File

@ -0,0 +1,25 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from openvino.tools.mo.front.extractor import FrontExtractorOp
from openvino.tools.mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.arange_like import ArangeLikeOp
class ArangeLikeExt(FrontExtractorOp):
op = '_contrib_arange_like'
enabled = True
@classmethod
def extract(cls, node: Node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
ArangeLikeOp.update_node_stat(node, {
'start': attrs.float('start', 0),
'repeat': attrs.int('repeat', 1),
'step': attrs.float('step', 1),
'axis': attrs.int('axis', None),
})
return cls.enabled

View File

@ -0,0 +1,148 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array
from openvino.tools.mo.front.common.replacement import FrontReplacementOp
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.ops.Cast import Cast
from openvino.tools.mo.ops.ReduceOps import ReduceProd
from openvino.tools.mo.ops.elementwise import Add, Div, Mul
from openvino.tools.mo.ops.gather import Gather
from openvino.tools.mo.ops.range import Range
from openvino.tools.mo.ops.reshape import Reshape
from openvino.tools.mo.ops.shape import Shape
from openvino.tools.mo.ops.slice import Slice
from openvino.tools.mo.ops.squeeze import Squeeze
from openvino.tools.mo.ops.tile import Tile
from openvino.tools.mo.utils.error import Error
class ArangeLikeReplacer(FrontReplacementOp):
op = 'arange_like'
enabled = True
def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
name = node.soft_get('name', node.id)
axis = node.axis
input_shape_node = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
range_node = create_op_with_const_inputs(graph, Range, {0: mo_array(node.start),
2: mo_array(node.step)}, {'name': name + '/Range'})
node.in_port(0).get_connection().set_destination(input_shape_node.in_port(0))
if axis is not None:
'''
Replace arange_like op to subgraph:
Shape - Gather - Range
'''
gather_node = create_op_with_const_inputs(graph, Gather, {1: int64_array([axis]),
2: int64_array(0)},
{'name': name + '/Gather'})
input_shape_node.out_port(0).connect(gather_node.in_port(0))
gather_node.out_port(0).connect(range_node.in_port(1))
node.out_port(0).get_connection().set_source(range_node.out_port(0))
rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)])
else:
r'''
Replace arange_like op to subgraph:
|
ShapeOf ----------- |
| |
ReduceProd |
| |
Range |
| |
Reshape ----------- |
|
'''
flattened_shape_node = create_op_with_const_inputs(graph, ReduceProd, {1: int64_array([0])},
{'name': input_shape_node.name + '/ReduceProd',
'keep_dims': True})
reshape_backward_node = Reshape(graph, {'name': name + '/Reshape_backward'}).create_node()
input_shape_node.out_port(0).connect(flattened_shape_node.in_port(0))
flattened_shape_node.out_port(0).connect(range_node.in_port(1))
range_node.out_port(0).connect(reshape_backward_node.in_port(0))
input_shape_node.out_port(0).connect(reshape_backward_node.in_port(1))
node.out_port(0).get_connection().set_source(reshape_backward_node.out_port(0))
rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)])
if node.repeat != 1:
r"""
First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1.
Then repeats each value of the interval using Tile. After that we can get a longer interval
so we reduce it with Slice.
Sub-graph after Range node will be look like
Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice
"""
if node.repeat < 1:
raise Error("Unexpected value {} of the attribute 'repeat' for the node {}". format(node.repeat, name))
div_node = create_op_with_const_inputs(graph, Div, {1: int64_array([node.repeat])},
{'name': name + '/Divide'})
add_node = create_op_with_const_inputs(graph, Add, {1: int64_array([1])},
{'name': div_node.name + '/Add'})
cast_node = Cast(graph, {'name': name + '/ConvertToI64', 'dst_type': np.int64}).create_node()
cast_node.out_port(0).connect(div_node.in_port(0))
div_node.out_port(0).connect(add_node.in_port(0))
range_node.in_port(1).get_connection().set_destination(cast_node.in_port(0))
add_node.out_port(0).connect(range_node.in_port(1))
tile_forward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1, 1])},
{'name': range_node.name + '/ForwardReshape'})
tile = create_op_with_const_inputs(graph, Tile, {1: int64_array([1, node.repeat])},
{'name': tile_forward_reshape.name + '/Tile'})
tile_backward_reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([-1])},
{'name': tile.name + '/BackwardReshape'})
slice_node = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]),
4: int64_array([1])},
{'name': tile_backward_reshape.name + '/Slice'})
tile_forward_reshape.out_port(0).connect(tile.in_port(0))
tile.out_port(0).connect(tile_backward_reshape.in_port(0))
tile_backward_reshape.out_port(0).connect(slice_node.in_port(0))
slice_node.in_port(2).connect(div_node.in_port(0).get_source())
range_node.out_port(0).get_connection().set_source(slice_node.out_port(0))
range_node.out_port(0).connect(tile_forward_reshape.in_port(0))
if axis is not None:
rename_nodes([(range_node, name + '/Range'), (slice_node, name)])
# MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so
# we have to correct the stop value for the Range node if step != 1 or start != 0
if node.step != 1:
# If step attribute is not integer, we will generate an interval with a larger size and then reduce it
# using Slice
true_elements_count_port = range_node.in_port(1).get_source()
mul_value = np.ceil(node.step) if node.step > 0 else np.floor(node.step)
stop_value = create_op_with_const_inputs(graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))},
op_attrs={'name': range_node.name + '/Stop'})
range_node.in_port(1).get_connection().insert_node(stop_value)
slice_range_values = create_op_with_const_inputs(graph, Slice, {1: int64_array([0]), 3: int64_array([0]),
4: int64_array([1])},
{'name': range_node.name + '/Slice'})
slice_range_values.in_port(2).connect(true_elements_count_port)
range_node.out_port(0).get_connection().insert_node(slice_range_values)
if axis is not None and node.repeat == 1:
rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)])
if node.start != 0:
correct_stop_value = create_op_with_const_inputs(graph, Add, port_value_dict={1: mo_array(node.start)},
op_attrs={'name': range_node.name + '/Correct_Stop'})
range_node.in_port(1).get_connection().insert_node(correct_stop_value)
# Range node supports only scalar inputs
squeeze_node = create_op_with_const_inputs(graph, Squeeze, port_value_dict={1: int64_array(0)},
op_attrs={"name": range_node.name + '/Stop/Squeeze'})
range_node.in_port(1).get_connection().insert_node(squeeze_node)

View File

@ -39,15 +39,12 @@ class DivSqrtDim(FrontReplacementOp):
# Due to specification, Power must have inputs with the same data type. # Due to specification, Power must have inputs with the same data type.
convert_pow_input = Cast(graph, dict(dst_type=np.float32, convert_pow_input = Cast(graph, dict(dst_type=np.float32,
name=shape_values_node.name + '/ConvertToFP32')).create_node() name=shape_values_node.name + '/ConvertToFP32')).create_node()
convert_pow_output = ConvertLike(graph, dict(name=pow_node.name + 'ConvertLike')).create_node()
div_node = Div(graph, dict(name="Div")).create_node() div_node = Div(graph, dict(name="Div")).create_node()
shape_values_node.out_port(0).connect(convert_pow_input.in_port(0)) shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
convert_pow_input.out_port(0).connect(pow_node.in_port(0)) convert_pow_input.out_port(0).connect(pow_node.in_port(0))
div_sqrt.in_port(0).get_connection().set_destination(div_node.in_port(0)) div_sqrt.in_port(0).get_connection().set_destination(div_node.in_port(0))
pow_node.out_port(0).connect(convert_pow_output.in_port(0)) div_node.in_port(1).connect(pow_node.out_port(0))
convert_pow_output.in_port(1).connect(data_out_port)
div_node.in_port(1).connect(convert_pow_output.out_port(0))
div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0)) div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))
rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)]) rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)])

View File

@ -0,0 +1,29 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.graph.graph import Graph
from openvino.tools.mo.ops.op import Op
class ArangeLikeOp(Op):
"""
MXNet operation which returns a sequence of numbers. If axis attribute is None, the output has the
same shape as the input. Otherwise, the output is a 1D array with size of the specified axis.
Attributes:
start - Start of interval
step - Spacing between values
repeat - The repeating time of all elements. Now we can support only default value (= 1)
axis - Arange elements according to the size of a certain axis of input array. Defualt value is None
"""
op = 'arange_like'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': None,
'op': self.op,
'infer': None,
'in_ports_count': 1,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)

View File

@ -90,6 +90,7 @@ class Gather(Op):
data_value = node.in_port(0).data.get_value() data_value = node.in_port(0).data.get_value()
indices_value = node.in_port(1).data.get_value() indices_value = node.in_port(1).data.get_value()
if data_value is not None and indices_value is not None and is_fully_defined(indices_value): if data_value is not None and indices_value is not None and is_fully_defined(indices_value):
indices_value = int64_array(indices_value)
if batch_dims == 0: if batch_dims == 0:
node.out_port(0).data.set_value(np.ma.take(data_value, indices_value, axis)) node.out_port(0).data.set_value(np.ma.take(data_value, indices_value, axis))
else: else:

View File

@ -0,0 +1,238 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.mxnet.arange_like_replacer import ArangeLikeReplacer
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, shaped_parameter, regular_op_with_empty_data, result, connect, \
shaped_const_with_data, connect_data
class ArangeLikeReplacerTest(unittest.TestCase):
def test_axis_not_none_start_0(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1,
'start': 0, 'step': 1}),
**result('result')
},
edges=[
*connect('input', 'arange_like'),
*connect('arange_like', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
**shaped_const_with_data('gather_axis', None),
**shaped_const_with_data('gather_indices', None),
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
**shaped_const_with_data('range_start', None),
**shaped_const_with_data('range_step', None),
**shaped_const_with_data('squeeze_const', None),
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
**result('result')
},
edges=[
*connect('input', 'shape_of'),
*connect('shape_of', '0:gather'),
*connect('gather_axis', '1:gather'),
*connect('gather_indices', '2:gather'),
*connect('range_start', '0:range'),
*connect('gather', '0:squeeze'),
*connect('squeeze_const', '1:squeeze'),
*connect('squeeze', '1:range'),
*connect('range_step', '2:range'),
*connect('range', 'result')
],
update_attributes={
'gather_axis': {'value': 3},
'gather_indices': {'value': 0},
'range_start': {'value': 0},
'range_step': {'value': 1}
}
)
ArangeLikeReplacer().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_axis_not_none_start_1_step_2(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': 3, 'repeat': 1,
'start': 1, 'step': 2}),
**result('result')
},
edges=[
*connect('input', 'arange_like'),
*connect('arange_like', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
**shaped_const_with_data('gather_axis', None),
**shaped_const_with_data('gather_indices', None),
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
**regular_op_with_empty_data('mul', {'op': 'Mul', 'type': 'Multiply'}),
**shaped_const_with_data('mul_const', None),
**shaped_const_with_data('range_start', None),
**shaped_const_with_data('range_step', None),
**shaped_const_with_data('add_const', None),
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
**shaped_const_with_data('squeeze_const', None),
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
**shaped_const_with_data('slice_start', None),
**shaped_const_with_data('slice_axes', None),
**shaped_const_with_data('slice_step', None),
**result('result')
},
edges=[
*connect('input', 'shape_of'),
*connect('shape_of', '0:gather'),
*connect('gather_axis', '1:gather'),
*connect('gather_indices', '2:gather'),
*connect('range_start', '0:range'),
*connect('gather', '0:mul'),
*connect('mul_const', '1:mul'),
*connect('mul', '0:add'),
*connect('add_const', '1:add'),
*connect('squeeze_const', '1:squeeze'),
*connect('add', '0:squeeze'),
*connect('squeeze', '1:range'),
*connect('range_step', '2:range'),
*connect('range', '0:slice'),
*connect('slice_start', '1:slice'),
*connect_data('gather', '2:slice'),
*connect('slice_axes', '3:slice'),
*connect('slice_step', '4:slice'),
*connect('slice', 'result')
],
update_attributes={
'gather_axis': {'value': 3},
'gather_indices': {'value': 0},
'range_start': {'value': 1},
'range_step': {'value': 2},
'add_const': {'value': 1},
'mul_const': {'value': 2},
'slice_start': {'value': int64_array([0])},
'slice_axes': {'value': int64_array([0])},
'slice_step': {'value': int64_array([1])},
}
)
ArangeLikeReplacer().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_axis_none_start_0(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None,
'repeat': 1, 'start': 0, 'step': 1}),
**result('result')
},
edges=[
*connect('input', 'arange_like'),
*connect('arange_like', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
**regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}),
**shaped_const_with_data('reduce_prod_const', None),
**shaped_const_with_data('squeeze_const', None),
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
**shaped_const_with_data('range_start', None),
**shaped_const_with_data('range_step', None),
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
**regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}),
**result('result')
},
edges=[
*connect('input', 'shape_of'),
*connect('shape_of', '0:reduce_prod'),
*connect('reduce_prod_const', '1:reduce_prod'),
*connect('squeeze_const', '1:squeeze'),
*connect('reduce_prod', '0:squeeze'),
*connect('range_start', '0:range'),
*connect('range_step', '2:range'),
*connect('squeeze', '1:range'),
*connect('range', '0:reshape_backward'),
*connect_data('shape_of', '1:reshape_backward'),
*connect('reshape_backward', 'result')
],
update_attributes={
'range_start': {'value': 0},
'range_step': {'value': 1},
'reduce_prod_const': {'value': int64_array([0])}
}
)
ArangeLikeReplacer().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_axis_none_start_1(self):
graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('arange_like', {'op': 'arange_like', 'type': None, 'axis': None,
'repeat': 1, 'start': 1, 'step': 1}),
**result('result')
},
edges=[
*connect('input', 'arange_like'),
*connect('arange_like', 'result')
]
)
ref_graph = build_graph(
nodes_attrs={
**shaped_parameter('input', int64_array([1, 3, 5, 5])),
**regular_op_with_empty_data('shape_of', {'op': 'ShapeOf', 'type': 'ShapeOf'}),
**regular_op_with_empty_data('reduce_prod', {'op': 'ReduceProd', 'type': 'ReduceProd'}),
**shaped_const_with_data('reduce_prod_const', None),
**shaped_const_with_data('squeeze_const', None),
**regular_op_with_empty_data('squeeze', {'op': 'Squeeze', 'type': 'Squeeze'}),
**shaped_const_with_data('add_const', None),
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add'}),
**shaped_const_with_data('range_start', None),
**shaped_const_with_data('range_step', None),
**regular_op_with_empty_data('range', {'op': 'Range', 'type': 'Range'}),
**regular_op_with_empty_data('reshape_backward', {'op': 'Reshape', 'type': 'Reshape'}),
**result('result')
},
edges=[
*connect('input', 'shape_of'),
*connect('shape_of', '0:reduce_prod'),
*connect('reduce_prod_const', '1:reduce_prod'),
*connect('squeeze_const', '1:squeeze'),
*connect('add_const', '1:add'),
*connect('reduce_prod', '0:add'),
*connect('add', '0:squeeze'),
*connect('range_start', '0:range'),
*connect('range_step', '2:range'),
*connect('squeeze', '1:range'),
*connect('range', '0:reshape_backward'),
*connect_data('shape_of', '1:reshape_backward'),
*connect('reshape_backward', 'result')
],
update_attributes={
'range_start': {'value': 1},
'range_step': {'value': 1},
'add_const': {'value': 1},
'reduce_prod_const': {'value': int64_array([0])}
}
)
ArangeLikeReplacer().find_and_replace_pattern(graph)
flag, resp = compare_graphs(graph, ref_graph, 'result', 'result', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -36,7 +36,6 @@ class DivSqrtDimTest(unittest.TestCase):
**regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}), **regular_op_with_empty_data('gather', {'op': 'Gather', 'type': 'Gather'}),
**regular_op_with_empty_data('power', {'op': 'AttributedPower', 'power': 0.5, 'type': 'Power'}), **regular_op_with_empty_data('power', {'op': 'AttributedPower', 'power': 0.5, 'type': 'Power'}),
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32}), **regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32}),
**regular_op_with_empty_data('z_convert_like', {'op': 'ConvertLike', 'type': 'ConvertLike'}),
**regular_op_with_empty_data('div', {'op': 'Div', 'type': 'Divide'}), **regular_op_with_empty_data('div', {'op': 'Div', 'type': 'Divide'}),
**result('result') **result('result')
}, },
@ -48,9 +47,7 @@ class DivSqrtDimTest(unittest.TestCase):
*connect('gather_indices', '2:gather'), *connect('gather_indices', '2:gather'),
*connect('gather', 'cast'), *connect('gather', 'cast'),
*connect('cast', 'power'), *connect('cast', 'power'),
*connect('power', '0:z_convert_like'), *connect('power', '1:div'),
*connect_front('input_d', '1:z_convert_like'),
*connect('z_convert_like', '1:div'),
*connect('div', 'result') *connect('div', 'result')
], ],
) )