publish master branch snapshot, revision ea98a886d925eb152931aab13856e68037665562

This commit is contained in:
Alexey Suhov
2020-05-22 03:42:00 +03:00
parent deb008a26f
commit ccb7438803
45 changed files with 2674 additions and 875 deletions

View File

@@ -23,7 +23,7 @@ from mo.ops.crop import Crop
from mo.utils.logger import log
class CutMemory(BackReplacementPattern):
class CutMemoryInput(BackReplacementPattern):
"""
Cut Memory layers and have inputs/outputs in graph instead of them
"""
@@ -38,30 +38,56 @@ class CutMemory(BackReplacementPattern):
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='Memory'))],
('op', dict(kind='op', op='ReadValue'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['id']
node_id = node['variable_id']
if node.in_port(0).disconnected():
i = 0
for dest in node.out_port(0).get_destinations():
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
'shape': dest.data.get_shape()}).create_node()
i += 1
dest.disconnect()
new_in.out_port(0).connect(dest)
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
extra={'is_warning': True})
else:
out_node_port = node.out_port(0).get_destination()
in_node_port = node.in_port(0).get_source()
node.in_port(0).disconnect()
node.out_port(0).disconnect()
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0])}).create_node()
in_node_port.connect(crop.in_port(0))
crop.out_port(0).connect(out_node_port)
i = 0
node.in_port(0).disconnect()
for dest in node.out_port(0).get_destinations():
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
'shape': dest.data.get_shape()}).create_node()
i += 1
dest.disconnect()
new_in.out_port(0).connect(dest)
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
extra={'is_warning': True})
class CutMemoryOutput(BackReplacementPattern):
"""
Cut Memory layers and have inputs/outputs in graph instead of them
"""
enabled = True
graph_condition = [lambda graph: graph.graph['fw'] == "kaldi" and graph.graph['cmd_params'].remove_memory]
force_clean_up = True
def run_before(self):
return [ParameterToInput]
@staticmethod
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='Assign'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['variable_id']
out_node_port = node.out_port(0).get_destination()
in_node_port = node.in_port(0).get_source()
node.in_port(0).disconnect()
node.out_port(0).disconnect()
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]),
'axis': np.array([0])}).create_node()
in_node_port.connect(crop.in_port(0))
crop.out_port(0).connect(out_node_port)

View File

