Fix of ReverseInputChannels for NHWC layout. (#8504)
* Fixed ReverseInputChannels for new api. * Renamed function, added node name in assert, code refactoring. * Added ReverseChannels propagation through transposes. * Refactored get_fw_index() method. * Fixed wrong name. * Added negative axis support.
This commit is contained in:
parent
925e24525a
commit
f800993e6f
@ -50,10 +50,38 @@ class InsertReverseChannels(BackReplacementPattern):
|
|||||||
"""
|
"""
|
||||||
enabled = False
|
enabled = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_fw_index(node: Node, idx: int):
|
||||||
|
if not node.has_valid('rt_info'):
|
||||||
|
return idx
|
||||||
|
|
||||||
|
rt_info = node.rt_info
|
||||||
|
if not rt_info.contains('old_api_map'):
|
||||||
|
return idx
|
||||||
|
|
||||||
|
old_api_map_version = rt_info.get_attribute_version('old_api_map')
|
||||||
|
old_api_map = rt_info.info['old_api_map', old_api_map_version]
|
||||||
|
if 'inverse_order' not in old_api_map.info:
|
||||||
|
return idx
|
||||||
|
|
||||||
|
order = old_api_map.info['inverse_order']
|
||||||
|
node_name = node.soft_get('name', node.id)
|
||||||
|
|
||||||
|
if idx < 0:
|
||||||
|
assert not node.out_port(0).disconnected(), 'Cannot normalize negative axis {} in node {} ' \
|
||||||
|
'as out port is disconnected.'.format(idx, node_name)
|
||||||
|
data_rank = len(list(node.out_port(0).data.get_shape()))
|
||||||
|
idx = data_rank + idx
|
||||||
|
|
||||||
|
assert len(order) > idx >= 0, \
|
||||||
|
'Channel index {} is incompatible with old_api_map in node {}.'.format(idx, node_name)
|
||||||
|
return list(order).index(idx)
|
||||||
|
|
||||||
def find_and_replace_pattern(self, graph: Graph):
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
|
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
|
||||||
for p in graph.get_op_nodes(type='Parameter')]
|
for p in graph.get_op_nodes(type='Parameter')]
|
||||||
suitable_params = [(name, p, shape) for name, p, shape in all_params if len(shape) == 4 and shape[1] == 3]
|
suitable_params = [(name, p, shape) for name, p, shape in all_params if
|
||||||
|
len(shape) == 4 and shape[self.get_fw_index(p, 1)] == 3]
|
||||||
|
|
||||||
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
|
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
|
||||||
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape in suitable_params}))
|
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape in suitable_params}))
|
||||||
@ -66,10 +94,14 @@ class InsertReverseChannels(BackReplacementPattern):
|
|||||||
extra={'is_warning': True})
|
extra={'is_warning': True})
|
||||||
|
|
||||||
for name, parameter, _ in suitable_params:
|
for name, parameter, _ in suitable_params:
|
||||||
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels'}).create_node()
|
reverse_index = self.get_fw_index(parameter, 1)
|
||||||
parameter.out_port(0).get_connection().set_source(reverse_channels.out_port(0),
|
|
||||||
attributes_save_mode='source')
|
if parameter.out_port(0).disconnected():
|
||||||
parameter.out_port(0).connect(reverse_channels.in_port(0))
|
continue
|
||||||
|
|
||||||
|
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels',
|
||||||
|
'axis': reverse_index}).create_node()
|
||||||
|
parameter.out_port(0).get_connection().insert_node(reverse_channels, attributes_save_mode='source')
|
||||||
|
|
||||||
|
|
||||||
class ReverseChannelsPropagationDown(BackReplacementPattern):
|
class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||||
@ -95,11 +127,30 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
|
|||||||
'Shape': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
|
'Shape': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
|
||||||
'ShapeOf': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
|
'ShapeOf': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
|
||||||
|
|
||||||
'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through(node, rc),
|
'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, rc),
|
||||||
|
'Transpose': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_transpose(node, rc),
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pass_rc_through(node: Node, reverse_channels: Node):
|
def pass_rc_through_transpose(node: Node, reverse_channels: Node):
|
||||||
|
if node.in_port(1).disconnected() or node.in_port(0).disconnected():
|
||||||
|
return False
|
||||||
|
order = node.in_port(1).data.get_value()
|
||||||
|
reverse_axis = reverse_channels.axis
|
||||||
|
data_rank = len(list(node.in_port(0).data.get_shape()))
|
||||||
|
|
||||||
|
if reverse_axis < 0:
|
||||||
|
reverse_axis = data_rank + reverse_axis
|
||||||
|
assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels)
|
||||||
|
|
||||||
|
if order is None:
|
||||||
|
return False
|
||||||
|
new_axis = list(order).index(reverse_axis)
|
||||||
|
reverse_channels.axis = new_axis
|
||||||
|
return ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pass_rc_through_zero_port_only(node: Node, reverse_channels: Node):
|
||||||
r"""
|
r"""
|
||||||
BEFORE AFTER
|
BEFORE AFTER
|
||||||
|
|
||||||
@ -295,11 +346,30 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
|||||||
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||||
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||||
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||||
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc),
|
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, rc),
|
||||||
|
'Transpose': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_transpose(node, rc),
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def lift_up_through_pad(node: Node, reverse_channels: Node):
|
def lift_up_through_transpose(node: Node, reverse_channels: Node):
|
||||||
|
if node.in_port(1).disconnected() or node.in_port(0).disconnected():
|
||||||
|
return False
|
||||||
|
order = node.in_port(1).data.get_value()
|
||||||
|
reverse_axis = reverse_channels.axis
|
||||||
|
data_rank = len(list(node.in_port(0).data.get_shape()))
|
||||||
|
|
||||||
|
if reverse_axis < 0:
|
||||||
|
reverse_axis = data_rank + reverse_axis
|
||||||
|
assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels)
|
||||||
|
|
||||||
|
if order is None:
|
||||||
|
return False
|
||||||
|
new_axis = order[reverse_axis]
|
||||||
|
reverse_channels.axis = new_axis
|
||||||
|
return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def lift_up_through_zero_port_only(node: Node, reverse_channels: Node):
|
||||||
r"""
|
r"""
|
||||||
BEFORE AFTER
|
BEFORE AFTER
|
||||||
|
|
||||||
@ -307,7 +377,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
|||||||
\
|
\
|
||||||
previous_op previous_op ReverseChannels previous_op
|
previous_op previous_op ReverseChannels previous_op
|
||||||
\ / \ /
|
\ / \ /
|
||||||
Pad Pad
|
Node Node
|
||||||
| |
|
| |
|
||||||
ReverseChannels next_op
|
ReverseChannels next_op
|
||||||
|
|
|
|
||||||
@ -323,7 +393,16 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
|||||||
reverse_channels.out_port(0).disconnect()
|
reverse_channels.out_port(0).disconnect()
|
||||||
reverse_channels.in_port(0).disconnect()
|
reverse_channels.in_port(0).disconnect()
|
||||||
src = node_input_port_0.get_connection().get_source()
|
src = node_input_port_0.get_connection().get_source()
|
||||||
|
|
||||||
|
if src.node.soft_get('type') == 'Parameter':
|
||||||
|
# For Parameter nodes tensor debug attributes should not move to the last node
|
||||||
|
# of subgraph. It is needed for the proper mapping of input framework name.
|
||||||
|
# For this reason "source" mode is used to keep tensor debug attributes at Parameter node.
|
||||||
|
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0),
|
||||||
|
attributes_save_mode="source")
|
||||||
|
else:
|
||||||
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
|
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
|
||||||
|
|
||||||
src.connect(reverse_channels.in_port(0))
|
src.connect(reverse_channels.in_port(0))
|
||||||
for reverse_channels_destination in reverse_channels_out_nodes:
|
for reverse_channels_destination in reverse_channels_out_nodes:
|
||||||
node.out_port(0).get_connection().add_destination(reverse_channels_destination)
|
node.out_port(0).get_connection().add_destination(reverse_channels_destination)
|
||||||
|
@ -29,6 +29,17 @@ class RTInfo:
|
|||||||
"""
|
"""
|
||||||
self.info = defaultdict(dict)
|
self.info = defaultdict(dict)
|
||||||
|
|
||||||
|
def contains(self, attribute_name: str):
|
||||||
|
attr_count = [key[0] for key in list(self.info.keys())].count(attribute_name)
|
||||||
|
assert attr_count <= 1, 'Incorrect rt_info attribute, got more than one {}.'.format(attribute_name)
|
||||||
|
return attr_count > 0
|
||||||
|
|
||||||
|
def get_attribute_version(self, attribute_name: str):
|
||||||
|
for name, version in list(self.info.keys()):
|
||||||
|
if name == attribute_name:
|
||||||
|
return version
|
||||||
|
raise Exception("rt_info does not contain attribute with name {}".format(attribute_name))
|
||||||
|
|
||||||
|
|
||||||
class RTInfoElement:
|
class RTInfoElement:
|
||||||
"""
|
"""
|
||||||
|
@ -3,13 +3,16 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown
|
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown, \
|
||||||
from mo.graph.graph import Node, Graph
|
InsertReverseChannels
|
||||||
from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data
|
|
||||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||||
|
from mo.graph.graph import Node, Graph
|
||||||
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from mo.utils.runtime_info import OldAPIMap, RTInfo
|
||||||
|
from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data
|
||||||
|
|
||||||
nodes = {
|
nodes = {
|
||||||
**regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter'}),
|
**regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter', 'rt_info': RTInfo()}),
|
||||||
**regular_op_with_shaped_data('placeholder2', [1, 1, 1, 1], {'type': 'Parameter'}),
|
**regular_op_with_shaped_data('placeholder2', [1, 1, 1, 1], {'type': 'Parameter'}),
|
||||||
|
|
||||||
**regular_op_with_shaped_data('mul', [1, 3, 10, 10], {'type': 'Multiply'}),
|
**regular_op_with_shaped_data('mul', [1, 3, 10, 10], {'type': 'Multiply'}),
|
||||||
@ -35,6 +38,17 @@ nodes2 = {
|
|||||||
**result('result2'),
|
**result('result2'),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nodes3 = {
|
||||||
|
**regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}),
|
||||||
|
**regular_op_with_shaped_data('transpose', [1, 3, 10, 10], {'type': 'Transpose'}),
|
||||||
|
**valued_const_with_data('transpose_order', int64_array([0, 3, 1, 2])),
|
||||||
|
**regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 3}),
|
||||||
|
**regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
|
||||||
|
**result('result'),
|
||||||
|
**result('result2'),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ReverseInputChannelsTest(unittest.TestCase):
|
class ReverseInputChannelsTest(unittest.TestCase):
|
||||||
def check_graph_attrs(self, graph: Graph, parameter_node_names: list):
|
def check_graph_attrs(self, graph: Graph, parameter_node_names: list):
|
||||||
for node in graph.get_op_nodes():
|
for node in graph.get_op_nodes():
|
||||||
@ -75,12 +89,11 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
|||||||
node = Node(graph, 'pad')
|
node = Node(graph, 'pad')
|
||||||
reverse_channels = Node(graph, 'reverse_channels')
|
reverse_channels = Node(graph, 'reverse_channels')
|
||||||
|
|
||||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
|
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||||
self.assertTrue(keep_moving_up is True)
|
self.assertTrue(keep_moving_up is True)
|
||||||
self.assertTrue(len(new_reverses) == 1)
|
self.assertTrue(len(new_reverses) == 1)
|
||||||
self.check_graph_attrs(graph, ['placeholder'])
|
self.check_graph_attrs(graph, ['placeholder'])
|
||||||
|
|
||||||
|
|
||||||
def test_lift_up_through_pad2(self):
|
def test_lift_up_through_pad2(self):
|
||||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||||
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
|
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
|
||||||
@ -91,12 +104,11 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
|||||||
node = Node(graph, 'pad')
|
node = Node(graph, 'pad')
|
||||||
reverse_channels = Node(graph, 'reverse_channels')
|
reverse_channels = Node(graph, 'reverse_channels')
|
||||||
|
|
||||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
|
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||||
self.assertTrue(keep_moving_up is True)
|
self.assertTrue(keep_moving_up is True)
|
||||||
self.assertTrue(len(new_reverses) == 1)
|
self.assertTrue(len(new_reverses) == 1)
|
||||||
self.check_graph_attrs(graph, ['placeholder'])
|
self.check_graph_attrs(graph, ['placeholder'])
|
||||||
|
|
||||||
|
|
||||||
def test_pass_rc_through(self):
|
def test_pass_rc_through(self):
|
||||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||||
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
|
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
|
||||||
@ -107,5 +119,114 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
|||||||
node = Node(graph, 'pad')
|
node = Node(graph, 'pad')
|
||||||
reverse_channels = Node(graph, 'reverse_channels')
|
reverse_channels = Node(graph, 'reverse_channels')
|
||||||
|
|
||||||
ReverseChannelsPropagationDown.pass_rc_through(node, reverse_channels)
|
ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
|
||||||
self.check_graph_attrs(graph, ['placeholder'])
|
self.check_graph_attrs(graph, ['placeholder'])
|
||||||
|
|
||||||
|
def test_lift_up_through_transpose(self):
|
||||||
|
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('transpose', 'reverse_channels_down'),
|
||||||
|
*connect('reverse_channels_down', 'result')])
|
||||||
|
graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'),
|
||||||
|
*connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('reverse_channels_down', 'transpose'),
|
||||||
|
*connect('transpose', 'result')])
|
||||||
|
self.set_graph_attrs(graph, ['placeholder'])
|
||||||
|
|
||||||
|
node = Node(graph, 'transpose')
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||||
|
|
||||||
|
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
|
||||||
|
self.assertTrue(keep_moving_up is True)
|
||||||
|
self.assertTrue(len(new_reverses) == 1)
|
||||||
|
self.check_graph_attrs(graph, ['placeholder'])
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||||
|
self.assertTrue(reverse_channels.axis == 3)
|
||||||
|
|
||||||
|
def test_lift_down_through_transpose(self):
|
||||||
|
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
|
||||||
|
*connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('reverse_channels_up', '0:transpose'),
|
||||||
|
*connect('transpose', 'result')])
|
||||||
|
graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'),
|
||||||
|
*connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('transpose', 'reverse_channels_up'),
|
||||||
|
*connect('reverse_channels_up', '0:result')])
|
||||||
|
self.set_graph_attrs(graph, ['placeholder'])
|
||||||
|
|
||||||
|
node = Node(graph, 'transpose')
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_up')
|
||||||
|
|
||||||
|
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
|
||||||
|
|
||||||
|
self.assertTrue(keep_moving_down is True)
|
||||||
|
self.check_graph_attrs(graph, ['placeholder'])
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||||
|
self.assertTrue(reverse_channels.axis == 1)
|
||||||
|
|
||||||
|
def test_lift_up_through_transpose_negative_axis(self):
|
||||||
|
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('transpose', 'reverse_channels_down'),
|
||||||
|
*connect('reverse_channels_down', 'result')])
|
||||||
|
graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'),
|
||||||
|
*connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('reverse_channels_down', 'transpose'),
|
||||||
|
*connect('transpose', 'result')])
|
||||||
|
self.set_graph_attrs(graph, ['placeholder'])
|
||||||
|
|
||||||
|
node = Node(graph, 'transpose')
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||||
|
reverse_channels.axis = -3
|
||||||
|
|
||||||
|
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
|
||||||
|
self.assertTrue(keep_moving_up is True)
|
||||||
|
self.assertTrue(len(new_reverses) == 1)
|
||||||
|
self.check_graph_attrs(graph, ['placeholder'])
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||||
|
self.assertTrue(reverse_channels.axis == 3)
|
||||||
|
|
||||||
|
def test_lift_down_through_transpose_negative_axis(self):
|
||||||
|
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
|
||||||
|
*connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('reverse_channels_up', '0:transpose'),
|
||||||
|
*connect('transpose', 'result')])
|
||||||
|
graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'),
|
||||||
|
*connect('transpose_order', '1:transpose'),
|
||||||
|
*connect('transpose', 'reverse_channels_up'),
|
||||||
|
*connect('reverse_channels_up', '0:result')])
|
||||||
|
self.set_graph_attrs(graph, ['placeholder'])
|
||||||
|
|
||||||
|
node = Node(graph, 'transpose')
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_up')
|
||||||
|
reverse_channels.axis = -1
|
||||||
|
|
||||||
|
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
|
||||||
|
|
||||||
|
self.assertTrue(keep_moving_down is True)
|
||||||
|
self.check_graph_attrs(graph, ['placeholder'])
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||||
|
self.assertTrue(reverse_channels.axis == 1)
|
||||||
|
|
||||||
|
def test_get_fw_index(self):
|
||||||
|
graph = build_graph(nodes, [*connect('placeholder1', 'result')])
|
||||||
|
node = Node(graph, 'placeholder1')
|
||||||
|
old_api_map = OldAPIMap(version=0)
|
||||||
|
node.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
|
||||||
|
node.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1])
|
||||||
|
self.assertTrue(InsertReverseChannels.get_fw_index(node, 0) == 0)
|
||||||
|
self.assertTrue(InsertReverseChannels.get_fw_index(node, 1) == 3)
|
||||||
|
self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1)
|
||||||
|
self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2)
|
||||||
|
self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user