Tdnnf (#5255)
* initial changes (IR not generated) * extractor fix * convert tdnnf (with correct infer) * refactoring + comments in code * added unit tests + couple fixes based on tests * change order for old convolutions * fix pylint * small refactoring * added if to remove changes in old irs * doc updated * fix layout and kernel shapes for old convolutions * fixed test * moved test * fix import in test * fixed backward compatibility * review fixes
This commit is contained in:
parent
6624a77827
commit
7b52e3155a
@ -383,6 +383,7 @@ Standard Kaldi\* Layers:
|
||||
| splicecomponent | No |
|
||||
| tanhcomponent | No |
|
||||
| tdnncomponent | No |
|
||||
| timeheightconvolutioncomponent | No |
|
||||
|
||||
|
||||
## ONNX\* Supported Operators
|
||||
|
@ -160,6 +160,7 @@ extensions/front/kaldi/memoryoffset_batch_update.py
|
||||
extensions/front/kaldi/replace_eltwise_nin1.py
|
||||
extensions/front/kaldi/replace_lstm_node_pattern.py
|
||||
extensions/front/kaldi/replace_lstm_nonlinearity.py
|
||||
extensions/front/kaldi/replace_timeheightconvolution.py
|
||||
extensions/front/kaldi/set_ports.py
|
||||
extensions/front/kaldi/sigmoid_ext.py
|
||||
extensions/front/kaldi/split_recurrent_memoryoffset.py
|
||||
@ -865,6 +866,7 @@ mo/front/kaldi/extractors/softmax_ext.py
|
||||
mo/front/kaldi/extractors/specaugment_component_ext.py
|
||||
mo/front/kaldi/extractors/splice_component_ext.py
|
||||
mo/front/kaldi/extractors/tdnncomponent_ext.py
|
||||
mo/front/kaldi/extractors/timeheightconvolution_ext.py
|
||||
mo/front/kaldi/loader/__init__.py
|
||||
mo/front/kaldi/loader/loader.py
|
||||
mo/front/kaldi/loader/utils.py
|
||||
@ -977,6 +979,7 @@ mo/ops/squeeze.py
|
||||
mo/ops/strided_slice.py
|
||||
mo/ops/tdnncomponent.py
|
||||
mo/ops/tile.py
|
||||
mo/ops/timeheightconvolution.py
|
||||
mo/ops/unsqueeze.py
|
||||
mo/pipeline/__init__.py
|
||||
mo/pipeline/common.py
|
||||
|
@ -1,11 +1,9 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Div
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
@ -45,32 +43,45 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
|
||||
node = match['conv']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||
|
||||
# create Reshape before convolution
|
||||
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
|
||||
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||
shape = Cast(graph, {'name': node_name + '/to_float',
|
||||
'dst_type': dst_dtype}).create_node()
|
||||
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
||||
shape.in_port(0).connect(i_shape.out_port(0))
|
||||
# if transpose will be applied (new models)
|
||||
# shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), patch_stride, C=1]
|
||||
# else (for old models to avoid fails on GNA - should be removed as soon as GNA will be changed)
|
||||
# shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), C=1, patch_stride]
|
||||
sp_dim_1 = 1
|
||||
if node.has_valid('patch_stride'):
|
||||
channel_dim = 2
|
||||
sp_dim_2 = 3
|
||||
frame_height = node.patch_stride
|
||||
else:
|
||||
channel_dim = 3
|
||||
sp_dim_2 = 2
|
||||
frame_height = node.height_in
|
||||
|
||||
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
||||
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
||||
|
||||
N, H = node_to_get_shape_value_of_indices(i_shape, [0]), node_to_get_shape_value_of_indices(i_shape, [1])
|
||||
|
||||
div = create_op_with_const_inputs(
|
||||
graph, Div, {1: float32_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
|
||||
graph, Div, {1: int64_array([frame_height * node.kernel[1]])}, {'name': node_name + '/div_stride_h'})
|
||||
div.in_port(0).connect(H.out_port(0))
|
||||
|
||||
concat = create_op_with_const_inputs(graph, Concat, {2: float32_array([1]), 3: float32_array([node.patch_stride])},
|
||||
concat = create_op_with_const_inputs(graph, Concat, {sp_dim_2: int64_array([frame_height]),
|
||||
channel_dim: int64_array([node.kernel[1]])},
|
||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||
concat.in_port(0).connect(N.out_port(0))
|
||||
concat.in_port(1).connect(div.out_port(0))
|
||||
|
||||
reshape_pattern = Cast(graph, {'name': node_name + '/to_int', 'dst_type': np.int64}).create_node()
|
||||
concat.out_port(0).connect(reshape_pattern.in_port(0))
|
||||
concat.in_port(sp_dim_1).connect(div.out_port(0))
|
||||
|
||||
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
|
||||
reshape_in.in_port(1).connect(reshape_pattern.out_port(0))
|
||||
reshape_in.in_port(1).connect(concat.out_port(0))
|
||||
|
||||
# change layout from NHWC to NCHW
|
||||
# should be replaced by common Permute logic in future
|
||||
transpose = None
|
||||
if channel_dim == 3 and node.channel_dims == 1:
|
||||
transpose = create_op_node_with_second_input(graph, Transpose, int64_array([0, 3, 1, 2]),
|
||||
{'name': node.name + '/Transpose'}, reshape_in)
|
||||
|
||||
# create Reshape after Convolution
|
||||
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
|
||||
@ -78,7 +89,7 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
|
||||
|
||||
# connect input_reshape_node
|
||||
source = node.in_port(0).get_source()
|
||||
node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
|
||||
node.in_port(0).get_connection().set_source(transpose.out_port(0) if transpose else reshape_in.out_port(0))
|
||||
reshape_in.in_port(0).connect(source)
|
||||
# connect output_reshape_node
|
||||
node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
|
||||
|
@ -0,0 +1,103 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.graph.graph import Node, Graph, rename_node
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.convolution import Convolution
|
||||
from mo.ops.memoryoffset import MemoryOffset
|
||||
|
||||
|
||||
class ReplaceTimeHeightConvolutionPattern(FrontReplacementPattern):
|
||||
enabled = True
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.MoveEmbeddedInputsToInputs import MoveEmbeddedInputsToInputs
|
||||
return [MoveEmbeddedInputsToInputs]
|
||||
|
||||
def run_before(self):
|
||||
from extensions.front.kaldi.add_permute_after_convolution import ReplaceConvolutionTranspose
|
||||
from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape
|
||||
from extensions.front.kaldi.memory_offset_adjustment import MemoryOffsetAdjustment
|
||||
from extensions.front.kaldi.split_recurrent_memoryoffset import SplitRecurrentMemoryOffset
|
||||
return [MemoryOffsetAdjustment, ReplaceConvolutionReshape, ReplaceConvolutionTranspose,
|
||||
SplitRecurrentMemoryOffset]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(op='timeheightconvolutioncomponent'):
|
||||
self.replace_timeheightconv(graph, node)
|
||||
|
||||
def replace_timeheightconv(self, graph: Graph, node: Node):
|
||||
req_time_offsets = node.soft_get('time_offsets')
|
||||
offsets = node.soft_get("offsets", [[]])
|
||||
all_time_offsets = list(set(offsets[:, 0]))
|
||||
all_time_offsets.sort()
|
||||
in_name = node.soft_get('name', node.id)
|
||||
rename_node(node, in_name + '/to_delete')
|
||||
|
||||
# create memoryoffsets for context gathering
|
||||
# we need concat if time offsets more than 1
|
||||
concat = Concat(graph, attrs={'name': in_name + '/Concat',
|
||||
'in_ports_count': len(all_time_offsets)}).create_node()
|
||||
i = 0
|
||||
for t in all_time_offsets:
|
||||
# if time offset included in required_time_offsets we don't need default value
|
||||
has_default = t not in req_time_offsets
|
||||
memoff = MemoryOffset(graph, attrs={'name': in_name + '/MemoryOffset_' + str(i),
|
||||
't': t, 'has_default': has_default, 'splitted': False,
|
||||
'pair_name': in_name + '/MemoryOffset_pair_' + str(i)}).create_node()
|
||||
concat.in_port(i).connect(memoff.out_port(0))
|
||||
memoff.in_port(0).connect(node.in_port(0).get_source())
|
||||
i = i + 1
|
||||
|
||||
stride = node.soft_get("height_subsample", 1)
|
||||
|
||||
kernel = int64_array([0, 0])
|
||||
kernel[0] = len(set(offsets[:, 0]))
|
||||
kernel[1] = len(set(offsets[:, 1]))
|
||||
|
||||
pad_h = int64_array([0, 0])
|
||||
pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0
|
||||
pad_h[1] = stride * node.height_out - (node.height_in - max([max(offsets[:, 1]), 0]))
|
||||
|
||||
dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / (kernel[0] - 1) if kernel[0] > 1 else 1
|
||||
dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / (kernel[1] - 1) if kernel[0] > 1 else 1
|
||||
|
||||
conv_attrs = {
|
||||
'name': in_name,
|
||||
'output': node['out_channels'],
|
||||
'height_in': node.height_in,
|
||||
'bias_term': None,
|
||||
'pad': int64_array([[0, 0], [0, 0], [0, 0], pad_h]),
|
||||
'pad_spatial_shape': int64_array([[0, 0], pad_h]),
|
||||
'dilation': int64_array([1, 1, dilation_t, dilation_h]),
|
||||
'kernel': int64_array([node.out_channels, node.in_channels, kernel[0], kernel[1]]),
|
||||
'stride': int64_array([1, 1, 1, stride]),
|
||||
'kernel_spatial': kernel,
|
||||
'input_feature_channel': 1,
|
||||
'output_feature_channel': 0,
|
||||
'channel_dims': int64_array([1]),
|
||||
'spatial_dims': int64_array([2, 3]),
|
||||
'batch_dims': int64_array([0]),
|
||||
'kernel_spatial_idx': int64_array([2, 3]),
|
||||
'group': 1,
|
||||
'reshape_kernel': True,
|
||||
'bias_addable': True,
|
||||
}
|
||||
conv = Convolution(graph, attrs=conv_attrs).create_node()
|
||||
conv.in_port(0).connect(concat.out_port(0))
|
||||
conv.in_port(1).connect(node.in_port(1).get_source())
|
||||
|
||||
# change layout for weights from OHWI to OIHW
|
||||
# in future should be replaced by common Permute mechanics
|
||||
weights = conv.in_port(1).get_source().node.value
|
||||
weights = weights.reshape(int64_array([node.out_channels, -1, node.in_channels]))
|
||||
weights = weights.transpose(int64_array([0, 2, 1]))
|
||||
weights = weights.flatten()
|
||||
conv.in_port(1).get_source().node.value = weights
|
||||
|
||||
conv.in_port(2).connect(node.in_port(2).get_source())
|
||||
node.out_port(0).get_connection().set_source(conv.out_port(0))
|
||||
graph.remove_node(node.id)
|
@ -26,6 +26,10 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
|
||||
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
|
||||
return [RemoveMemoryDuplicationPattern]
|
||||
|
||||
def run_after(self):
|
||||
from extensions.middle.split_tdnn_memoryoffset import SplitTdnnMemoryOffset
|
||||
return [SplitTdnnMemoryOffset]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
|
@ -5,11 +5,9 @@ import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import embed_input
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.kaldi.loader.utils import read_binary_bool_token, read_binary_integer32_token, collect_until_token, \
|
||||
read_binary_float_token
|
||||
from mo.front.kaldi.loader.utils import collect_until_token, read_binary_float_token, read_binary_integer32_token
|
||||
from mo.front.kaldi.utils import read_binary_vector
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
class BatchNormComponentFrontExtractor(FrontExtractorOp):
|
||||
@ -26,18 +24,12 @@ class BatchNormComponentFrontExtractor(FrontExtractorOp):
|
||||
collect_until_token(pb, b'<BlockDim>')
|
||||
block_dim = read_binary_integer32_token(pb)
|
||||
|
||||
if block_dim != dim:
|
||||
raise Error("Dim is not equal BlockDim for BatchNorm is not supported")
|
||||
|
||||
collect_until_token(pb, b'<Epsilon>')
|
||||
eps = read_binary_float_token(pb)
|
||||
|
||||
collect_until_token(pb, b'<TargetRms>')
|
||||
target_rms = read_binary_float_token(pb)
|
||||
|
||||
collect_until_token(pb, b'<TestMode>')
|
||||
test_mode = read_binary_bool_token(pb)
|
||||
|
||||
collect_until_token(pb, b'<StatsMean>')
|
||||
mean = read_binary_vector(pb)
|
||||
|
||||
@ -47,8 +39,13 @@ class BatchNormComponentFrontExtractor(FrontExtractorOp):
|
||||
scale = target_rms / np.sqrt(var + eps)
|
||||
|
||||
shift = - target_rms * mean / np.sqrt(var + eps)
|
||||
attrs = {'out-size': len(shift)}
|
||||
|
||||
scale = np.tile(scale, dim // block_dim)
|
||||
shift = np.tile(shift, dim // block_dim)
|
||||
|
||||
attrs = {'out-size': dim}
|
||||
embed_input(attrs, 1, 'weights', scale)
|
||||
embed_input(attrs, 2, 'biases', shift)
|
||||
|
||||
ScaleShiftOp.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
|
@ -0,0 +1,62 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import embed_input
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.kaldi.loader.utils import collect_until_token, read_token_value
|
||||
from mo.front.kaldi.utils import read_binary_matrix, read_binary_vector, read_binary_vector_of_pairs
|
||||
from mo.ops.timeheightconvolution import TimeHeightConvolutionComponent
|
||||
|
||||
|
||||
class TimeHeightConvolutionFrontExtractor(FrontExtractorOp):
|
||||
op = 'timeheightconvolutioncomponent'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
pb = node.parameters
|
||||
collect_until_token(pb, b'<ConvolutionModel>')
|
||||
in_shape = read_token_value(pb, b'<NumFiltersIn>')
|
||||
out_shape = read_token_value(pb, b'<NumFiltersOut>')
|
||||
height_in = read_token_value(pb, b'<HeightIn>')
|
||||
height_out = read_token_value(pb, b'<HeightOut>')
|
||||
height_subsample = read_token_value(pb, b'<HeightSubsampleOut>')
|
||||
collect_until_token(pb, b'<Offsets>')
|
||||
offsets = read_binary_vector_of_pairs(pb, read_token=False, dtype=np.int32)
|
||||
collect_until_token(pb, b'<RequiredTimeOffsets>')
|
||||
time_offsets = read_binary_vector(pb, read_token=False, dtype=np.int32)
|
||||
collect_until_token(pb, b'<LinearParams>')
|
||||
weights, _ = read_binary_matrix(pb)
|
||||
collect_until_token(pb, b'<BiasParams>')
|
||||
biases = read_binary_vector(pb)
|
||||
|
||||
offsets = offsets.reshape([len(offsets)//2, 2])
|
||||
mapping_rule = { # stride for h axis
|
||||
'height_subsample': height_subsample,
|
||||
# input dimension for h axis
|
||||
'height_in': height_in,
|
||||
# output dimension for h axis
|
||||
'height_out': height_out,
|
||||
# input dimension for channel axis
|
||||
'in_channels': in_shape,
|
||||
# output dimension for channel axis
|
||||
'out_channels': out_shape,
|
||||
# array with pairs like the following
|
||||
# [ (-1, -1) (-1, 0) (-1, 1)
|
||||
# (0, -1) (0, 0) (0, 1)
|
||||
# (1, -1) (1, 0) (1, 1)]
|
||||
# it means that kernel 3x3 will be applied to calculate current value of output
|
||||
'offsets': offsets,
|
||||
# required time offsets to calculate current convolution
|
||||
# time_offsets = [-1, 0, 1] for previous example means no padding for time axis and
|
||||
# 3 values should be prepared
|
||||
# time_offsets = [0] means zero padding [1, 1] for time axis
|
||||
'time_offsets': time_offsets,
|
||||
'out-size': out_shape * height_out}
|
||||
|
||||
embed_input(mapping_rule, 1, 'weights', weights)
|
||||
embed_input(mapping_rule, 2, 'biases', biases)
|
||||
|
||||
TimeHeightConvolutionComponent.update_node_stat(node, mapping_rule)
|
||||
return cls.enabled
|
@ -52,6 +52,7 @@ supported_components = [
|
||||
'sumgroupcomponent',
|
||||
'tanhcomponent',
|
||||
'tdnncomponent',
|
||||
'timeheightconvolutioncomponent',
|
||||
]
|
||||
|
||||
|
||||
|
@ -28,6 +28,13 @@ def read_binary_vector(file_desc: io.BufferedReader, read_token: bool = True, dt
|
||||
return read_blob(file_desc, elements_number, dtype)
|
||||
|
||||
|
||||
def read_binary_vector_of_pairs(file_desc: io.BufferedReader, read_token: bool = True, dtype=np.float32):
|
||||
if read_token:
|
||||
read_placeholder(file_desc)
|
||||
elements_number = read_binary_integer32_token(file_desc)
|
||||
return read_blob(file_desc, 2 * elements_number, dtype)
|
||||
|
||||
|
||||
def read_learning_info(pb: io.BufferedReader):
|
||||
while True:
|
||||
read_placeholder(pb, 1)
|
||||
|
19
model-optimizer/mo/ops/timeheightconvolution.py
Normal file
19
model-optimizer/mo/ops/timeheightconvolution.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class TimeHeightConvolutionComponent(Op):
|
||||
op = 'timeheightconvolutioncomponent'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'infer': None,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
@ -0,0 +1,324 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.kaldi.replace_timeheightconvolution import ReplaceTimeHeightConvolutionPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op, connect_front, const
|
||||
|
||||
|
||||
class TimeheightconvolutionReplacerTest(unittest.TestCase):
|
||||
nodes = {
|
||||
**regular_op('placeholder', {}),
|
||||
**regular_op('timeheightconv', {'op': 'timeheightconvolutioncomponent'}),
|
||||
**const('weights', int64_array([])),
|
||||
**const('biases', int64_array([])),
|
||||
**regular_op('placeholder_out', {}),
|
||||
|
||||
**regular_op('concat', {'type': 'Concat', 'axis': 1}),
|
||||
**regular_op('memoryoffset_0', {'type': None, 'op': 'MemoryOffset', 't': -1, 'has_default': False}),
|
||||
**regular_op('memoryoffset_1', {'type': None, 'op': 'MemoryOffset', 't': 0, 'has_default': False}),
|
||||
**regular_op('memoryoffset_2', {'type': None, 'op': 'MemoryOffset', 't': 1, 'has_default': True}),
|
||||
**regular_op('conv', {'op': 'Convolution', 'type': 'Convolution', 'output': 12, 'height_in': 80}),
|
||||
}
|
||||
|
||||
def test_timeheightconvolution_1offset(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', '0:timeheightconv'),
|
||||
*connect_front('weights', '1:timeheightconv'),
|
||||
*connect_front('biases', '2:timeheightconv'),
|
||||
*connect_front('timeheightconv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
|
||||
conv = graph.nodes['timeheightconv']
|
||||
conv['height_subsample'] = 1
|
||||
conv['height_in'] = 80
|
||||
conv['height_out'] = 80
|
||||
conv['in_channels'] = 1
|
||||
conv['out_channels'] = 12
|
||||
conv['offsets'] = int64_array([[-1, -1], [-1, 0], [-1, 1]])
|
||||
conv['time_offsets'] = [-1]
|
||||
graph.nodes['weights']['value'] = np.zeros([36])
|
||||
|
||||
ref_graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', 'memoryoffset_0'),
|
||||
*connect_front('memoryoffset_0', '0:concat'),
|
||||
*connect_front('concat', '0:conv'),
|
||||
*connect_front('weights', '1:conv'),
|
||||
*connect_front('biases', '2:conv'),
|
||||
*connect_front('conv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
ref_graph.nodes['weights']['value'] = np.zeros([36])
|
||||
new_conv = ref_graph.nodes['conv']
|
||||
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [1, 1]])
|
||||
new_conv['dilation'] = int64_array([1, 1, 1, 1])
|
||||
new_conv['kernel'] = int64_array([12, 1, 1, 3])
|
||||
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||
|
||||
|
||||
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_timeheightconvolution_2_offsets(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', '0:timeheightconv'),
|
||||
*connect_front('weights', '1:timeheightconv'),
|
||||
*connect_front('biases', '2:timeheightconv'),
|
||||
*connect_front('timeheightconv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
|
||||
conv = graph.nodes['timeheightconv']
|
||||
conv['height_subsample'] = 1
|
||||
conv['height_in'] = 80
|
||||
conv['height_out'] = 80
|
||||
conv['in_channels'] = 1
|
||||
conv['out_channels'] = 12
|
||||
conv['offsets'] = int64_array([[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1]])
|
||||
conv['time_offsets'] = int64_array([-1, 0])
|
||||
graph.nodes['weights']['value'] = np.zeros([72])
|
||||
|
||||
ref_graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', 'memoryoffset_0'),
|
||||
*connect_front('placeholder', 'memoryoffset_1'),
|
||||
*connect_front('memoryoffset_0', '0:concat'),
|
||||
*connect_front('memoryoffset_1', '1:concat'),
|
||||
*connect_front('concat', '0:conv'),
|
||||
*connect_front('weights', '1:conv'),
|
||||
*connect_front('biases', '2:conv'),
|
||||
*connect_front('conv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||
new_conv = ref_graph.nodes['conv']
|
||||
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [1, 1]])
|
||||
new_conv['dilation'] = int64_array([1, 1, 1, 1])
|
||||
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||
|
||||
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_timeheightconvolution_2_offsets_def(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', '0:timeheightconv'),
|
||||
*connect_front('weights', '1:timeheightconv'),
|
||||
*connect_front('biases', '2:timeheightconv'),
|
||||
*connect_front('timeheightconv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
|
||||
conv = graph.nodes['timeheightconv']
|
||||
conv['height_subsample'] = 1
|
||||
conv['height_in'] = 80
|
||||
conv['height_out'] = 80
|
||||
conv['in_channels'] = 1
|
||||
conv['out_channels'] = 12
|
||||
conv['offsets'] = int64_array([[0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]])
|
||||
conv['time_offsets'] = int64_array([0])
|
||||
graph.nodes['weights']['value'] = np.zeros([72])
|
||||
|
||||
ref_graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', 'memoryoffset_1'),
|
||||
*connect_front('placeholder', 'memoryoffset_2'),
|
||||
*connect_front('memoryoffset_1', '0:concat'),
|
||||
*connect_front('memoryoffset_2', '1:concat'),
|
||||
*connect_front('concat', '0:conv'),
|
||||
*connect_front('weights', '1:conv'),
|
||||
*connect_front('biases', '2:conv'),
|
||||
*connect_front('conv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||
new_conv = ref_graph.nodes['conv']
|
||||
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [1, 1]])
|
||||
new_conv['dilation'] = int64_array([1, 1, 1, 1])
|
||||
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||
|
||||
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_timeheightconvolution_2_offsets_dilation(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', '0:timeheightconv'),
|
||||
*connect_front('weights', '1:timeheightconv'),
|
||||
*connect_front('biases', '2:timeheightconv'),
|
||||
*connect_front('timeheightconv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
|
||||
conv = graph.nodes['timeheightconv']
|
||||
conv['height_subsample'] = 1
|
||||
conv['height_in'] = 80
|
||||
conv['height_out'] = 80
|
||||
conv['in_channels'] = 1
|
||||
conv['out_channels'] = 12
|
||||
conv['offsets'] = int64_array([[-1, -3], [-1, 0], [-1, 3], [1, -3], [1, 0], [1, 3]])
|
||||
conv['time_offsets'] = int64_array([-1])
|
||||
graph.nodes['weights']['value'] = np.zeros([72])
|
||||
|
||||
ref_graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', 'memoryoffset_0'),
|
||||
*connect_front('placeholder', 'memoryoffset_2'),
|
||||
*connect_front('memoryoffset_0', '0:concat'),
|
||||
*connect_front('memoryoffset_2', '1:concat'),
|
||||
*connect_front('concat', '0:conv'),
|
||||
*connect_front('weights', '1:conv'),
|
||||
*connect_front('biases', '2:conv'),
|
||||
*connect_front('conv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||
new_conv = ref_graph.nodes['conv']
|
||||
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [3, 3]])
|
||||
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||
|
||||
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_timeheightconvolution_2_offsets_pad(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', '0:timeheightconv'),
|
||||
*connect_front('weights', '1:timeheightconv'),
|
||||
*connect_front('biases', '2:timeheightconv'),
|
||||
*connect_front('timeheightconv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
conv = graph.nodes['timeheightconv']
|
||||
conv['height_subsample'] = 1
|
||||
conv['height_in'] = 80
|
||||
conv['height_out'] = 74
|
||||
conv['in_channels'] = 1
|
||||
conv['out_channels'] = 12
|
||||
conv['offsets'] = int64_array([[-1, 0], [-1, 3], [-1, 6], [1, 0], [1, 3], [1, 6]])
|
||||
conv['time_offsets'] = int64_array([-1])
|
||||
graph.nodes['weights']['value'] = np.zeros([72])
|
||||
|
||||
ref_graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', 'memoryoffset_0'),
|
||||
*connect_front('placeholder', 'memoryoffset_2'),
|
||||
*connect_front('memoryoffset_0', '0:concat'),
|
||||
*connect_front('memoryoffset_2', '1:concat'),
|
||||
*connect_front('concat', '0:conv'),
|
||||
*connect_front('weights', '1:conv'),
|
||||
*connect_front('biases', '2:conv'),
|
||||
*connect_front('conv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||
new_conv = ref_graph.nodes['conv']
|
||||
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||
|
||||
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_timeheightconvolution_out_channels(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', '0:timeheightconv'),
|
||||
*connect_front('weights', '1:timeheightconv'),
|
||||
*connect_front('biases', '2:timeheightconv'),
|
||||
*connect_front('timeheightconv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
conv = graph.nodes['timeheightconv']
|
||||
conv['height_subsample'] = 1
|
||||
conv['height_in'] = 80
|
||||
conv['height_out'] = 74
|
||||
conv['in_channels'] = 3
|
||||
conv['out_channels'] = 4
|
||||
conv['offsets'] = int64_array([[-1, 0], [-1, 3], [-1, 6], [1, 0], [1, 3], [1, 6]])
|
||||
conv['time_offsets'] = int64_array([-1])
|
||||
graph.nodes['weights']['value'] = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
||||
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
||||
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72])
|
||||
|
||||
ref_graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', 'memoryoffset_0'),
|
||||
*connect_front('placeholder', 'memoryoffset_2'),
|
||||
*connect_front('memoryoffset_0', '0:concat'),
|
||||
*connect_front('memoryoffset_2', '1:concat'),
|
||||
*connect_front('concat', '0:conv'),
|
||||
*connect_front('weights', '1:conv'),
|
||||
*connect_front('biases', '2:conv'),
|
||||
*connect_front('conv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
ref_graph.nodes['weights']['value'] = np.array([1, 4, 7, 10, 13, 16, 2, 5, 8, 11, 14, 17, 3, 6, 9, 12, 15, 18,
|
||||
19, 22, 25, 28, 31, 34, 20, 23, 26, 29, 32, 35, 21, 24, 27, 30, 33, 36,
|
||||
37, 40, 43, 46, 49, 52, 38, 41, 44, 47, 50, 53, 39, 42, 45, 48, 51, 54,
|
||||
55, 58, 61, 64, 67, 70, 56, 59, 62, 65, 68, 71, 57, 60, 63, 66, 69, 72])
|
||||
new_conv = ref_graph.nodes['conv']
|
||||
new_conv['output'] = 4
|
||||
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||
new_conv['kernel'] = int64_array([4, 3, 2, 3])
|
||||
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||
|
||||
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_timeheightconvolution_2_offsets_stride(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', '0:timeheightconv'),
|
||||
*connect_front('weights', '1:timeheightconv'),
|
||||
*connect_front('biases', '2:timeheightconv'),
|
||||
*connect_front('timeheightconv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
conv = graph.nodes['timeheightconv']
|
||||
conv['height_subsample'] = 2
|
||||
conv['height_in'] = 80
|
||||
conv['height_out'] = 37
|
||||
conv['in_channels'] = 1
|
||||
conv['out_channels'] = 12
|
||||
conv['offsets'] = int64_array([[-1, 0], [-1, 3], [-1, 6], [1, 0], [1, 3], [1, 6]])
|
||||
conv['time_offsets'] = int64_array([-1])
|
||||
graph.nodes['weights']['value'] = np.zeros([72])
|
||||
|
||||
ref_graph = build_graph(self.nodes, [
|
||||
*connect_front('placeholder', 'memoryoffset_0'),
|
||||
*connect_front('placeholder', 'memoryoffset_2'),
|
||||
*connect_front('memoryoffset_0', '0:concat'),
|
||||
*connect_front('memoryoffset_2', '1:concat'),
|
||||
*connect_front('concat', '0:conv'),
|
||||
*connect_front('weights', '1:conv'),
|
||||
*connect_front('biases', '2:conv'),
|
||||
*connect_front('conv', 'placeholder_out')
|
||||
], nodes_with_edges_only=True)
|
||||
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||
new_conv = ref_graph.nodes['conv']
|
||||
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||
new_conv['stride'] = int64_array([1, 1, 1, 2])
|
||||
|
||||
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -22,7 +22,7 @@ class ConvolutionalComponentFrontExtractorTest(KaldiFrontExtractorTest):
|
||||
pb += KaldiFrontExtractorTest.write_tag_with_value('<PatchStride>', 4)
|
||||
pb += KaldiFrontExtractorTest.generate_learn_info()
|
||||
pb += b'<Filters> '
|
||||
pb += KaldiFrontExtractorTest.generate_matrix([2, 1])
|
||||
pb += KaldiFrontExtractorTest.generate_matrix([2, 4])
|
||||
pb += b'<Bias> '
|
||||
pb += KaldiFrontExtractorTest.generate_vector(2)
|
||||
cls.test_node['parameters'] = TestKaldiUtilsLoading.bytesio_from(pb)
|
||||
@ -50,6 +50,6 @@ class ConvolutionalComponentFrontExtractorTest(KaldiFrontExtractorTest):
|
||||
self.assertEqual(self.test_node[attr], val_attrs[attr])
|
||||
|
||||
def test_convolution_blobs(self):
|
||||
self.assertTrue(np.array_equal(self.test_node.weights, [0, 1]))
|
||||
self.assertTrue(np.array_equal(self.test_node.weights, [0, 1, 2, 3, 4, 5, 6, 7]))
|
||||
self.assertTrue(np.array_equal(self.test_node.biases, [0, 1]))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user