@@ -17,7 +17,7 @@ import unittest
import numpy as np
from extensions.back.CutMemory import CutMemory
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
@@ -29,18 +29,21 @@ class CutMemoryTest(unittest.TestCase):
nodes_attrs={
'input': {'kind': 'op'},
'data_in': {'kind': 'data', 'shape': None, 'value': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'index': 1, 'id': 'memory_', 'in_ports_count': 1},
'const_0': {'kind': 'op', 'op': 'Const'},
'const_0_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'variable_id': 'memory_'},
'data_mem': {'kind': 'data', 'shape': None, 'value': None},
'concat': {'kind': 'op', 'op': 'Concat', 'axis': 0},
'concat_data': {'kind': 'data', 'shape': None, 'value': None},
'some_op': {'kind': 'op'},
'some_op_data': {'kind': 'data', 'shape': None, 'value': None},
'memory_out': {'kind': 'op', 'op': 'Memory', 'index': 0, 'id': 'memory_'},
'memory_out': {'kind': 'op', 'op': 'Assign', 'variable_id': 'memory_'},
'data_mem_out': {'kind': 'data', 'shape': None, 'value': None},
'mem_out_result': {'kind': 'op', 'op': 'Result'}
},
edges=[
('input', 'data_in'), ('memory_in', 'data_mem'),
('input', 'data_in'),
('const_0', 'const_0_data'), ('const_0_data', 'memory_in'), ('memory_in', 'data_mem'),
('data_in', 'concat', {'in': 0}), ('data_mem', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'some_op'),
('some_op', 'some_op_data'), ('some_op_data', 'memory_out'),
@@ -69,7 +72,8 @@ class CutMemoryTest(unittest.TestCase):
('crop', 'crop_data'), ('crop_data', 'mem_out_result')
],
)
CutMemory().find_and_replace_pattern(graph)
CutMemoryInput().find_and_replace_pattern(graph)
CutMemoryOutput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mem_out_result', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@@ -0,0 +1,140 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
from mo.ops.memory import Memory
"""
All transformations in this file should be removed after removing IR v7 support
"""
class ReplaceReadValueByMemory(BackReplacementPattern):
"""
Replace ReadValue by Memory. Should be removed after v7 IR support removing.
"""
enabled = True
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
force_clean_up = True
def run_after(self):
return [CutMemoryInput]
@staticmethod
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='ReadValue'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['variable_id']
node.in_port(0).disconnect()
new_in = Memory(graph, {'name': node.id, 'id': node_id, 'index': 1, 'size': 2,
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
for dest in node.out_port(0).get_destinations():
dest.disconnect()
new_in.out_port(0).connect(dest)
class ReplaceAssignByMemory(BackReplacementPattern):
"""
Replace Assign by Memory. Should be removed after v7 IR support removing.
"""
enabled = True
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
force_clean_up = True
def run_after(self):
return [CutMemoryOutput]
@staticmethod
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='Assign'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['variable_id']
new_out = Memory(graph, {'name': node.id, 'id': node_id, 'index': 0, 'size': 2,
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
node.in_port(0).get_source().connect(new_out.in_port(0))
node.in_port(0).disconnect()
node.out_port(0).get_connection().set_source(new_out.out_port(0))
class KaldiRemoveMemoryOutputBackReplacementPatternV7(BackReplacementPattern):
enabled = True
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
def run_after(self):
from extensions.back.pass_separator import BackFinish
return [BackFinish]
def run_before(self):
from extensions.back.SpecialNodesFinalization import RemoveOutputOps
return [RemoveOutputOps]
@staticmethod
def pattern():
return dict(
nodes=[
('memory_node', dict(op='Memory')),
('data_node', dict(kind='data')),
('op_output', dict(op='Result'))
],
edges=[
('memory_node', 'data_node'),
('data_node', 'op_output')
]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
"""
Need to find the pattern: Memory -> Data -> Result
It is needed to make Memory nodes appear in IR,
but they are output nodes by default and we remove the Result node after each output memory.
DO NOT use graph clean up after it
otherwise Memory nodes would be removed as they are not on the path from input to output
Parameters
----------
graph : Graph
Graph with loaded model.
match : dict
Patterns which were found in graph structure.
"""
memory = match['memory_node']
data = match['data_node']
op_output = match['op_output']
graph.remove_edge(memory.id, data.id)
graph.remove_node(data.id)
graph.remove_node(op_output.id)

View File

@@ -33,7 +33,7 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
def pattern():
return dict(
nodes=[
('memory_node', dict(op='Memory')),
('memory_node', dict(op='Assign')),
('data_node', dict(kind='data')),
('op_output', dict(op='Result'))
],
@@ -63,6 +63,8 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
"""
memory = match['memory_node']
data = match['data_node']
op_output = match['op_output']
graph.remove_edge(memory.id, data.id)
graph.remove_node(data.id)
graph.remove_node(op_output.id)

View File

@@ -26,7 +26,7 @@ class KaldiRemoveMemoryOutputTest(unittest.TestCase):
'kind': 'data'
},
'memory_node': {
'op': 'Memory',
'op': 'Assign',
'kind': 'op'
},
'output_node': {

View File

@@ -63,7 +63,7 @@ def apply_biases_to_last_layer(graph, counts):
outputs_ids = find_outputs(graph)
for output in outputs_ids.copy():
node = Node(graph, output)
if node.op != 'Memory':
if node.op != 'Assign':
continue
outputs_ids.remove(output)

View File

@@ -13,12 +13,12 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph
from mo.ops.eltwise import Eltwise
from mo.ops.eltwise_n import EltwiseN
from mo.utils.error import Error
@@ -43,8 +43,12 @@ class ReplaceEltwiseNin1NodePattern(FrontReplacementOp):
edge_attrs = inp[0][1]
graph.add_edge(in_node, ss_node.id, **edge_attrs)
if ss_node.num_splits == 2:
eltwise_node = Eltwise(graph, attrs={'name': 'Eltwise_' + node.name,
'operation': node['operation']}).create_node()
if node['operation'] == 'mul':
eltwise_node = Mul(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
elif node['operation'] == 'sum':
eltwise_node = Add(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
else:
raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation'])
elif ss_node.num_splits > 2:
eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
'operation': node['operation']}).create_node()

View File

@@ -20,13 +20,19 @@ from extensions.ops.activation_ops import Tanh, Sigmoid
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from mo.front.caffe.extractors.utils import input_as_const
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Node, Graph
from mo.graph.graph import Node, Graph, Port
from mo.ops.assign import Assign
from mo.ops.broadcast import Broadcast
from mo.ops.clamp import Clamp
from mo.ops.crop import Crop
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.ops.scale_shift import ScaleShiftOp
from mo.ops.shape import Shape
def unique_id(prefix: str = 'id') -> str:
@@ -46,6 +52,35 @@ def unique_id(prefix: str = 'id') -> str:
unique_id.names = []
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
# create init_graph connected to ReadValue
graph = input_out_port.node.graph
input_name = input_out_port.node.name
shape_of_input = Shape(graph, {'name': 'shape/' + input_name}).create_node()
shape_of_input.in_port(0).connect(input_out_port)
dim_for_get_batch = Const(graph, {'name': 'dim/crop_batch/'+shape_of_input.name,
'value': int64_array([1]), 'shape': int64_array([1])}).create_node()
get_batch = Crop(graph, {'name': 'crop_batch/' + shape_of_input.name,
'axis': int64_array([0]), 'offset': int64_array([0])
}).create_node()
get_batch.in_port(0).connect(shape_of_input.out_port(0))
get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
mem_shape_2nd_dim = Const(graph, {'name': 'gifo_r_weights_shape/'+input_name,
'value': int64_array([second_dim]),
'shape': int64_array([1])}).create_node()
mem_shape = Concat(graph, {'name': 'gather_memory_shape/' + input_name,
'axis': 0, 'in_ports_count': 2}).create_node()
mem_shape.in_port(0).connect(get_batch.out_port(0))
mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
fill_value = Const(graph, {'name': 'fill_value/'+input_name,
'value': np.array([0.0], precision), 'shape': int64_array([1])}).create_node()
init_value_prev_lstm_output = Broadcast(graph, {'name': 'init_value/'+input_name,
}).create_node()
init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
return init_value_prev_lstm_output
class ReplaceLSTMNodePattern(FrontReplacementOp):
op = "LSTMCell"
enabled = True
@@ -69,7 +104,7 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
)
def replace_op(self, graph: Graph, node: Node):
input_node = node.in_node()
input_out_port = node.in_port(0).get_source()
memory_pair_input = unique_id('id')
memory_pair_output = unique_id('id')
@@ -81,16 +116,17 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
'bias_term': True,
}
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node([input_node])
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node()
fc_layer_after_input.in_port(0).connect(input_out_port)
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
prev_lstm_output = Memory(graph, {'name': 'prev_memory_output',
'id': memory_pair_input,
'index': 1,
'size': 2,
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
}).create_node()
init_value_prev_lstm_output = create_zero_value_with_batch_from_input(input_out_port,
node.gifo_r_weights_shape[1])
prev_lstm_output = ReadValue(graph, {'name': 'prev_memory_output',
'variable_id': memory_pair_input
}).create_node()
prev_lstm_output.in_port(0).connect(init_value_prev_lstm_output.out_port(0))
# *Memory(output) -> FullyConnected
fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
@@ -99,15 +135,16 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
'bias_term': False,
}
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node(
[prev_lstm_output])
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node()
fc_layer_from_prev_state.in_port(0).connect(prev_lstm_output.out_port(0))
input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
# Memory -> FullyConnected \
# *Eltwise(sum)
# Input -> FullyConnected /
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise',
}).create_node([fc_layer_from_prev_state, fc_layer_after_input])
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise'}).create_node()
join_input_prev_state_sum.in_port(0).connect(fc_layer_from_prev_state.out_port(0))
join_input_prev_state_sum.in_port(1).connect(fc_layer_after_input.out_port(0))
# *Eltwise(sum) -> Split
# it is split into 4 nodes: Act, Eltw*3
@@ -120,131 +157,147 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
# |____(4)Eltwise(sum)
split_joined_input_axis = Const(graph, {'value': np.int64(1)}).create_node()
split_joined_input = Split(graph, {'name': 'join_input_split',
'num_splits': 4,
}).create_node([join_input_prev_state_sum, split_joined_input_axis])
'num_splits': 4, 'out_ports_count': 4}).create_node()
split_joined_input.in_port(0).connect(join_input_prev_state_sum.out_port(0))
split_joined_input.in_port(1).connect(split_joined_input_axis.out_port(0))
prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
'id': memory_pair_output,
'index': 1,
'size': 2,
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
}).create_node()
# prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
# 'id': memory_pair_output,
# 'index': 1,
# 'size': 2,
# 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
# }).create_node()
init_value_prev_lstm_state = create_zero_value_with_batch_from_input(split_joined_input.out_port(0),
node.input_gate_weights.shape[0])
prev_lstm_state = ReadValue(graph, {'name': 'prev_memory_state',
'variable_id': memory_pair_output}).create_node()
prev_lstm_state.in_port(0).connect(init_value_prev_lstm_state.out_port(0))
# *Memory(state) -> *ScaleShift(input)
state_input_scaleshift_attrs = {'name': 'input_scaleshift',
'bias_term': False
}
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node()
state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
# *Memory(state) -> *ScaleShift(forget)
state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
'bias_term': False
}
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state])
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node()
state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
# Split \
# (2)Eltwise(sum)
# Memory(state) -> *ScaleShift(input) /
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise',
}).create_node([(split_joined_input, 1),
state_input_scaleshift
])
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise'
}).create_node()
join_prev_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(1))
join_prev_lstm_input_joined_input_sum.in_port(1).connect(state_input_scaleshift.out_port(0))
# Split \
# (3)Eltwise(sum)
# Memory(state) -> *ScaleShift(forget) /
join_prev_lstm_input_joined_forget_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
}).create_node([(split_joined_input, 2),
state_forget_scaleshift
])
}).create_node()
join_prev_lstm_input_joined_forget_sum.in_port(0).connect(split_joined_input.out_port(2))
join_prev_lstm_input_joined_forget_sum.in_port(1).connect(state_forget_scaleshift.out_port(0))
# Split -> Tanh
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node([(split_joined_input, 0)])
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
remember_tahn.in_port(0).connect(split_joined_input.out_port(0))
# Split -> (2)Eltwise(sum) -> *Sigmoid
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'
}).create_node([join_prev_lstm_input_joined_input_sum])
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'}).create_node()
remember_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_input_sum.out_port(0))
# Split -> (3)Eltwise(sum) -> **Sigmoid
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'
}).create_node([join_prev_lstm_input_joined_forget_sum])
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'}).create_node()
forget_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_forget_sum.out_port(0))
# *Memory(state) \
# (6)Eltwise(mul)
# Split -> (3)Eltwise(sum) -> **Sigmoid /
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul',
}).create_node([forget_sigmoid, prev_lstm_state])
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul'}).create_node()
join_forget_prev_state_mul.in_port(0).connect(forget_sigmoid.out_port(0))
join_forget_prev_state_mul.in_port(1).connect(prev_lstm_state.out_port(0))
# Split -> Tahn \
# (5)Eltwise(mul)
# Split -> (2)Eltwise(sum) -> *Sigmoid /
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul',
}).create_node([remember_tahn, remember_sigmoid])
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul'}).create_node()
join_remember_candidates_mul.in_port(0).connect(remember_tahn.out_port(0))
join_remember_candidates_mul.in_port(1).connect(remember_sigmoid.out_port(0))
# (5)Eltwise(mul) \
# (7)Eltwise(sum)
# (6)Eltwise(mul) /
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum',
}).create_node(
[join_forget_prev_state_mul, join_remember_candidates_mul])
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum'}).create_node()
join_forget_remember_sum.in_port(0).connect(join_forget_prev_state_mul.out_port(0))
join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
# (7)Eltwise(sum) -> Clamp
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
'max': node.clip_value,
'min': -node.clip_value
}).create_node(
[join_forget_remember_sum])
'min': -node.clip_value}).create_node()
join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
#
# Clamp -> (2)Memory(state)
next_lstm_state = Memory(graph, {'name': 'next_lstm_state',
'id': memory_pair_output,
'index': 0,
'size': 2,
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
}).create_node([join_forget_clamp])
Result(graph, {'name': 'next_lstm_state_out'}).create_node([next_lstm_state])
next_lstm_state = Assign(graph, {'name': 'next_lstm_state',
'variable_id': memory_pair_output}).create_node()
next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))
res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
res_node.in_port(0).connect(next_lstm_state.out_port(0))
# Clamp -> (2)Tahn
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node([join_forget_clamp])
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node()
state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))
# Clamp -> (2)ScaleShift
clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
'bias_term': False}
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node()
clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
# Split \
# (4)Eltwise(sum)
# Clamp -> (2)ScaleShift /
join_next_lstm_input_joined_input_sum = Add(graph, {'name': 'join_next_lstm_input_joined_input_sum',
}).create_node([(split_joined_input, 3), clamp_scaleshift])
}).create_node()
join_next_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(3))
join_next_lstm_input_joined_input_sum.in_port(1).connect(clamp_scaleshift.out_port(0))
# (4)Eltwise(sum) -> (3)Sigmoid
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node([join_next_lstm_input_joined_input_sum])
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node()
output_sigmoid.in_port(0).connect(join_next_lstm_input_joined_input_sum.out_port(0))
# (4)Eltwise(sum) -> (3)Sigmoid \
# (5)Eltwise(mul)
# Clamp -> (2)Tahn /
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node([state_filtered_tahn, output_sigmoid])
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node()
joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))
# (5)Eltwise(mul) -> (3)FullyConnected
fc_output_attrs = {'name': 'FullyConnected',
'out-size': node.projection_weights_shape[0],
'transpose_weights': True,
'bias_term': False}
fc_output = FullyConnected(graph, fc_output_attrs).create_node([joined_output_mul])
fc_output = FullyConnected(graph, fc_output_attrs).create_node()
fc_output.in_port(0).connect(joined_output_mul.out_port(0))
input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights)
# / (2)Memory(output)
# (3)FullyConnected
# \ Output (any next node) (edge created automatically after replacement)
next_lstm_output = Memory(graph, {'name': 'next_lstm_output',
'id': memory_pair_input,
'index': 0,
'size': 2,
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
}).create_node([fc_output])
Result(graph, {'name': 'next_lstm_output_out'}).create_node([next_lstm_output])
next_lstm_output = Assign(graph, {'name': 'next_lstm_output',
'variable_id': memory_pair_input}).create_node()
next_lstm_output.in_port(0).connect(fc_output.out_port(0))
res_node_lstm_output = Result(graph, {'name': 'next_lstm_output_out'}).create_node()
res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))
return [fc_output.id]

View File

@@ -22,7 +22,7 @@ from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Node, Graph
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.eltwise import Eltwise
from extensions.ops.elementwise import Add, Mul
from mo.ops.scale_shift import ScaleShiftOp
@@ -41,19 +41,19 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
def replace_op(self, graph: Graph, node: Node):
# split input to (i_part, f_part, c_part, o_part, ct_1)
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
split_node = Split(graph, {'name': graph.unique_id(prefix='Split_lstm_input_'),
split_node = Split(graph, {'name': 'Split_lstm_input_',
'num_splits': 5}).create_node()
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
split_node.in_port(1).connect(split_node_axis.out_port(0))
# i_t = Sigmoid(i_part + w_ic*ct_1)
i_scale_attrs = {'name': graph.unique_id(prefix='i_scaleshift'),
i_scale_attrs = {'name': 'i_scaleshift',
'bias_term': False}
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
split_node.out_port(4).connect(i_scale.in_port(0))
sum_i_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum'}).create_node()
sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
split_node.out_port(0).connect(sum_i_c.in_port(0))
i_scale.out_port(0).connect(sum_i_c.in_port(1))
@@ -61,13 +61,13 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
# f_t = Sigmoid(f_part + w_fc*ct_1)
f_scale_attrs = {'name': graph.unique_id(prefix='f_scaleshift'),
f_scale_attrs = {'name': 'f_scaleshift',
'bias_term': False}
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
split_node.out_port(4).connect(f_scale.in_port(0))
sum_f_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum'}).create_node()
sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
split_node.out_port(1).connect(sum_f_c.in_port(0))
f_scale.out_port(0).connect(sum_f_c.in_port(1))
@@ -78,28 +78,26 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
split_node.out_port(2).connect(c_tanh.in_port(0))
prod_i_c_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_i_c_tanh_'),
'operation': 'mul'}).create_node()
prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
prod_f_ct_1 = Eltwise(graph, {'name': graph.unique_id(prefix='prod_f_ct_1_'),
'operation': 'mul'}).create_node()
prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
sum_f_i = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum'}).create_node()
sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
# o_t = Sigmoid(o_part + w_oc*c_t)
o_scale_attrs = {'name': graph.unique_id(prefix='o_scaleshift'),
o_scale_attrs = {'name': 'o_scaleshift',
'bias_term': False}
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
sum_f_i.out_port(0).connect(o_scale.in_port(0))
sum_o_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum'}).create_node()
sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
split_node.out_port(3).connect(sum_o_c.in_port(0))
o_scale.out_port(0).connect(sum_o_c.in_port(1))
@@ -110,13 +108,12 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
prod_o_c_t_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
'operation': 'mul'}).create_node()
prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
# add concat to create 1 output
concat = Concat(graph, {'name': graph.unique_id(prefix='Concat_c_m')}).create_node()
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
concat.add_sequence_of_ports('in', range(2))
sum_f_i.out_port(0).connect(concat.in_port(0))
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

View File

@@ -15,15 +15,18 @@
"""
import numpy as np
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
from extensions.ops.elementwise import Equal
from extensions.ops.select import Select
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.crop import Crop
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.utils.error import Error
from mo.utils.graph import invert_sub_graph_between_nodes
@@ -48,7 +51,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
@staticmethod
def pattern():
return dict(
nodes=[('op', dict(op='Memory', index=0))],
nodes=[('op', dict(op='Assign'))],
edges=[])
@staticmethod
@@ -93,9 +96,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
select_node.in_port(2).connect(zero_else.out_port(0))
# check if we have already appropriate iteration counter
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='Memory', index=1,
shape=int64_array([context_len]))),
('mem_in_data', dict()),
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
('mem_in_data', dict(shape=int64_array([context_len]))),
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
offset=int64_array([1]),
dim=int64_array([context_len-1]))),
@@ -104,8 +106,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
('concat_data', dict()),
('const_1', dict(op='Const')),
('const_1_data', dict()),
('mem_out', dict(op='Memory', index=0,
shape=int64_array([context_len]))),
('mem_out', dict(op='Assign')),
('crop_out', dict(op='Crop', axis=int64_array([1]),
offset=int64_array([0]),
dim=int64_array([1]))),
@@ -122,12 +123,13 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
('crop_out_data', 'select')])
counter_match = next(existing_counters, None)
if counter_match is not None:
ones = Node(graph, inverse_dict(counter_match)['const_1'])
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
else:
mem_out = Memory(graph, {'name': 'iteration_number', 'size': 2,
'index': 1, 'id': 'iteration_' + node.name,
'shape': int64_array([context_len]),
'dst_type': np.int32}).create_node()
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
mem_out = ReadValue(graph, {'name': 'iteration_number',
'variable_id': 'iteration_'+node.name}).create_node()
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
cut_first.in_port(0).connect(mem_out.out_port(0))
@@ -135,9 +137,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
concat.in_port(0).connect(cut_first.out_port(0))
concat.in_port(1).connect(ones.out_port(0))
mem_in = Memory(graph, {'name': 'iteration_number_out', 'size': 2,
'index': 0, 'id': 'iteration_' + node.name,
'shape': int64_array([context_len])}).create_node()
mem_in = Assign(graph, {'name': 'iteration_number_out',
'variable_id': 'iteration_'+node.name}).create_node()
mem_in.in_port(0).connect(concat.out_port(0))
res = Result(graph, {}).create_node()
mem_in.out_port(0).connect(res.in_port(0))
@@ -146,6 +147,12 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
cut_last.in_port(0).connect(concat.out_port(0))
input_port = cut_last.out_port(0)
select_node.in_port(0).connect(input_port)
# Check if data from memory is 1
# if it is True, we have correct data and should proceed with saving it to memory
# else we have not gathered context and have garbage here, shouldn't change initial state of memory
cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
cast_in.in_port(0).connect(ones.out_port(0))
cast_in.in_port(1).connect(input_port)
select_node.in_port(0).connect(cast_in.out_port(0))
select_node.out_port(0).connect(node.in_port(0))
select_node.out_port(0).data.set_shape(in_node_shape)

