publish master branch snapshot, revision ea98a886d925eb152931aab13856e68037665562
This commit is contained in:
@@ -23,7 +23,7 @@ from mo.ops.crop import Crop
|
||||
from mo.utils.logger import log
|
||||
|
||||
|
||||
class CutMemory(BackReplacementPattern):
|
||||
class CutMemoryInput(BackReplacementPattern):
|
||||
"""
|
||||
Cut Memory layers and have inputs/outputs in graph instead of them
|
||||
"""
|
||||
@@ -38,30 +38,56 @@ class CutMemory(BackReplacementPattern):
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='Memory'))],
|
||||
('op', dict(kind='op', op='ReadValue'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['id']
|
||||
node_id = node['variable_id']
|
||||
|
||||
if node.in_port(0).disconnected():
|
||||
i = 0
|
||||
for dest in node.out_port(0).get_destinations():
|
||||
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
|
||||
'shape': dest.data.get_shape()}).create_node()
|
||||
i += 1
|
||||
dest.disconnect()
|
||||
new_in.out_port(0).connect(dest)
|
||||
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
|
||||
extra={'is_warning': True})
|
||||
else:
|
||||
out_node_port = node.out_port(0).get_destination()
|
||||
in_node_port = node.in_port(0).get_source()
|
||||
node.in_port(0).disconnect()
|
||||
node.out_port(0).disconnect()
|
||||
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0])}).create_node()
|
||||
in_node_port.connect(crop.in_port(0))
|
||||
crop.out_port(0).connect(out_node_port)
|
||||
i = 0
|
||||
node.in_port(0).disconnect()
|
||||
for dest in node.out_port(0).get_destinations():
|
||||
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
|
||||
'shape': dest.data.get_shape()}).create_node()
|
||||
i += 1
|
||||
dest.disconnect()
|
||||
new_in.out_port(0).connect(dest)
|
||||
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
|
||||
extra={'is_warning': True})
|
||||
|
||||
|
||||
class CutMemoryOutput(BackReplacementPattern):
|
||||
"""
|
||||
Cut Memory layers and have inputs/outputs in graph instead of them
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['fw'] == "kaldi" and graph.graph['cmd_params'].remove_memory]
|
||||
force_clean_up = True
|
||||
|
||||
def run_before(self):
|
||||
return [ParameterToInput]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='Assign'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['variable_id']
|
||||
|
||||
out_node_port = node.out_port(0).get_destination()
|
||||
in_node_port = node.in_port(0).get_source()
|
||||
node.in_port(0).disconnect()
|
||||
node.out_port(0).disconnect()
|
||||
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]),
|
||||
'axis': np.array([0])}).create_node()
|
||||
in_node_port.connect(crop.in_port(0))
|
||||
crop.out_port(0).connect(out_node_port)
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.CutMemory import CutMemory
|
||||
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
@@ -29,18 +29,21 @@ class CutMemoryTest(unittest.TestCase):
|
||||
nodes_attrs={
|
||||
'input': {'kind': 'op'},
|
||||
'data_in': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'index': 1, 'id': 'memory_', 'in_ports_count': 1},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
'const_0_data': {'kind': 'data'},
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'variable_id': 'memory_'},
|
||||
'data_mem': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'concat': {'kind': 'op', 'op': 'Concat', 'axis': 0},
|
||||
'concat_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'some_op': {'kind': 'op'},
|
||||
'some_op_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'index': 0, 'id': 'memory_'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign', 'variable_id': 'memory_'},
|
||||
'data_mem_out': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'mem_out_result': {'kind': 'op', 'op': 'Result'}
|
||||
},
|
||||
edges=[
|
||||
('input', 'data_in'), ('memory_in', 'data_mem'),
|
||||
('input', 'data_in'),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'memory_in'), ('memory_in', 'data_mem'),
|
||||
('data_in', 'concat', {'in': 0}), ('data_mem', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'some_op'),
|
||||
('some_op', 'some_op_data'), ('some_op_data', 'memory_out'),
|
||||
@@ -69,7 +72,8 @@ class CutMemoryTest(unittest.TestCase):
|
||||
('crop', 'crop_data'), ('crop_data', 'mem_out_result')
|
||||
],
|
||||
)
|
||||
CutMemory().find_and_replace_pattern(graph)
|
||||
CutMemoryInput().find_and_replace_pattern(graph)
|
||||
CutMemoryOutput().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mem_out_result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
140
model-optimizer/extensions/back/ReadValueAssignToMemory.py
Normal file
140
model-optimizer/extensions/back/ReadValueAssignToMemory.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.memory import Memory
|
||||
|
||||
|
||||
"""
|
||||
All transformations in this file should be removed after removing IR v7 support
|
||||
"""
|
||||
|
||||
|
||||
class ReplaceReadValueByMemory(BackReplacementPattern):
|
||||
"""
|
||||
Replace ReadValue by Memory. Should be removed after v7 IR support removing.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [CutMemoryInput]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='ReadValue'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['variable_id']
|
||||
|
||||
node.in_port(0).disconnect()
|
||||
new_in = Memory(graph, {'name': node.id, 'id': node_id, 'index': 1, 'size': 2,
|
||||
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
|
||||
for dest in node.out_port(0).get_destinations():
|
||||
dest.disconnect()
|
||||
new_in.out_port(0).connect(dest)
|
||||
|
||||
|
||||
class ReplaceAssignByMemory(BackReplacementPattern):
|
||||
"""
|
||||
Replace Assign by Memory. Should be removed after v7 IR support removing.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [CutMemoryOutput]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='Assign'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['variable_id']
|
||||
|
||||
new_out = Memory(graph, {'name': node.id, 'id': node_id, 'index': 0, 'size': 2,
|
||||
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
|
||||
node.in_port(0).get_source().connect(new_out.in_port(0))
|
||||
node.in_port(0).disconnect()
|
||||
node.out_port(0).get_connection().set_source(new_out.out_port(0))
|
||||
|
||||
|
||||
class KaldiRemoveMemoryOutputBackReplacementPatternV7(BackReplacementPattern):
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||
|
||||
def run_after(self):
|
||||
from extensions.back.pass_separator import BackFinish
|
||||
return [BackFinish]
|
||||
|
||||
def run_before(self):
|
||||
from extensions.back.SpecialNodesFinalization import RemoveOutputOps
|
||||
return [RemoveOutputOps]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('memory_node', dict(op='Memory')),
|
||||
('data_node', dict(kind='data')),
|
||||
('op_output', dict(op='Result'))
|
||||
],
|
||||
edges=[
|
||||
('memory_node', 'data_node'),
|
||||
('data_node', 'op_output')
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
"""
|
||||
Need to find the pattern: Memory -> Data -> Result
|
||||
|
||||
It is needed to make Memory nodes appear in IR,
|
||||
but they are output nodes by default and we remove the Result node after each output memory.
|
||||
|
||||
DO NOT use graph clean up after it
|
||||
otherwise Memory nodes would be removed as they are not on the path from input to output
|
||||
|
||||
Parameters
|
||||
----------
|
||||
graph : Graph
|
||||
Graph with loaded model.
|
||||
match : dict
|
||||
Patterns which were found in graph structure.
|
||||
"""
|
||||
memory = match['memory_node']
|
||||
data = match['data_node']
|
||||
op_output = match['op_output']
|
||||
|
||||
graph.remove_edge(memory.id, data.id)
|
||||
graph.remove_node(data.id)
|
||||
graph.remove_node(op_output.id)
|
||||
@@ -33,7 +33,7 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('memory_node', dict(op='Memory')),
|
||||
('memory_node', dict(op='Assign')),
|
||||
('data_node', dict(kind='data')),
|
||||
('op_output', dict(op='Result'))
|
||||
],
|
||||
@@ -63,6 +63,8 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
|
||||
"""
|
||||
memory = match['memory_node']
|
||||
data = match['data_node']
|
||||
op_output = match['op_output']
|
||||
|
||||
graph.remove_edge(memory.id, data.id)
|
||||
graph.remove_node(data.id)
|
||||
graph.remove_node(op_output.id)
|
||||
|
||||
@@ -26,7 +26,7 @@ class KaldiRemoveMemoryOutputTest(unittest.TestCase):
|
||||
'kind': 'data'
|
||||
},
|
||||
'memory_node': {
|
||||
'op': 'Memory',
|
||||
'op': 'Assign',
|
||||
'kind': 'op'
|
||||
},
|
||||
'output_node': {
|
||||
|
||||
@@ -63,7 +63,7 @@ def apply_biases_to_last_layer(graph, counts):
|
||||
outputs_ids = find_outputs(graph)
|
||||
for output in outputs_ids.copy():
|
||||
node = Node(graph, output)
|
||||
if node.op != 'Memory':
|
||||
if node.op != 'Assign':
|
||||
continue
|
||||
outputs_ids.remove(output)
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.split import Split
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.eltwise import Eltwise
|
||||
from mo.ops.eltwise_n import EltwiseN
|
||||
from mo.utils.error import Error
|
||||
|
||||
@@ -43,8 +43,12 @@ class ReplaceEltwiseNin1NodePattern(FrontReplacementOp):
|
||||
edge_attrs = inp[0][1]
|
||||
graph.add_edge(in_node, ss_node.id, **edge_attrs)
|
||||
if ss_node.num_splits == 2:
|
||||
eltwise_node = Eltwise(graph, attrs={'name': 'Eltwise_' + node.name,
|
||||
'operation': node['operation']}).create_node()
|
||||
if node['operation'] == 'mul':
|
||||
eltwise_node = Mul(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
|
||||
elif node['operation'] == 'sum':
|
||||
eltwise_node = Add(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
|
||||
else:
|
||||
raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation'])
|
||||
elif ss_node.num_splits > 2:
|
||||
eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
|
||||
'operation': node['operation']}).create_node()
|
||||
|
||||
@@ -20,13 +20,19 @@ from extensions.ops.activation_ops import Tanh, Sigmoid
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.split import Split
|
||||
from mo.front.caffe.extractors.utils import input_as_const
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.graph.graph import Node, Graph, Port
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
from mo.ops.shape import Shape
|
||||
|
||||
|
||||
def unique_id(prefix: str = 'id') -> str:
|
||||
@@ -46,6 +52,35 @@ def unique_id(prefix: str = 'id') -> str:
|
||||
unique_id.names = []
|
||||
|
||||
|
||||
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
|
||||
# create init_graph connected to ReadValue
|
||||
graph = input_out_port.node.graph
|
||||
input_name = input_out_port.node.name
|
||||
shape_of_input = Shape(graph, {'name': 'shape/' + input_name}).create_node()
|
||||
shape_of_input.in_port(0).connect(input_out_port)
|
||||
dim_for_get_batch = Const(graph, {'name': 'dim/crop_batch/'+shape_of_input.name,
|
||||
'value': int64_array([1]), 'shape': int64_array([1])}).create_node()
|
||||
get_batch = Crop(graph, {'name': 'crop_batch/' + shape_of_input.name,
|
||||
'axis': int64_array([0]), 'offset': int64_array([0])
|
||||
}).create_node()
|
||||
get_batch.in_port(0).connect(shape_of_input.out_port(0))
|
||||
get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
|
||||
mem_shape_2nd_dim = Const(graph, {'name': 'gifo_r_weights_shape/'+input_name,
|
||||
'value': int64_array([second_dim]),
|
||||
'shape': int64_array([1])}).create_node()
|
||||
mem_shape = Concat(graph, {'name': 'gather_memory_shape/' + input_name,
|
||||
'axis': 0, 'in_ports_count': 2}).create_node()
|
||||
mem_shape.in_port(0).connect(get_batch.out_port(0))
|
||||
mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
|
||||
fill_value = Const(graph, {'name': 'fill_value/'+input_name,
|
||||
'value': np.array([0.0], precision), 'shape': int64_array([1])}).create_node()
|
||||
init_value_prev_lstm_output = Broadcast(graph, {'name': 'init_value/'+input_name,
|
||||
}).create_node()
|
||||
init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
|
||||
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
|
||||
return init_value_prev_lstm_output
|
||||
|
||||
|
||||
class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
op = "LSTMCell"
|
||||
enabled = True
|
||||
@@ -69,7 +104,7 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
)
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
input_node = node.in_node()
|
||||
input_out_port = node.in_port(0).get_source()
|
||||
|
||||
memory_pair_input = unique_id('id')
|
||||
memory_pair_output = unique_id('id')
|
||||
@@ -81,16 +116,17 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
'bias_term': True,
|
||||
}
|
||||
|
||||
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node([input_node])
|
||||
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node()
|
||||
fc_layer_after_input.in_port(0).connect(input_out_port)
|
||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
|
||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
|
||||
|
||||
prev_lstm_output = Memory(graph, {'name': 'prev_memory_output',
|
||||
'id': memory_pair_input,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
|
||||
}).create_node()
|
||||
init_value_prev_lstm_output = create_zero_value_with_batch_from_input(input_out_port,
|
||||
node.gifo_r_weights_shape[1])
|
||||
prev_lstm_output = ReadValue(graph, {'name': 'prev_memory_output',
|
||||
'variable_id': memory_pair_input
|
||||
}).create_node()
|
||||
prev_lstm_output.in_port(0).connect(init_value_prev_lstm_output.out_port(0))
|
||||
|
||||
# *Memory(output) -> FullyConnected
|
||||
fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
|
||||
@@ -99,15 +135,16 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
'bias_term': False,
|
||||
}
|
||||
|
||||
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node(
|
||||
[prev_lstm_output])
|
||||
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node()
|
||||
fc_layer_from_prev_state.in_port(0).connect(prev_lstm_output.out_port(0))
|
||||
input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
|
||||
|
||||
# Memory -> FullyConnected \
|
||||
# *Eltwise(sum)
|
||||
# Input -> FullyConnected /
|
||||
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise',
|
||||
}).create_node([fc_layer_from_prev_state, fc_layer_after_input])
|
||||
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise'}).create_node()
|
||||
join_input_prev_state_sum.in_port(0).connect(fc_layer_from_prev_state.out_port(0))
|
||||
join_input_prev_state_sum.in_port(1).connect(fc_layer_after_input.out_port(0))
|
||||
|
||||
# *Eltwise(sum) -> Split
|
||||
# it is split into 4 nodes: Act, Eltw*3
|
||||
@@ -120,131 +157,147 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
# |____(4)Eltwise(sum)
|
||||
split_joined_input_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
||||
split_joined_input = Split(graph, {'name': 'join_input_split',
|
||||
'num_splits': 4,
|
||||
}).create_node([join_input_prev_state_sum, split_joined_input_axis])
|
||||
'num_splits': 4, 'out_ports_count': 4}).create_node()
|
||||
split_joined_input.in_port(0).connect(join_input_prev_state_sum.out_port(0))
|
||||
split_joined_input.in_port(1).connect(split_joined_input_axis.out_port(0))
|
||||
|
||||
prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
|
||||
'id': memory_pair_output,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||
}).create_node()
|
||||
# prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
|
||||
# 'id': memory_pair_output,
|
||||
# 'index': 1,
|
||||
# 'size': 2,
|
||||
# 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||
# }).create_node()
|
||||
init_value_prev_lstm_state = create_zero_value_with_batch_from_input(split_joined_input.out_port(0),
|
||||
node.input_gate_weights.shape[0])
|
||||
prev_lstm_state = ReadValue(graph, {'name': 'prev_memory_state',
|
||||
'variable_id': memory_pair_output}).create_node()
|
||||
prev_lstm_state.in_port(0).connect(init_value_prev_lstm_state.out_port(0))
|
||||
|
||||
# *Memory(state) -> *ScaleShift(input)
|
||||
state_input_scaleshift_attrs = {'name': 'input_scaleshift',
|
||||
'bias_term': False
|
||||
}
|
||||
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])
|
||||
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node()
|
||||
state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
|
||||
input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
|
||||
|
||||
# *Memory(state) -> *ScaleShift(forget)
|
||||
state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
|
||||
'bias_term': False
|
||||
}
|
||||
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state])
|
||||
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node()
|
||||
state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
|
||||
input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
|
||||
|
||||
# Split \
|
||||
# (2)Eltwise(sum)
|
||||
# Memory(state) -> *ScaleShift(input) /
|
||||
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise',
|
||||
}).create_node([(split_joined_input, 1),
|
||||
state_input_scaleshift
|
||||
])
|
||||
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise'
|
||||
}).create_node()
|
||||
join_prev_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(1))
|
||||
join_prev_lstm_input_joined_input_sum.in_port(1).connect(state_input_scaleshift.out_port(0))
|
||||
# Split \
|
||||
# (3)Eltwise(sum)
|
||||
# Memory(state) -> *ScaleShift(forget) /
|
||||
join_prev_lstm_input_joined_forget_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
|
||||
}).create_node([(split_joined_input, 2),
|
||||
state_forget_scaleshift
|
||||
])
|
||||
}).create_node()
|
||||
join_prev_lstm_input_joined_forget_sum.in_port(0).connect(split_joined_input.out_port(2))
|
||||
join_prev_lstm_input_joined_forget_sum.in_port(1).connect(state_forget_scaleshift.out_port(0))
|
||||
|
||||
# Split -> Tanh
|
||||
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node([(split_joined_input, 0)])
|
||||
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
|
||||
remember_tahn.in_port(0).connect(split_joined_input.out_port(0))
|
||||
|
||||
# Split -> (2)Eltwise(sum) -> *Sigmoid
|
||||
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'
|
||||
}).create_node([join_prev_lstm_input_joined_input_sum])
|
||||
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'}).create_node()
|
||||
remember_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_input_sum.out_port(0))
|
||||
|
||||
# Split -> (3)Eltwise(sum) -> **Sigmoid
|
||||
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'
|
||||
}).create_node([join_prev_lstm_input_joined_forget_sum])
|
||||
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'}).create_node()
|
||||
forget_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_forget_sum.out_port(0))
|
||||
|
||||
# *Memory(state) \
|
||||
# (6)Eltwise(mul)
|
||||
# Split -> (3)Eltwise(sum) -> **Sigmoid /
|
||||
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul',
|
||||
}).create_node([forget_sigmoid, prev_lstm_state])
|
||||
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul'}).create_node()
|
||||
join_forget_prev_state_mul.in_port(0).connect(forget_sigmoid.out_port(0))
|
||||
join_forget_prev_state_mul.in_port(1).connect(prev_lstm_state.out_port(0))
|
||||
|
||||
# Split -> Tahn \
|
||||
# (5)Eltwise(mul)
|
||||
# Split -> (2)Eltwise(sum) -> *Sigmoid /
|
||||
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul',
|
||||
}).create_node([remember_tahn, remember_sigmoid])
|
||||
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul'}).create_node()
|
||||
join_remember_candidates_mul.in_port(0).connect(remember_tahn.out_port(0))
|
||||
join_remember_candidates_mul.in_port(1).connect(remember_sigmoid.out_port(0))
|
||||
|
||||
# (5)Eltwise(mul) \
|
||||
# (7)Eltwise(sum)
|
||||
# (6)Eltwise(mul) /
|
||||
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum',
|
||||
}).create_node(
|
||||
[join_forget_prev_state_mul, join_remember_candidates_mul])
|
||||
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum'}).create_node()
|
||||
join_forget_remember_sum.in_port(0).connect(join_forget_prev_state_mul.out_port(0))
|
||||
join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
|
||||
|
||||
# (7)Eltwise(sum) -> Clamp
|
||||
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
|
||||
'max': node.clip_value,
|
||||
'min': -node.clip_value
|
||||
}).create_node(
|
||||
[join_forget_remember_sum])
|
||||
'min': -node.clip_value}).create_node()
|
||||
join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
|
||||
#
|
||||
# Clamp -> (2)Memory(state)
|
||||
next_lstm_state = Memory(graph, {'name': 'next_lstm_state',
|
||||
'id': memory_pair_output,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||
}).create_node([join_forget_clamp])
|
||||
Result(graph, {'name': 'next_lstm_state_out'}).create_node([next_lstm_state])
|
||||
next_lstm_state = Assign(graph, {'name': 'next_lstm_state',
|
||||
'variable_id': memory_pair_output}).create_node()
|
||||
next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||
|
||||
res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
|
||||
res_node.in_port(0).connect(next_lstm_state.out_port(0))
|
||||
|
||||
# Clamp -> (2)Tahn
|
||||
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node([join_forget_clamp])
|
||||
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node()
|
||||
state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||
|
||||
# Clamp -> (2)ScaleShift
|
||||
clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
|
||||
'bias_term': False}
|
||||
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])
|
||||
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node()
|
||||
clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||
input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
|
||||
|
||||
# Split \
|
||||
# (4)Eltwise(sum)
|
||||
# Clamp -> (2)ScaleShift /
|
||||
join_next_lstm_input_joined_input_sum = Add(graph, {'name': 'join_next_lstm_input_joined_input_sum',
|
||||
}).create_node([(split_joined_input, 3), clamp_scaleshift])
|
||||
}).create_node()
|
||||
join_next_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(3))
|
||||
join_next_lstm_input_joined_input_sum.in_port(1).connect(clamp_scaleshift.out_port(0))
|
||||
|
||||
# (4)Eltwise(sum) -> (3)Sigmoid
|
||||
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node([join_next_lstm_input_joined_input_sum])
|
||||
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node()
|
||||
output_sigmoid.in_port(0).connect(join_next_lstm_input_joined_input_sum.out_port(0))
|
||||
|
||||
# (4)Eltwise(sum) -> (3)Sigmoid \
|
||||
# (5)Eltwise(mul)
|
||||
# Clamp -> (2)Tahn /
|
||||
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node([state_filtered_tahn, output_sigmoid])
|
||||
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node()
|
||||
joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
|
||||
joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))
|
||||
|
||||
# (5)Eltwise(mul) -> (3)FullyConnected
|
||||
fc_output_attrs = {'name': 'FullyConnected',
|
||||
'out-size': node.projection_weights_shape[0],
|
||||
'transpose_weights': True,
|
||||
'bias_term': False}
|
||||
fc_output = FullyConnected(graph, fc_output_attrs).create_node([joined_output_mul])
|
||||
fc_output = FullyConnected(graph, fc_output_attrs).create_node()
|
||||
fc_output.in_port(0).connect(joined_output_mul.out_port(0))
|
||||
input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights)
|
||||
|
||||
# / (2)Memory(output)
|
||||
# (3)FullyConnected
|
||||
# \ Output (any next node) (edge created automatically after replacement)
|
||||
next_lstm_output = Memory(graph, {'name': 'next_lstm_output',
|
||||
'id': memory_pair_input,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
|
||||
}).create_node([fc_output])
|
||||
Result(graph, {'name': 'next_lstm_output_out'}).create_node([next_lstm_output])
|
||||
next_lstm_output = Assign(graph, {'name': 'next_lstm_output',
|
||||
'variable_id': memory_pair_input}).create_node()
|
||||
next_lstm_output.in_port(0).connect(fc_output.out_port(0))
|
||||
|
||||
res_node_lstm_output = Result(graph, {'name': 'next_lstm_output_out'}).create_node()
|
||||
res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))
|
||||
|
||||
return [fc_output.id]
|
||||
|
||||
@@ -22,7 +22,7 @@ from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.eltwise import Eltwise
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
|
||||
|
||||
@@ -41,19 +41,19 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
||||
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
||||
split_node = Split(graph, {'name': graph.unique_id(prefix='Split_lstm_input_'),
|
||||
split_node = Split(graph, {'name': 'Split_lstm_input_',
|
||||
'num_splits': 5}).create_node()
|
||||
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
|
||||
split_node.in_port(1).connect(split_node_axis.out_port(0))
|
||||
|
||||
# i_t = Sigmoid(i_part + w_ic*ct_1)
|
||||
i_scale_attrs = {'name': graph.unique_id(prefix='i_scaleshift'),
|
||||
i_scale_attrs = {'name': 'i_scaleshift',
|
||||
'bias_term': False}
|
||||
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
|
||||
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
|
||||
split_node.out_port(4).connect(i_scale.in_port(0))
|
||||
|
||||
sum_i_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum'}).create_node()
|
||||
sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
|
||||
split_node.out_port(0).connect(sum_i_c.in_port(0))
|
||||
i_scale.out_port(0).connect(sum_i_c.in_port(1))
|
||||
|
||||
@@ -61,13 +61,13 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
|
||||
|
||||
# f_t = Sigmoid(f_part + w_fc*ct_1)
|
||||
f_scale_attrs = {'name': graph.unique_id(prefix='f_scaleshift'),
|
||||
f_scale_attrs = {'name': 'f_scaleshift',
|
||||
'bias_term': False}
|
||||
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
|
||||
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
|
||||
split_node.out_port(4).connect(f_scale.in_port(0))
|
||||
|
||||
sum_f_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum'}).create_node()
|
||||
sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
|
||||
split_node.out_port(1).connect(sum_f_c.in_port(0))
|
||||
f_scale.out_port(0).connect(sum_f_c.in_port(1))
|
||||
|
||||
@@ -78,28 +78,26 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
|
||||
split_node.out_port(2).connect(c_tanh.in_port(0))
|
||||
|
||||
prod_i_c_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_i_c_tanh_'),
|
||||
'operation': 'mul'}).create_node()
|
||||
prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
|
||||
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
|
||||
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
|
||||
|
||||
prod_f_ct_1 = Eltwise(graph, {'name': graph.unique_id(prefix='prod_f_ct_1_'),
|
||||
'operation': 'mul'}).create_node()
|
||||
prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
|
||||
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
|
||||
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
|
||||
|
||||
sum_f_i = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum'}).create_node()
|
||||
sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
|
||||
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
|
||||
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
|
||||
|
||||
# o_t = Sigmoid(o_part + w_oc*c_t)
|
||||
o_scale_attrs = {'name': graph.unique_id(prefix='o_scaleshift'),
|
||||
o_scale_attrs = {'name': 'o_scaleshift',
|
||||
'bias_term': False}
|
||||
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
|
||||
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
|
||||
sum_f_i.out_port(0).connect(o_scale.in_port(0))
|
||||
|
||||
sum_o_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum'}).create_node()
|
||||
sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
|
||||
split_node.out_port(3).connect(sum_o_c.in_port(0))
|
||||
o_scale.out_port(0).connect(sum_o_c.in_port(1))
|
||||
|
||||
@@ -110,13 +108,12 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
|
||||
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
|
||||
|
||||
prod_o_c_t_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
|
||||
'operation': 'mul'}).create_node()
|
||||
prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
|
||||
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
|
||||
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
|
||||
|
||||
# add concat to create 1 output
|
||||
concat = Concat(graph, {'name': graph.unique_id(prefix='Concat_c_m')}).create_node()
|
||||
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
|
||||
concat.add_sequence_of_ports('in', range(2))
|
||||
sum_f_i.out_port(0).connect(concat.in_port(0))
|
||||
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))
|
||||
|
||||
@@ -15,15 +15,18 @@
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||
from extensions.ops.elementwise import Equal
|
||||
from extensions.ops.select import Select
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.graph import invert_sub_graph_between_nodes
|
||||
@@ -48,7 +51,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[('op', dict(op='Memory', index=0))],
|
||||
nodes=[('op', dict(op='Assign'))],
|
||||
edges=[])
|
||||
|
||||
@staticmethod
|
||||
@@ -93,9 +96,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
select_node.in_port(2).connect(zero_else.out_port(0))
|
||||
|
||||
# check if we have already appropriate iteration counter
|
||||
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='Memory', index=1,
|
||||
shape=int64_array([context_len]))),
|
||||
('mem_in_data', dict()),
|
||||
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
|
||||
('mem_in_data', dict(shape=int64_array([context_len]))),
|
||||
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
|
||||
offset=int64_array([1]),
|
||||
dim=int64_array([context_len-1]))),
|
||||
@@ -104,8 +106,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
('concat_data', dict()),
|
||||
('const_1', dict(op='Const')),
|
||||
('const_1_data', dict()),
|
||||
('mem_out', dict(op='Memory', index=0,
|
||||
shape=int64_array([context_len]))),
|
||||
('mem_out', dict(op='Assign')),
|
||||
('crop_out', dict(op='Crop', axis=int64_array([1]),
|
||||
offset=int64_array([0]),
|
||||
dim=int64_array([1]))),
|
||||
@@ -122,12 +123,13 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
('crop_out_data', 'select')])
|
||||
counter_match = next(existing_counters, None)
|
||||
if counter_match is not None:
|
||||
ones = Node(graph, inverse_dict(counter_match)['const_1'])
|
||||
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
|
||||
else:
|
||||
mem_out = Memory(graph, {'name': 'iteration_number', 'size': 2,
|
||||
'index': 1, 'id': 'iteration_' + node.name,
|
||||
'shape': int64_array([context_len]),
|
||||
'dst_type': np.int32}).create_node()
|
||||
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
|
||||
mem_out = ReadValue(graph, {'name': 'iteration_number',
|
||||
'variable_id': 'iteration_'+node.name}).create_node()
|
||||
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
|
||||
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
|
||||
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
|
||||
cut_first.in_port(0).connect(mem_out.out_port(0))
|
||||
@@ -135,9 +137,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
|
||||
concat.in_port(0).connect(cut_first.out_port(0))
|
||||
concat.in_port(1).connect(ones.out_port(0))
|
||||
mem_in = Memory(graph, {'name': 'iteration_number_out', 'size': 2,
|
||||
'index': 0, 'id': 'iteration_' + node.name,
|
||||
'shape': int64_array([context_len])}).create_node()
|
||||
mem_in = Assign(graph, {'name': 'iteration_number_out',
|
||||
'variable_id': 'iteration_'+node.name}).create_node()
|
||||
mem_in.in_port(0).connect(concat.out_port(0))
|
||||
res = Result(graph, {}).create_node()
|
||||
mem_in.out_port(0).connect(res.in_port(0))
|
||||
@@ -146,6 +147,12 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
cut_last.in_port(0).connect(concat.out_port(0))
|
||||
input_port = cut_last.out_port(0)
|
||||
|
||||
select_node.in_port(0).connect(input_port)
|
||||
# Check if data from memory is 1
|
||||
# if it is True, we have correct data and should proceed with saving it to memory
|
||||
# else we have not gathered context and have garbage here, shouldn't change initial state of memory
|
||||
cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
|
||||
cast_in.in_port(0).connect(ones.out_port(0))
|
||||
cast_in.in_port(1).connect(input_port)
|
||||
select_node.in_port(0).connect(cast_in.out_port(0))
|
||||
select_node.out_port(0).connect(node.in_port(0))
|
||||
select_node.out_port(0).data.set_shape(in_node_shape)
|
||||
|
||||
@@ -30,23 +30,14 @@ class InsertSelectTests(unittest.TestCase):
|
||||
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
ref_graph = graph.copy()
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
|
||||
self.assertTrue(flag, resp)
|
||||
@@ -60,7 +51,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@@ -76,15 +67,32 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim':{'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||
'crop_in_data': {'kind': 'data'},
|
||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||
'crop_out_data': {'kind': 'data'},
|
||||
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||
'equal_data': {'kind': 'data'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
@@ -95,22 +103,34 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
('select', 'select_out_data'),
|
||||
('select_out_data', 'memory')
|
||||
],
|
||||
@@ -132,7 +152,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@@ -151,15 +171,32 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||
'crop_in_data': {'kind': 'data'},
|
||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||
'crop_out_data': {'kind': 'data'},
|
||||
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||
'equal_data': {'kind': 'data'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
@@ -170,7 +207,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@@ -178,13 +215,25 @@ class InsertSelectTests(unittest.TestCase):
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'select', {'in': 0}),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
|
||||
('select', 'select_out_data'),
|
||||
@@ -208,7 +257,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@@ -227,15 +276,32 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([7])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
|
||||
'crop_in_data': {'kind': 'data'},
|
||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||
'crop_out_data': {'kind': 'data'},
|
||||
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||
'equal_data': {'kind': 'data'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
@@ -246,7 +312,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@@ -254,13 +320,25 @@ class InsertSelectTests(unittest.TestCase):
|
||||
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'select', {'in': 0}),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
|
||||
('select', 'select_out_data'),
|
||||
|
||||
@@ -59,7 +59,18 @@ class RemoveUselessCropsPattern(MiddleReplacementPattern):
|
||||
if out['op'] == 'Crop' and out['axis'] == axis and \
|
||||
len(out.out_port(0).get_destinations()) == 1 and \
|
||||
out.out_port(0).get_destination().node == concat_node:
|
||||
offsets_dims.append((out['offset'], out['dim']))
|
||||
# crop type 1
|
||||
if 'dim' in out:
|
||||
offsets_dims.append((out['offset'], out['dim']))
|
||||
# crop type 3
|
||||
elif 'crop_begin' in out and 'crop_end' in out:
|
||||
offsets_dims.append((out['crop_begin'], out['crop_end']-out['crop_begin']))
|
||||
# crop type 2 with const dim
|
||||
elif not out.in_port(1).disconnected() and out.in_port(1).data.get_value() is not None:
|
||||
offsets_dims.append((out['offset'], out.in_port(1).data.get_value()))
|
||||
# crop type 2 with non-const dim or strange type of crop
|
||||
else:
|
||||
return
|
||||
crop_list.append(out)
|
||||
|
||||
offsets_dims.sort(key=lambda off_dim: off_dim[0])
|
||||
|
||||
@@ -84,6 +84,136 @@ class RemoveUselessCropsPatternTests(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_useless_crops_type2(self):
|
||||
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
|
||||
'const_26_data': {'kind': 'data', 'value': 26},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
|
||||
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('crop_data_1', 'concat'),
|
||||
('crop_data_2', 'concat'),
|
||||
('crop_data_3', 'concat'),
|
||||
('crop_data_4', 'concat'),
|
||||
('crop_data_5', 'concat'),
|
||||
('concat', 'concat_data'),
|
||||
('concat_data', 'placeholder')])
|
||||
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
|
||||
'const_26_data': {'kind': 'data', 'value': 26},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'dim': 26, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[
|
||||
('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
|
||||
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('concat', 'concat_data'),
|
||||
('in_node', 'placeholder')
|
||||
]
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_useless_crops_type3(self):
|
||||
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('crop_data_1', 'concat'),
|
||||
('crop_data_2', 'concat'),
|
||||
('crop_data_3', 'concat'),
|
||||
('crop_data_4', 'concat'),
|
||||
('crop_data_5', 'concat'),
|
||||
('concat', 'concat_data'),
|
||||
('concat_data', 'placeholder')])
|
||||
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[
|
||||
('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('concat', 'concat_data'),
|
||||
('in_node', 'placeholder')
|
||||
]
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_useful_crops(self):
|
||||
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
|
||||
@@ -15,13 +15,15 @@
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||
from extensions.ops.splice import Splice
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.utils.error import Error
|
||||
|
||||
@@ -67,7 +69,8 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
|
||||
|
||||
splice = Splice(graph, {'name': node_name,
|
||||
'id': node_id,
|
||||
'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
|
||||
'context': int64_array(range(node_t, 1))
|
||||
if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
|
||||
splice.in_port(0).connect(input_node_out_port)
|
||||
|
||||
# offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
|
||||
@@ -106,6 +109,7 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
||||
Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
|
||||
"""
|
||||
enabled = True
|
||||
force_shape_inference = True
|
||||
|
||||
def run_before(self):
|
||||
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
|
||||
@@ -141,43 +145,34 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
||||
in_shape = input_port.data.get_shape()
|
||||
node_t = abs(node.t)
|
||||
|
||||
memory_out = Memory(graph, {'name': pair_name, 'id': node_name+pair_name,
|
||||
'index': 1, 'size': 2,
|
||||
'shape': np.array([in_shape[1]*node_t])}).create_node()
|
||||
init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
|
||||
memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
|
||||
init_value_memory_out.out_port(0).connect(memory_out.in_port(0))
|
||||
|
||||
if node_t > 1:
|
||||
crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
|
||||
'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
|
||||
memory_out.out_port(0).connect(crop_concat.in_port(0))
|
||||
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
|
||||
concat.add_sequence_of_ports('in', range(2))
|
||||
crop_concat.out_port(0).connect(concat.in_port(0))
|
||||
crop_concat.out_port(0).data.set_shape(np.array([in_shape[0], crop_concat.dim]))
|
||||
concat.in_port(1).connect(input_port)
|
||||
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
|
||||
'index': 0, 'size': 2,
|
||||
'shape': memory_out.shape}).create_node()
|
||||
|
||||
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
|
||||
concat.out_port(0).connect(memory_in.in_port(0))
|
||||
concat.out_port(0).data.set_shape(np.array([in_shape[0], memory_in.shape[0]]))
|
||||
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
||||
memory_in.out_port(0).connect(out.in_port(0))
|
||||
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
|
||||
crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
|
||||
'offset': np.array([0]), 'axis': np.array([1])}).create_node()
|
||||
memory_out.out_port(0).connect(crop_out.in_port(0))
|
||||
out_port.get_connection().set_source(crop_out.out_port(0))
|
||||
crop_out.out_port(0).data.set_shape(np.array([in_shape[0], crop_out.dim]))
|
||||
else:
|
||||
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
|
||||
'index': 0, 'size': 2,
|
||||
'shape': memory_out.shape}).create_node()
|
||||
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
|
||||
memory_in.in_port(0).connect(input_port)
|
||||
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
||||
memory_in.out_port(0).connect(out.in_port(0))
|
||||
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
out_port.get_connection().set_source(memory_out.out_port(0))
|
||||
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
|
||||
graph.remove_node(op_output_id)
|
||||
graph.remove_node(node.id)
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id, create_zero_value_with_batch_from_input
|
||||
from extensions.ops.split import VariadicSplit
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
|
||||
|
||||
@@ -39,7 +40,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
So this pass will convert this graph to the next one:
|
||||
|
||||
Input [N, H] __
|
||||
\ /
|
||||
/ /
|
||||
Concat [N, k*H]
|
||||
/ \
|
||||
Memory [N, k*H] -> Slice [N, (k-1)*H] Memory [N, k*H]
|
||||
@@ -67,11 +68,9 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
|
||||
memory_pair_id = unique_id('id')
|
||||
# Memory(in)
|
||||
input_memory = Memory(graph, {'name': 'prev_splice_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size])}).create_node()
|
||||
input_memory = ReadValue(graph, {'name': 'prev_splice_memory',
|
||||
'variable_id': memory_pair_id}).create_node()
|
||||
|
||||
# Memory(in) \
|
||||
# Crop
|
||||
# Input(temp) /
|
||||
@@ -90,11 +89,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
concat_node.in_port(0).connect(crop.out_port(0))
|
||||
|
||||
# Concat -> Memory(out)
|
||||
mem_out = Memory(graph, {'name': 'out_splice_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size])}).create_node()
|
||||
mem_out = Assign(graph, {'name': 'out_splice_memory', 'variable_id': memory_pair_id}).create_node()
|
||||
mem_out.in_port(0).connect(concat_node.out_port(0))
|
||||
Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))
|
||||
|
||||
@@ -110,11 +105,12 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
|
||||
# create separate splice construction for const_dim
|
||||
memory_pair_id = unique_id('memory_for_const_dim')
|
||||
input_memory_const_dim = Memory(graph, {'name': 'const_dim_in_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size_constdim])}).create_node()
|
||||
init_value_input_memory_const_dim = create_zero_value_with_batch_from_input(split.out_port(1),
|
||||
memory_size_constdim)
|
||||
input_memory_const_dim = ReadValue(graph, {'name': 'const_dim_in_memory',
|
||||
'variable_id': memory_pair_id}).create_node()
|
||||
init_value_input_memory_const_dim.out_port(0).connect(input_memory_const_dim.in_port(0))
|
||||
|
||||
crop_const_dim = Crop(graph, {'name': 'const_dim_crop',
|
||||
'axis': int64_array([1]),
|
||||
'offset': int64_array([memory_element_constdim]),
|
||||
@@ -127,11 +123,8 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
'axis': 1}).create_node()
|
||||
concat_node_const_dim.in_port(0).connect(crop_const_dim.out_port(0))
|
||||
|
||||
mem_out_const_dim = Memory(graph, {'name': 'const_dim_out_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size_constdim])}).create_node()
|
||||
mem_out_const_dim = Assign(graph, {'name': 'const_dim_out_memory',
|
||||
'variable_id': memory_pair_id}).create_node()
|
||||
mem_out_const_dim.in_port(0).connect(concat_node_const_dim.out_port(0))
|
||||
Result(graph).create_node().in_port(0).connect(mem_out_const_dim.out_port(0))
|
||||
|
||||
@@ -148,9 +141,15 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
concat_const.in_port(1).connect(crop_first.out_port(0))
|
||||
concat_const.in_port(0).connect(concat_node.out_port(0))
|
||||
|
||||
init_value_input_memory = create_zero_value_with_batch_from_input(split.out_port(0),
|
||||
memory_size)
|
||||
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||
node.in_port(0).get_connection().set_destination(split.in_port(0))
|
||||
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
|
||||
else:
|
||||
init_value_input_memory = create_zero_value_with_batch_from_input(node.in_port(0).get_source(),
|
||||
memory_size)
|
||||
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
|
||||
node.out_port(0).get_connection().set_source(concat_node.out_port(0))
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
@@ -42,19 +43,47 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
|
||||
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([143])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
|
||||
'crop_mem_data': {'kind': 'data'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 143]},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
|
||||
},
|
||||
[
|
||||
('in_placeholder', 'in_node'),
|
||||
|
||||
('in_node', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||
('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'),
|
||||
('memory_in_data', 'crop_mem'),
|
||||
('crop_mem', 'crop_mem_data'),
|
||||
@@ -86,22 +115,54 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
'split': {'kind': 'op', 'op': 'Split'},
|
||||
'split_data_0': {'kind': 'data'},
|
||||
'split_data_1': {'kind': 'data'},
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
|
||||
'crop_mem_data': {'kind': 'data'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
|
||||
|
||||
'shape_2': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_2_data': {'kind': 'data'},
|
||||
'crop_batch_2': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_2_data': {'kind': 'data'},
|
||||
'crop_batch_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_2_data': {'kind': 'data'},
|
||||
'second_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||
'second_dim_2_data': {'kind': 'data'},
|
||||
'gather_shape_2': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_2_data': {'kind': 'data'},
|
||||
'fill_value_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_2_data': {'kind': 'data'},
|
||||
'broadcast_2': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_2_data': {'kind': 'data'},
|
||||
|
||||
'memory_in_constdims': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_constdims_data': {'kind': 'data'},
|
||||
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
|
||||
'crop_mem_constdims_data': {'kind': 'data'},
|
||||
'concat_constdims': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_constdims_data': {'kind': 'data'},
|
||||
'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
|
||||
'memory_out_constdims': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_constdims_data': {'kind': 'data'},
|
||||
'result_constdims': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
|
||||
@@ -121,6 +182,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
('in_node', 'split', {'in': 0}),
|
||||
('split', 'split_data_0', {'out': 0}),
|
||||
('split', 'split_data_1', {'out': 1}),
|
||||
|
||||
('split_data_0', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||
('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'),
|
||||
('memory_in_data', 'crop_mem'),
|
||||
('crop_mem', 'crop_mem_data'),
|
||||
@@ -130,6 +203,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'),
|
||||
('memory_out_data', 'result'),
|
||||
|
||||
('split_data_1', 'shape_2'), ('shape_2', 'shape_2_data'),
|
||||
('shape_2_data', 'crop_batch_2'), ('crop_batch_2', 'crop_batch_2_data'),
|
||||
('crop_batch_dim_2', 'crop_batch_dim_2_data'),
|
||||
('crop_batch_dim_2_data', 'crop_batch_2', {'in': 1}),
|
||||
('second_dim_2', 'second_dim_2_data'), ('second_dim_2_data', 'gather_shape_2', {'in': 1}),
|
||||
('crop_batch_2_data', 'gather_shape_2', {'in': 0}),
|
||||
('gather_shape_2', 'gather_shape_2_data'),
|
||||
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'broadcast_2', {'in': 0}),
|
||||
('gather_shape_2_data', 'broadcast_2', {'in': 1}), ('broadcast_2', 'broadcast_2_data'),
|
||||
('broadcast_2_data', 'memory_in_constdims'),
|
||||
|
||||
('memory_in_constdims', 'memory_in_constdims_data'),
|
||||
('memory_in_constdims_data', 'crop_mem_constdims'),
|
||||
('crop_mem_constdims', 'crop_mem_constdims_data'),
|
||||
|
||||
Reference in New Issue
Block a user