ExpandRangeConstant adjustment for bidirectional Broadcast (#6739)
* Fixes in the transformation ExpandRangeConstant. * Fixed test. * Now we use use ShapeOf for both inputs of Broadcast. * Now the transformation ExpandRangeConstant uses two Gather layers. * Deletec commented code. * Fixed tests for the transformation ExpandRangeConstant. * Rewritten the transformation ExpandRangeConstant using Select.
This commit is contained in:
committed by
GitHub
parent
6a63cb9122
commit
838e701e5e
@@ -3,12 +3,15 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Equal
|
||||
from extensions.ops.gather import Gather
|
||||
from extensions.ops.range import Range
|
||||
from extensions.ops.select import Select
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, rename_nodes, Node
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
|
||||
@@ -51,16 +54,31 @@ class ExpandRangeConstant(FrontReplacementSubgraph):
|
||||
|
||||
positive_idx = non_one_dims.item(0)
|
||||
negative_idx = positive_idx - len(shape)
|
||||
|
||||
node_name = node.soft_get('name', node.id)
|
||||
gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
|
||||
{'name': node.soft_get('name', node.id) + '/BroadcastingDim'})
|
||||
{'name': node_name + '/BroadcastingDim'})
|
||||
gather_for_const = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
|
||||
{'name': const_name + '/BroadcastingDim'})
|
||||
shapeof_node = Shape(graph, {'name': const_name + '/ShapeOf'}).create_node()
|
||||
shapeof_node.out_port(0).connect(gather_for_const.in_port(0))
|
||||
|
||||
equal_node = create_op_with_const_inputs(graph, Equal, {1: int64_array(1)}, {'name': node_name + '/ConstOne'})
|
||||
gather.out_port(0).connect(equal_node.in_port(0))
|
||||
|
||||
select_node = Select(graph, {'name': node_name + '/Select',
|
||||
'auto_broadcast': 'numpy'}).create_node([equal_node, gather_for_const, gather])
|
||||
|
||||
const.out_port(0).connect(shapeof_node.in_port(0))
|
||||
|
||||
range_node = create_op_with_const_inputs(graph, Range,
|
||||
{0: np.array(0, dtype=value.dtype),
|
||||
2: np.array(1, dtype=value.dtype)},
|
||||
{'name': const_name + '/Range', 'dtype': value.dtype})
|
||||
select_node.out_port(0).connect(range_node.in_port(1))
|
||||
|
||||
node.in_port(1).get_connection().add_destination(gather.in_port(0))
|
||||
gather.out_port(0).connect(range_node.in_port(1))
|
||||
|
||||
node.in_port(0).get_connection().set_source(range_node.out_port(0))
|
||||
|
||||
if one_dims.size:
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
from extensions.front.broadcast_with_range import ExpandRangeConstant
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \
|
||||
regular_op_with_empty_data, connect_data
|
||||
regular_op_with_empty_data
|
||||
|
||||
|
||||
class TestRangeBroadcast(unittest.TestCase):
|
||||
@@ -25,38 +25,61 @@ class TestRangeBroadcast(unittest.TestCase):
|
||||
], nodes_with_edges_only=True)
|
||||
ExpandRangeConstant().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph({
|
||||
graph_ref = build_graph(nodes_attrs={
|
||||
**regular_op_with_shaped_data('shape', [2], {'type': 'Parameter'}),
|
||||
**valued_const_with_data('value', np.arange(0, 384).reshape((1, 384))),
|
||||
**regular_op_with_empty_data('bc', {'type': 'Broadcast'}),
|
||||
**regular_op_with_empty_data('shapeof', {'type': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('select', {'type': 'Select'}),
|
||||
**regular_op_with_empty_data('gather', {'type': 'Gather'}),
|
||||
'gather_const': {'type': 'Gather', 'kind': 'op', 'op': 'Gather'},
|
||||
'equal': {'type': 'Equal', 'kind': 'op', 'op': 'Equal'},
|
||||
|
||||
# start
|
||||
**valued_const_with_data('start', np.array(0)),
|
||||
# limit
|
||||
**valued_const_with_data('minus_one', np.array(-1)),
|
||||
**valued_const_with_data('zero', np.array(0)),
|
||||
**regular_op_with_empty_data('range_dim', {'type': 'Gather'}),
|
||||
**valued_const_with_data('minus_one_0', np.array(-1)),
|
||||
**valued_const_with_data('zero_0', np.array(0)),
|
||||
**valued_const_with_data('minus_one_1', np.array(-1)),
|
||||
**valued_const_with_data('zero_1', np.array(0)),
|
||||
# delta
|
||||
**valued_const_with_data('delta', np.array(1)),
|
||||
**regular_op_with_empty_data('range', {'type': 'Range'}),
|
||||
**regular_op_with_shaped_data('range', [1, 384], {'type': 'Range'}),
|
||||
|
||||
# keep dims
|
||||
**valued_const_with_data('axes', np.array([0])),
|
||||
**regular_op_with_empty_data('keep_shape', {'type': 'Unsqueeze'}),
|
||||
**regular_op_with_shaped_data('keep_shape', [1, 384], {'type': 'Unsqueeze'}),
|
||||
|
||||
**valued_const_with_data('one', np.array(1)),
|
||||
|
||||
**regular_op_with_empty_data('bc', {'type': 'Broadcast'}),
|
||||
**result(),
|
||||
}, [
|
||||
*connect('start', '0:range'),
|
||||
*connect('shape', '0:range_dim'),
|
||||
*connect('minus_one', '1:range_dim'),
|
||||
*connect('zero', '2:range_dim'),
|
||||
*connect('range_dim', '1:range'),
|
||||
*connect('delta', '2:range'),
|
||||
*connect('range', '0:keep_shape'),
|
||||
*connect('axes', '1:keep_shape'),
|
||||
*connect('keep_shape', '0:bc'),
|
||||
*connect_data('shape', '1:bc'),
|
||||
*connect('bc', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
},
|
||||
edges=[
|
||||
*connect('value', 'shapeof'),
|
||||
*connect('gather', '0:equal'),
|
||||
('gather', 'select', {'in': 2, 'out': 0}),
|
||||
('gather_const', 'select', {'in': 1}),
|
||||
('equal', 'select', {'in': 0}),
|
||||
*connect('minus_one_0', '1:gather'),
|
||||
*connect('zero_0', '2:gather'),
|
||||
*connect('shapeof', '0:gather_const'),
|
||||
*connect('minus_one_1', '1:gather_const'),
|
||||
*connect('zero_1', '2:gather_const'),
|
||||
*connect('start', '0:range'),
|
||||
*connect('select', '1:range'),
|
||||
*connect('delta', '2:range'),
|
||||
*connect('range', '0:keep_shape'),
|
||||
*connect('axes', '1:keep_shape'),
|
||||
*connect('keep_shape', '0:bc'),
|
||||
*connect('one', '1:equal'),
|
||||
*connect('shape', '1:bc'),
|
||||
('shape_d', 'gather', {'out': 0, 'in': 0}),
|
||||
*connect('bc', 'output'),
|
||||
],
|
||||
update_attributes={
|
||||
'range_d': {'value': np.arange(0, 384).reshape((1, 384))},
|
||||
'keep_shape_d': {'value': np.arange(0, 384).reshape((1, 384))},
|
||||
})
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(flag, resp)
|
||||
Reference in New Issue
Block a user