View File

@@ -30,23 +30,14 @@ class InsertSelectTests(unittest.TestCase):
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'memory')
],
nodes_with_edges_only=True)
ref_graph = graph.copy()
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'memory')
],
nodes_with_edges_only=True
)
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
self.assertTrue(flag, resp)
@@ -60,7 +51,7 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
'placeholder_2': {'kind': 'op', 'op': None},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@@ -76,15 +67,32 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
'placeholder_2': {'kind': 'op', 'op': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim':{'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
'crop_in_data': {'kind': 'data'},
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
'crop_out_data': {'kind': 'data'},
'equal': {'kind': 'op', 'op': 'Equal'},
'equal_data': {'kind': 'data'},
'select': {'kind': 'op', 'op': 'Select'},
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
@@ -95,22 +103,34 @@ class InsertSelectTests(unittest.TestCase):
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('select', 'select_out_data'),
('select_out_data', 'memory')
],
@@ -132,7 +152,7 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@@ -151,15 +171,32 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
'crop_in_data': {'kind': 'data'},
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
'crop_out_data': {'kind': 'data'},
'equal': {'kind': 'op', 'op': 'Equal'},
'equal_data': {'kind': 'data'},
'select': {'kind': 'op', 'op': 'Select'},
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
@@ -170,7 +207,7 @@ class InsertSelectTests(unittest.TestCase):
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@@ -178,13 +215,25 @@ class InsertSelectTests(unittest.TestCase):
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'select', {'in': 0}),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('select', 'select_out_data'),
@@ -208,7 +257,7 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@@ -227,15 +276,32 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([7])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
'crop_in_data': {'kind': 'data'},
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
'crop_out_data': {'kind': 'data'},
'equal': {'kind': 'op', 'op': 'Equal'},
'equal_data': {'kind': 'data'},
'select': {'kind': 'op', 'op': 'Select'},
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
@@ -246,7 +312,7 @@ class InsertSelectTests(unittest.TestCase):
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@@ -254,13 +320,25 @@ class InsertSelectTests(unittest.TestCase):
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'select', {'in': 0}),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('select', 'select_out_data'),

