Kaldi priors (#6258)
* add priors to loader and counts transformation
* fixes in select insertion for case with context gathering - LSTM - context gathering
fix for edge parallel to ReadValue
extend counts option to case of priors inside mdl model file
* fixed tests
* fixed typo
* fixed issue with input names
* fix priors loading + comments
* fix e2e test: error with not found transformation
* print debug info for dependency graph - should be reverted
* should be reverted: debug commit
* Revert "fix e2e test: error with not found transformation"
This reverts commit 8320fa99bf
.
* revert debug commits
* fixes after review
* review fixes
* review change
* review changes
This commit is contained in:
parent
fc7f80a34e
commit
ccf786438b
@ -50,7 +50,7 @@ The following list provides the Kaldi\*-specific parameters.
|
||||
|
||||
```sh
|
||||
Kaldi-specific parameters:
|
||||
--counts COUNTS A file name with full path to the counts file
|
||||
--counts COUNTS A file name with full path to the counts file or empty string to utilize count values from the model file
|
||||
--remove_output_softmax
|
||||
Removes the Softmax that is the output layer
|
||||
--remove_memory Remove the Memory layer and add new inputs and outputs instead
|
||||
@ -78,6 +78,8 @@ python3 mo.py --input_model wsj_dnn5b_smbr.nnet --counts wsj_dnn5b_smbr.counts -
|
||||
\f$|C|\f$ - number of elements in the counts array;
|
||||
* The normalized counts are subtracted from biases of the last or next to last layer (if last layer is SoftMax).
|
||||
|
||||
> **NOTE:** Model Optimizer will show warning if model contains counts values inside model and `--counts` option is not used.
|
||||
|
||||
* If you want to remove the last SoftMax layer in the topology, launch the Model Optimizer with the
|
||||
`--remove_output_softmax` flag.
|
||||
```sh
|
||||
|
@ -51,7 +51,7 @@ def apply_biases_to_last_layer(graph, counts):
|
||||
outputs_ids = find_outputs(graph)
|
||||
for output in outputs_ids.copy():
|
||||
node = Node(graph, output)
|
||||
if node.op != 'Assign':
|
||||
if node.op != 'Assign' and node.op != "MemoryOffset":
|
||||
continue
|
||||
outputs_ids.remove(output)
|
||||
|
||||
@ -82,6 +82,11 @@ def read_counts_file(file_path):
|
||||
except TypeError:
|
||||
raise Error('Expect counts file to contain list of floats.' +
|
||||
refer_to_faq_msg(90))
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def counts_to_priors(counts):
|
||||
cutoff = 1.00000001e-10
|
||||
cutoff_idxs = np.where(counts < cutoff)
|
||||
counts[cutoff_idxs] = cutoff
|
||||
@ -111,10 +116,21 @@ class ApplyCountsFilePattern(FrontReplacementSubgraph):
|
||||
]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
try:
|
||||
counts = read_counts_file(graph.graph['cmd_params'].counts)
|
||||
except Exception as e:
|
||||
raise Error('Model Optimizer is not able to read counts file {}'.format(graph.graph['cmd_params'].counts) +
|
||||
refer_to_faq_msg(92)) from e
|
||||
# if empty string is in counts, read priors from model itself (on loader stage)
|
||||
if graph.graph['cmd_params'].counts == "":
|
||||
assert isinstance(graph.graph['priors'], (list, np.ndarray)) and len(graph.graph['priors']) != 0, \
|
||||
"Model file does not contain Priors tag with counts values, use separate file instead"
|
||||
counts = graph.graph['priors'].copy()
|
||||
else:
|
||||
# read counts from given file
|
||||
try:
|
||||
counts = read_counts_file(graph.graph['cmd_params'].counts)
|
||||
except Exception as e:
|
||||
raise Error('Model Optimizer is not able to read counts file {}'.format(graph.graph['cmd_params'].counts) +
|
||||
refer_to_faq_msg(92)) from e
|
||||
|
||||
# calculate normalized counts as follows:
|
||||
# c_i=log(c_i/sum(c_j))
|
||||
# set max_float/2 for almost zero c_i (< 1.e-10)
|
||||
counts = counts_to_priors(counts)
|
||||
apply_biases_to_last_layer(graph, counts)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
|
||||
@ -15,15 +16,27 @@ from mo.ops.concat import Concat
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.utils.graph import bfs_search
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.graph import invert_sub_graph_between_nodes
|
||||
|
||||
|
||||
class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
"""
|
||||
Add Select before saving state with Memory to avoid garbage saving
|
||||
Add Select before saving state with Memory to avoid garbage saving.
|
||||
We need to know delay on each node where Select is adding. For that we traverse the whole graph and set frame time
|
||||
for each node using the following rules:
|
||||
* Splice increases frame time by length of its context. If Crop is following Splice - it takes one concrete
|
||||
moment of time, so frame time increases by its value
|
||||
Example:
|
||||
node ---> Splice(-5, -4, ... 0) ---> node
|
||||
frame time: 0 ---> 5 ---> 5
|
||||
node ---> Splice(-5, -4, ... 0) ---> Crop(offset = 2, dim = 1) ---> node
|
||||
frame time: 0 ---> 5 ---> 3 ---> 3
|
||||
* Nodes with several inputs have frame time= max (frame time of each input)
|
||||
* Node with one input have the same frame time as its input
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['fw'] == 'kaldi']
|
||||
|
||||
def run_after(self):
|
||||
from extensions.middle.ReplaceMemoryOffsetWithSplice import ReplaceMemoryOffsetWithMemoryNodePattern
|
||||
@ -36,38 +49,63 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
return [ReplaceSpliceNodePattern]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[('op', dict(op='Assign'))],
|
||||
edges=[])
|
||||
def calculate_frame_time(graph: Graph):
|
||||
# there are either one or two inputs in Kaldi. Only main input can change delay in network.
|
||||
# Usually ivector input has name 'ivector'.
|
||||
inputs = graph.get_op_nodes(op='Parameter')
|
||||
if len(inputs) == 1:
|
||||
inp_name = inputs[0].name
|
||||
elif len(inputs) == 2:
|
||||
if inputs[0].name == 'ivector':
|
||||
inp_name = inputs[1].name
|
||||
elif inputs[1].name == 'ivector':
|
||||
inp_name = inputs[0].name
|
||||
else:
|
||||
raise Error("There are 2 inputs for Kaldi model but we can't find out which one is ivector. " +
|
||||
"Use name \'ivector\' for the corresponding input")
|
||||
else:
|
||||
raise Error("There are {} inputs for Kaldi model but we expect only 1 or 2".format(len(inputs)))
|
||||
|
||||
# sort nodes to calculate delays
|
||||
nodes = list(bfs_search(graph, [inp_name]))
|
||||
nx.set_node_attributes(G=graph, name='frame_time', values=-1)
|
||||
|
||||
for n in nodes:
|
||||
node = Node(graph, n)
|
||||
|
||||
# just ignore data nodes
|
||||
if node.kind != 'op':
|
||||
continue
|
||||
|
||||
# calculate frame_time (delay) that was not calculated
|
||||
if node.frame_time < 0:
|
||||
# Splice increases frame delay
|
||||
if node.op == "Splice":
|
||||
node.frame_time = node.in_port(0).get_source().node.frame_time + len(node.context) - 1
|
||||
# crop often used to get concrete time frame, set frame_time correctly for this case
|
||||
elif node.op == 'Crop':
|
||||
if node.in_port(0).get_connection().get_source().node.op == 'Splice':
|
||||
splice_node = node.in_port(0).get_source().node
|
||||
assert len(node.offset) == 1
|
||||
assert len(node.dim) == 1
|
||||
new_delay = splice_node.context[node.offset[0] // node.dim[0]] - splice_node.context[0]
|
||||
node.frame_time = splice_node.in_port(0).get_source().node.frame_time + new_delay
|
||||
else:
|
||||
node.frame_time = node.in_port(0).get_source().node.frame_time
|
||||
# for node with several inputs frame_time = maximum of delays from branches
|
||||
else:
|
||||
# find out maximum of delay and check that we have at least one branch with another delay
|
||||
node.frame_time = 0
|
||||
for inp in node.in_ports():
|
||||
if node.in_port(inp).disconnected():
|
||||
continue
|
||||
in_node = node.in_port(inp).get_source().node
|
||||
if in_node.frame_time > node.frame_time:
|
||||
node.frame_time = in_node.frame_time
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
|
||||
if node.name == 'iteration_number_out':
|
||||
return
|
||||
|
||||
# calculate length of context when state of inference becomes meaningful
|
||||
inputs = []
|
||||
for n in graph.get_op_nodes(**{'op': 'Parameter'}):
|
||||
inputs.append(n)
|
||||
|
||||
in_nodes = []
|
||||
for inp in inputs:
|
||||
for ins in inp.out_port(0).get_destinations():
|
||||
in_nodes.append(ins.node.name)
|
||||
|
||||
context_len = 1
|
||||
try:
|
||||
subgraph = invert_sub_graph_between_nodes(graph, [node.in_port(0).get_source().node.name], in_nodes)
|
||||
except Error:
|
||||
return
|
||||
|
||||
for n in subgraph:
|
||||
n_node = Node(graph, n)
|
||||
if n_node.kind == 'op' and n_node.op == 'Splice':
|
||||
context_len += len(n_node.context) - 1
|
||||
def insert_select(graph: Graph, node: Node):
|
||||
context_len = node.frame_time + 1
|
||||
|
||||
if context_len == 1:
|
||||
return
|
||||
@ -87,7 +125,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
('mem_in_data', dict(shape=int64_array([context_len]))),
|
||||
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
|
||||
offset=int64_array([1]),
|
||||
dim=int64_array([context_len-1]))),
|
||||
dim=int64_array([context_len - 1]))),
|
||||
('crop_mem_in_data', dict()),
|
||||
('concat', dict(op='Concat', axis=1)),
|
||||
('concat_data', dict()),
|
||||
@ -115,17 +153,17 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
else:
|
||||
init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
|
||||
mem_out = ReadValue(graph, {'name': 'iteration_number',
|
||||
'variable_id': 'iteration_'+node.name}).create_node()
|
||||
'variable_id': 'iteration_' + node.name}).create_node()
|
||||
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
|
||||
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
|
||||
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
|
||||
'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node()
|
||||
cut_first.in_port(0).connect(mem_out.out_port(0))
|
||||
ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
|
||||
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
|
||||
concat.in_port(0).connect(cut_first.out_port(0))
|
||||
concat.in_port(1).connect(ones.out_port(0))
|
||||
mem_in = Assign(graph, {'name': 'iteration_number_out',
|
||||
'variable_id': 'iteration_'+node.name}).create_node()
|
||||
'variable_id': 'iteration_' + node.name}).create_node()
|
||||
mem_in.in_port(0).connect(concat.out_port(0))
|
||||
res = Result(graph, {}).create_node()
|
||||
mem_in.out_port(0).connect(res.in_port(0))
|
||||
@ -143,3 +181,19 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
select_node.in_port(0).connect(cast_in.out_port(0))
|
||||
select_node.out_port(0).connect(node.in_port(0))
|
||||
select_node.out_port(0).data.set_shape(in_node_shape)
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
if np.all([node.soft_get('name', node.id) == 'iteration_number_out'
|
||||
for node in graph.get_op_nodes(op='Assign')]):
|
||||
return
|
||||
|
||||
self.calculate_frame_time(graph)
|
||||
|
||||
for node in graph.get_op_nodes(op='Assign'):
|
||||
if node.soft_get('name', node.id) == 'iteration_number_out':
|
||||
continue
|
||||
self.insert_select(graph, node)
|
||||
|
||||
for node in graph.get_op_nodes():
|
||||
if 'frame_time' in node:
|
||||
del node['frame_time']
|
||||
|
@ -77,7 +77,7 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
|
||||
outs = input_node_out_port.get_destinations()
|
||||
for in_port in outs:
|
||||
out_ = in_port.node
|
||||
if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice':
|
||||
if out_.op == 'Concat' and out_ == out_node_in_ports[0].node:
|
||||
crop_input = Crop(graph, {'name': 'Splice_Crop',
|
||||
'axis': int64_array([1]),
|
||||
'offset': int64_array([-min(0, in_shape[1] * node_t)]),
|
||||
|
@ -14,6 +14,7 @@ from mo.front.extractor import add_outputs_identity
|
||||
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
|
||||
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
|
||||
collect_until_token, collect_until_token_and_read, create_edge_attrs, get_args_for_specifier
|
||||
from mo.front.kaldi.utils import read_binary_vector
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.const import Const
|
||||
from mo.utils.error import Error
|
||||
@ -222,6 +223,20 @@ def load_kaldi_nnet3_model(graph, file_descr, nnet_name):
|
||||
o_n['parameters']['element_size'] = int64_array([1, node.shape[1]])
|
||||
|
||||
load_components(file_descr, graph, component_layer_map)
|
||||
load_priors(file_descr, graph)
|
||||
|
||||
|
||||
def load_priors(file_descr, graph):
|
||||
try:
|
||||
collect_until_token(file_descr, b'<Priors>')
|
||||
except Error:
|
||||
# just ignore if priors were not found
|
||||
return
|
||||
if graph.graph['cmd_params'].counts is not None:
|
||||
graph.graph['priors'] = read_binary_vector(file_descr)
|
||||
else:
|
||||
log.error("Model contains Prior values, if you want to embed them into the generated IR add option --counts=\"\" to command line",
|
||||
extra={'is_warning': True})
|
||||
|
||||
|
||||
def load_components(file_descr, graph, component_layer_map=None):
|
||||
|
@ -75,6 +75,17 @@ class CanonicalizePathCheckExistenceAction(CanonicalizePathAction):
|
||||
' but "{}" does not exist.'.format(self.dest, name))
|
||||
|
||||
|
||||
class CanonicalizePathCheckExistenceIfNeededAction(CanonicalizePathCheckExistenceAction):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values is not None:
|
||||
if isinstance(values, str):
|
||||
if values != "":
|
||||
super().__call__(parser, namespace, values, option_string)
|
||||
else:
|
||||
setattr(namespace, self.dest, values)
|
||||
|
||||
|
||||
class DeprecatedCanonicalizePathCheckExistenceAction(CanonicalizePathCheckExistenceAction):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
super().__call__(parser, namespace, values, option_string)
|
||||
@ -415,7 +426,7 @@ def get_mxnet_cli_options():
|
||||
|
||||
def get_kaldi_cli_options():
|
||||
d = {
|
||||
'counts': '- A file name with full path to the counts file',
|
||||
'counts': '- A file name with full path to the counts file or empty string if you want to use counts from model',
|
||||
'remove_output_softmax': '- Removes the SoftMax layer that is the output layer',
|
||||
'remove_memory': '- Removes the Memory layer and use additional inputs and outputs instead'
|
||||
}
|
||||
@ -609,7 +620,7 @@ def get_kaldi_cli_parser(parser: argparse.ArgumentParser = None):
|
||||
kaldi_group.add_argument("--counts",
|
||||
help="Path to the counts file",
|
||||
default=None,
|
||||
action=CanonicalizePathCheckExistenceAction)
|
||||
action=CanonicalizePathCheckExistenceIfNeededAction)
|
||||
|
||||
kaldi_group.add_argument("--remove_output_softmax",
|
||||
help="Removes the SoftMax layer that is the output layer",
|
||||
|
@ -16,11 +16,11 @@ class InsertSelectTests(unittest.TestCase):
|
||||
# graph have no splices - selects should not be inserted
|
||||
def test_insert_select_0(self):
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
[('input', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
@ -33,7 +33,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
# graph contains 1 splice with context length 5, should be inserted select with memory as counter with length 5
|
||||
def test_insert_select_1(self):
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
@ -41,7 +41,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
[('input', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'memory')
|
||||
@ -49,7 +49,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
nodes_with_edges_only=True)
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
@ -109,7 +109,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
[('input', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
@ -168,7 +168,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
# should be inserted select with memory as counter with length 5
|
||||
def test_insert_select_2(self):
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -178,7 +178,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
[('input', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
@ -187,7 +187,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
nodes_with_edges_only=True)
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -249,7 +249,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
[('input', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
@ -308,7 +308,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
# should be inserted select with memory as counter with length 7
|
||||
def test_insert_select_3(self):
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -318,7 +318,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
[('input', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
@ -327,7 +327,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
nodes_with_edges_only=True)
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -389,7 +389,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
[('input', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
|
@ -107,8 +107,6 @@ class ReplaceMemoryOffsetNodePatternTests(unittest.TestCase):
|
||||
'splice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 1)},
|
||||
'splice_data': {'kind': 'data', 'shape': [1, 78]},
|
||||
'crop': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 13},
|
||||
'crop_input': {'kind': 'op', 'op': 'Crop', 'offset': 65, 'dim': 13},
|
||||
'crop_input_data': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memoryoffset_2_data': {'kind': 'data', 'shape': [1, 13]},
|
||||
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
|
||||
},
|
||||
@ -118,10 +116,8 @@ class ReplaceMemoryOffsetNodePatternTests(unittest.TestCase):
|
||||
('splice', 'splice_data'),
|
||||
('splice_data', 'crop'),
|
||||
('crop', 'memoryoffset_2_data'),
|
||||
('splice_data', 'crop_input'),
|
||||
('crop_input', 'crop_input_data'),
|
||||
('memoryoffset_2_data', 'out_placeholder'),
|
||||
('crop_input_data', 'out_placeholder')
|
||||
('in_node', 'out_placeholder')
|
||||
]
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user