ExpandRangeConstant adjustment for bidirectional Broadcast (#6739)

* Fixes in the transformation ExpandRangeConstant.

* Fixed test.

* Now we use use ShapeOf for both inputs of Broadcast.

* Now the transformation ExpandRangeConstant uses two Gather layers.

* Deletec commented code.

* Fixed tests for the transformation ExpandRangeConstant.

* Rewritten the transformation ExpandRangeConstant using Select.
This commit is contained in:
Vladimir Gavrilov
2021-08-09 16:49:07 +03:00
committed by GitHub
parent 6a63cb9122
commit 838e701e5e
2 changed files with 65 additions and 24 deletions

View File

@@ -3,12 +3,15 @@
import numpy as np
from extensions.ops.elementwise import Equal
from extensions.ops.gather import Gather
from extensions.ops.range import Range
from extensions.ops.select import Select
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph, rename_nodes, Node
from mo.ops.shape import Shape
from mo.ops.unsqueeze import Unsqueeze
@@ -51,16 +54,31 @@ class ExpandRangeConstant(FrontReplacementSubgraph):
positive_idx = non_one_dims.item(0)
negative_idx = positive_idx - len(shape)
node_name = node.soft_get('name', node.id)
gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
{'name': node.soft_get('name', node.id) + '/BroadcastingDim'})
{'name': node_name + '/BroadcastingDim'})
gather_for_const = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
{'name': const_name + '/BroadcastingDim'})
shapeof_node = Shape(graph, {'name': const_name + '/ShapeOf'}).create_node()
shapeof_node.out_port(0).connect(gather_for_const.in_port(0))
equal_node = create_op_with_const_inputs(graph, Equal, {1: int64_array(1)}, {'name': node_name + '/ConstOne'})
gather.out_port(0).connect(equal_node.in_port(0))
select_node = Select(graph, {'name': node_name + '/Select',
'auto_broadcast': 'numpy'}).create_node([equal_node, gather_for_const, gather])
const.out_port(0).connect(shapeof_node.in_port(0))
range_node = create_op_with_const_inputs(graph, Range,
{0: np.array(0, dtype=value.dtype),
2: np.array(1, dtype=value.dtype)},
{'name': const_name + '/Range', 'dtype': value.dtype})
select_node.out_port(0).connect(range_node.in_port(1))
node.in_port(1).get_connection().add_destination(gather.in_port(0))
gather.out_port(0).connect(range_node.in_port(1))
node.in_port(0).get_connection().set_source(range_node.out_port(0))
if one_dims.size:

View File

@@ -8,7 +8,7 @@ import numpy as np
from extensions.front.broadcast_with_range import ExpandRangeConstant
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \
regular_op_with_empty_data, connect_data
regular_op_with_empty_data
class TestRangeBroadcast(unittest.TestCase):
@@ -25,38 +25,61 @@ class TestRangeBroadcast(unittest.TestCase):
], nodes_with_edges_only=True)
ExpandRangeConstant().find_and_replace_pattern(graph)
graph_ref = build_graph({
graph_ref = build_graph(nodes_attrs={
**regular_op_with_shaped_data('shape', [2], {'type': 'Parameter'}),
**valued_const_with_data('value', np.arange(0, 384).reshape((1, 384))),
**regular_op_with_empty_data('bc', {'type': 'Broadcast'}),
**regular_op_with_empty_data('shapeof', {'type': 'ShapeOf'}),
**regular_op_with_empty_data('select', {'type': 'Select'}),
**regular_op_with_empty_data('gather', {'type': 'Gather'}),
'gather_const': {'type': 'Gather', 'kind': 'op', 'op': 'Gather'},
'equal': {'type': 'Equal', 'kind': 'op', 'op': 'Equal'},
# start
**valued_const_with_data('start', np.array(0)),
# limit
**valued_const_with_data('minus_one', np.array(-1)),
**valued_const_with_data('zero', np.array(0)),
**regular_op_with_empty_data('range_dim', {'type': 'Gather'}),
**valued_const_with_data('minus_one_0', np.array(-1)),
**valued_const_with_data('zero_0', np.array(0)),
**valued_const_with_data('minus_one_1', np.array(-1)),
**valued_const_with_data('zero_1', np.array(0)),
# delta
**valued_const_with_data('delta', np.array(1)),
**regular_op_with_empty_data('range', {'type': 'Range'}),
**regular_op_with_shaped_data('range', [1, 384], {'type': 'Range'}),
# keep dims
**valued_const_with_data('axes', np.array([0])),
**regular_op_with_empty_data('keep_shape', {'type': 'Unsqueeze'}),
**regular_op_with_shaped_data('keep_shape', [1, 384], {'type': 'Unsqueeze'}),
**valued_const_with_data('one', np.array(1)),
**regular_op_with_empty_data('bc', {'type': 'Broadcast'}),
**result(),
}, [
*connect('start', '0:range'),
*connect('shape', '0:range_dim'),
*connect('minus_one', '1:range_dim'),
*connect('zero', '2:range_dim'),
*connect('range_dim', '1:range'),
*connect('delta', '2:range'),
*connect('range', '0:keep_shape'),
*connect('axes', '1:keep_shape'),
*connect('keep_shape', '0:bc'),
*connect_data('shape', '1:bc'),
*connect('bc', 'output'),
], nodes_with_edges_only=True)
},
edges=[
*connect('value', 'shapeof'),
*connect('gather', '0:equal'),
('gather', 'select', {'in': 2, 'out': 0}),
('gather_const', 'select', {'in': 1}),
('equal', 'select', {'in': 0}),
*connect('minus_one_0', '1:gather'),
*connect('zero_0', '2:gather'),
*connect('shapeof', '0:gather_const'),
*connect('minus_one_1', '1:gather_const'),
*connect('zero_1', '2:gather_const'),
*connect('start', '0:range'),
*connect('select', '1:range'),
*connect('delta', '2:range'),
*connect('range', '0:keep_shape'),
*connect('axes', '1:keep_shape'),
*connect('keep_shape', '0:bc'),
*connect('one', '1:equal'),
*connect('shape', '1:bc'),
('shape_d', 'gather', {'out': 0, 'in': 0}),
*connect('bc', 'output'),
],
update_attributes={
'range_d': {'value': np.arange(0, 384).reshape((1, 384))},
'keep_shape_d': {'value': np.arange(0, 384).reshape((1, 384))},
})
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
self.assertTrue(flag, resp)