View File

@@ -59,7 +59,18 @@ class RemoveUselessCropsPattern(MiddleReplacementPattern):
if out['op'] == 'Crop' and out['axis'] == axis and \
len(out.out_port(0).get_destinations()) == 1 and \
out.out_port(0).get_destination().node == concat_node:
offsets_dims.append((out['offset'], out['dim']))
# crop type 1
if 'dim' in out:
offsets_dims.append((out['offset'], out['dim']))
# crop type 3
elif 'crop_begin' in out and 'crop_end' in out:
offsets_dims.append((out['crop_begin'], out['crop_end']-out['crop_begin']))
# crop type 2 with const dim
elif not out.in_port(1).disconnected() and out.in_port(1).data.get_value() is not None:
offsets_dims.append((out['offset'], out.in_port(1).data.get_value()))
# crop type 2 with non-const dim or strange type of crop
else:
return
crop_list.append(out)
offsets_dims.sort(key=lambda off_dim: off_dim[0])

View File

@@ -84,6 +84,136 @@ class RemoveUselessCropsPatternTests(unittest.TestCase):
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
self.assertTrue(flag, resp)
def test_useless_crops_type2(self):
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
'const_26_data': {'kind': 'data', 'value': 26},
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('crop_data_1', 'concat'),
('crop_data_2', 'concat'),
('crop_data_3', 'concat'),
('crop_data_4', 'concat'),
('crop_data_5', 'concat'),
('concat', 'concat_data'),
('concat_data', 'placeholder')])
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
'const_26_data': {'kind': 'data', 'value': 26},
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'dim': 26, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[
('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('concat', 'concat_data'),
('in_node', 'placeholder')
]
)
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
self.assertTrue(flag, resp)
def test_useless_crops_type3(self):
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('crop_data_1', 'concat'),
('crop_data_2', 'concat'),
('crop_data_3', 'concat'),
('crop_data_4', 'concat'),
('crop_data_5', 'concat'),
('concat', 'concat_data'),
('concat_data', 'placeholder')])
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[
('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('concat', 'concat_data'),
('in_node', 'placeholder')
]
)
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
self.assertTrue(flag, resp)
def test_useful_crops(self):
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},

View File

@@ -15,13 +15,15 @@
"""
import numpy as np
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
from extensions.ops.splice import Splice
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.crop import Crop
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.utils.error import Error
@@ -67,7 +69,8 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
splice = Splice(graph, {'name': node_name,
'id': node_id,
'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
'context': int64_array(range(node_t, 1))
if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
splice.in_port(0).connect(input_node_out_port)
# offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
@@ -106,6 +109,7 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
"""
enabled = True
force_shape_inference = True
def run_before(self):
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
@@ -141,43 +145,34 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
in_shape = input_port.data.get_shape()
node_t = abs(node.t)
memory_out = Memory(graph, {'name': pair_name, 'id': node_name+pair_name,
'index': 1, 'size': 2,
'shape': np.array([in_shape[1]*node_t])}).create_node()
init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
init_value_memory_out.out_port(0).connect(memory_out.in_port(0))
if node_t > 1:
crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
memory_out.out_port(0).connect(crop_concat.in_port(0))
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
concat.add_sequence_of_ports('in', range(2))
crop_concat.out_port(0).connect(concat.in_port(0))
crop_concat.out_port(0).data.set_shape(np.array([in_shape[0], crop_concat.dim]))
concat.in_port(1).connect(input_port)
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
'index': 0, 'size': 2,
'shape': memory_out.shape}).create_node()
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
concat.out_port(0).connect(memory_in.in_port(0))
concat.out_port(0).data.set_shape(np.array([in_shape[0], memory_in.shape[0]]))
out = Result(graph, {'name': 'Memory_output'}).create_node()
memory_in.out_port(0).connect(out.in_port(0))
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
'offset': np.array([0]), 'axis': np.array([1])}).create_node()
memory_out.out_port(0).connect(crop_out.in_port(0))
out_port.get_connection().set_source(crop_out.out_port(0))
crop_out.out_port(0).data.set_shape(np.array([in_shape[0], crop_out.dim]))
else:
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
'index': 0, 'size': 2,
'shape': memory_out.shape}).create_node()
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
memory_in.in_port(0).connect(input_port)
out = Result(graph, {'name': 'Memory_output'}).create_node()
memory_in.out_port(0).connect(out.in_port(0))
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
out_port.get_connection().set_source(memory_out.out_port(0))
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
graph.remove_node(op_output_id)
graph.remove_node(node.id)

View File

@@ -13,15 +13,16 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id, create_zero_value_with_batch_from_input
from extensions.ops.split import VariadicSplit
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.crop import Crop
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
@@ -39,7 +40,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
So this pass will convert this graph to the next one:
Input [N, H] __
\ /
/ /
Concat [N, k*H]
/ \
Memory [N, k*H] -> Slice [N, (k-1)*H] Memory [N, k*H]
@@ -67,11 +68,9 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
memory_pair_id = unique_id('id')
# Memory(in)
input_memory = Memory(graph, {'name': 'prev_splice_memory',
'id': memory_pair_id,
'index': 1,
'size': 2,
'shape': int64_array([memory_size])}).create_node()
input_memory = ReadValue(graph, {'name': 'prev_splice_memory',
'variable_id': memory_pair_id}).create_node()
# Memory(in) \
# Crop
# Input(temp) /
@@ -90,11 +89,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
concat_node.in_port(0).connect(crop.out_port(0))
# Concat -> Memory(out)
mem_out = Memory(graph, {'name': 'out_splice_memory',
'id': memory_pair_id,
'index': 0,
'size': 2,
'shape': int64_array([memory_size])}).create_node()
mem_out = Assign(graph, {'name': 'out_splice_memory', 'variable_id': memory_pair_id}).create_node()
mem_out.in_port(0).connect(concat_node.out_port(0))
Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))
@@ -110,11 +105,12 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
# create separate splice construction for const_dim
memory_pair_id = unique_id('memory_for_const_dim')
input_memory_const_dim = Memory(graph, {'name': 'const_dim_in_memory',
'id': memory_pair_id,
'index': 1,
'size': 2,
'shape': int64_array([memory_size_constdim])}).create_node()
init_value_input_memory_const_dim = create_zero_value_with_batch_from_input(split.out_port(1),
memory_size_constdim)
input_memory_const_dim = ReadValue(graph, {'name': 'const_dim_in_memory',
'variable_id': memory_pair_id}).create_node()
init_value_input_memory_const_dim.out_port(0).connect(input_memory_const_dim.in_port(0))
crop_const_dim = Crop(graph, {'name': 'const_dim_crop',
'axis': int64_array([1]),
'offset': int64_array([memory_element_constdim]),
@@ -127,11 +123,8 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
'axis': 1}).create_node()
concat_node_const_dim.in_port(0).connect(crop_const_dim.out_port(0))
mem_out_const_dim = Memory(graph, {'name': 'const_dim_out_memory',
'id': memory_pair_id,
'index': 0,
'size': 2,
'shape': int64_array([memory_size_constdim])}).create_node()
mem_out_const_dim = Assign(graph, {'name': 'const_dim_out_memory',
'variable_id': memory_pair_id}).create_node()
mem_out_const_dim.in_port(0).connect(concat_node_const_dim.out_port(0))
Result(graph).create_node().in_port(0).connect(mem_out_const_dim.out_port(0))
@@ -148,9 +141,15 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
concat_const.in_port(1).connect(crop_first.out_port(0))
concat_const.in_port(0).connect(concat_node.out_port(0))
init_value_input_memory = create_zero_value_with_batch_from_input(split.out_port(0),
memory_size)
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
node.in_port(0).get_connection().set_destination(split.in_port(0))
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
else:
init_value_input_memory = create_zero_value_with_batch_from_input(node.in_port(0).get_source(),
memory_size)
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
node.out_port(0).get_connection().set_source(concat_node.out_port(0))

View File

@@ -16,6 +16,7 @@
import unittest
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
@@ -42,19 +43,47 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
'in_node': {'kind': 'data', 'shape': [1, 13]},
'memory_in': {'kind': 'op', 'op': 'Memory'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([143])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
'crop_mem_data': {'kind': 'data'},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 143]},
'memory_out': {'kind': 'op', 'op': 'Memory'},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
},
[
('in_placeholder', 'in_node'),
('in_node', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}),
('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'),
('memory_in_data', 'crop_mem'),
('crop_mem', 'crop_mem_data'),
@@ -86,22 +115,54 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
'split': {'kind': 'op', 'op': 'Split'},
'split_data_0': {'kind': 'data'},
'split_data_1': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'Memory'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
'crop_mem_data': {'kind': 'data'},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory'},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
'shape_2': {'kind': 'op', 'op': 'ShapeOf'},
'shape_2_data': {'kind': 'data'},
'crop_batch_2': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_2_data': {'kind': 'data'},
'crop_batch_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_2_data': {'kind': 'data'},
'second_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
'second_dim_2_data': {'kind': 'data'},
'gather_shape_2': {'kind': 'op', 'op': 'Concat'},
'gather_shape_2_data': {'kind': 'data'},
'fill_value_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_2_data': {'kind': 'data'},
'broadcast_2': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_2_data': {'kind': 'data'},
'memory_in_constdims': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_constdims_data': {'kind': 'data'},
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
'crop_mem_constdims_data': {'kind': 'data'},
'concat_constdims': {'kind': 'op', 'op': 'Concat'},
'concat_constdims_data': {'kind': 'data'},
'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
'memory_out_constdims': {'kind': 'op', 'op': 'Assign'},
'memory_out_constdims_data': {'kind': 'data'},
'result_constdims': {'kind': 'op', 'op': 'Result'},
'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
@@ -121,6 +182,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
('in_node', 'split', {'in': 0}),
('split', 'split_data_0', {'out': 0}),
('split', 'split_data_1', {'out': 1}),
('split_data_0', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}),
('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'),
('memory_in_data', 'crop_mem'),
('crop_mem', 'crop_mem_data'),
@@ -130,6 +203,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'),
('memory_out_data', 'result'),
('split_data_1', 'shape_2'), ('shape_2', 'shape_2_data'),
('shape_2_data', 'crop_batch_2'), ('crop_batch_2', 'crop_batch_2_data'),
('crop_batch_dim_2', 'crop_batch_dim_2_data'),
('crop_batch_dim_2_data', 'crop_batch_2', {'in': 1}),
('second_dim_2', 'second_dim_2_data'), ('second_dim_2_data', 'gather_shape_2', {'in': 1}),
('crop_batch_2_data', 'gather_shape_2', {'in': 0}),
('gather_shape_2', 'gather_shape_2_data'),
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'broadcast_2', {'in': 0}),
('gather_shape_2_data', 'broadcast_2', {'in': 1}), ('broadcast_2', 'broadcast_2_data'),
('broadcast_2_data', 'memory_in_constdims'),
('memory_in_constdims', 'memory_in_constdims_data'),
('memory_in_constdims_data', 'crop_mem_constdims'),
('crop_mem_constdims', 'crop_mem_constdims_data'),