Tdnnf (#5255)
* initial changes (IR not generated) * extractor fix * convert tdnnf (with correct infer) * refactoring + comments in code * added unit tests + couple fixes based on tests * change order for old convolutions * fix pylint * small refactoring * added if to remove changes in old irs * doc updated * fix layout and kernel shapes for old convolutions * fixed test * moved test * fix import in test * fixed backward compatibility * review fixes
This commit is contained in:
parent
6624a77827
commit
7b52e3155a
@ -383,6 +383,7 @@ Standard Kaldi\* Layers:
|
|||||||
| splicecomponent | No |
|
| splicecomponent | No |
|
||||||
| tanhcomponent | No |
|
| tanhcomponent | No |
|
||||||
| tdnncomponent | No |
|
| tdnncomponent | No |
|
||||||
|
| timeheightconvolutioncomponent | No |
|
||||||
|
|
||||||
|
|
||||||
## ONNX\* Supported Operators
|
## ONNX\* Supported Operators
|
||||||
|
@ -160,6 +160,7 @@ extensions/front/kaldi/memoryoffset_batch_update.py
|
|||||||
extensions/front/kaldi/replace_eltwise_nin1.py
|
extensions/front/kaldi/replace_eltwise_nin1.py
|
||||||
extensions/front/kaldi/replace_lstm_node_pattern.py
|
extensions/front/kaldi/replace_lstm_node_pattern.py
|
||||||
extensions/front/kaldi/replace_lstm_nonlinearity.py
|
extensions/front/kaldi/replace_lstm_nonlinearity.py
|
||||||
|
extensions/front/kaldi/replace_timeheightconvolution.py
|
||||||
extensions/front/kaldi/set_ports.py
|
extensions/front/kaldi/set_ports.py
|
||||||
extensions/front/kaldi/sigmoid_ext.py
|
extensions/front/kaldi/sigmoid_ext.py
|
||||||
extensions/front/kaldi/split_recurrent_memoryoffset.py
|
extensions/front/kaldi/split_recurrent_memoryoffset.py
|
||||||
@ -865,6 +866,7 @@ mo/front/kaldi/extractors/softmax_ext.py
|
|||||||
mo/front/kaldi/extractors/specaugment_component_ext.py
|
mo/front/kaldi/extractors/specaugment_component_ext.py
|
||||||
mo/front/kaldi/extractors/splice_component_ext.py
|
mo/front/kaldi/extractors/splice_component_ext.py
|
||||||
mo/front/kaldi/extractors/tdnncomponent_ext.py
|
mo/front/kaldi/extractors/tdnncomponent_ext.py
|
||||||
|
mo/front/kaldi/extractors/timeheightconvolution_ext.py
|
||||||
mo/front/kaldi/loader/__init__.py
|
mo/front/kaldi/loader/__init__.py
|
||||||
mo/front/kaldi/loader/loader.py
|
mo/front/kaldi/loader/loader.py
|
||||||
mo/front/kaldi/loader/utils.py
|
mo/front/kaldi/loader/utils.py
|
||||||
@ -977,6 +979,7 @@ mo/ops/squeeze.py
|
|||||||
mo/ops/strided_slice.py
|
mo/ops/strided_slice.py
|
||||||
mo/ops/tdnncomponent.py
|
mo/ops/tdnncomponent.py
|
||||||
mo/ops/tile.py
|
mo/ops/tile.py
|
||||||
|
mo/ops/timeheightconvolution.py
|
||||||
mo/ops/unsqueeze.py
|
mo/ops/unsqueeze.py
|
||||||
mo/pipeline/__init__.py
|
mo/pipeline/__init__.py
|
||||||
mo/pipeline/common.py
|
mo/pipeline/common.py
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
# Copyright (C) 2018-2021 Intel Corporation
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from extensions.ops.Cast import Cast
|
|
||||||
from extensions.ops.elementwise import Div
|
from extensions.ops.elementwise import Div
|
||||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
from extensions.ops.transpose import Transpose
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.common.replacement import FrontReplacementPattern
|
from mo.front.common.replacement import FrontReplacementPattern
|
||||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
@ -45,32 +43,45 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
|
|||||||
node = match['conv']
|
node = match['conv']
|
||||||
node_name = node.soft_get('name', node.id)
|
node_name = node.soft_get('name', node.id)
|
||||||
|
|
||||||
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
|
||||||
|
|
||||||
# create Reshape before convolution
|
# create Reshape before convolution
|
||||||
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
|
# if transpose will be applied (new models)
|
||||||
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
# shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), patch_stride, C=1]
|
||||||
shape = Cast(graph, {'name': node_name + '/to_float',
|
# else (for old models to avoid fails on GNA - should be removed as soon as GNA will be changed)
|
||||||
'dst_type': dst_dtype}).create_node()
|
# shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), C=1, patch_stride]
|
||||||
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
sp_dim_1 = 1
|
||||||
shape.in_port(0).connect(i_shape.out_port(0))
|
if node.has_valid('patch_stride'):
|
||||||
|
channel_dim = 2
|
||||||
|
sp_dim_2 = 3
|
||||||
|
frame_height = node.patch_stride
|
||||||
|
else:
|
||||||
|
channel_dim = 3
|
||||||
|
sp_dim_2 = 2
|
||||||
|
frame_height = node.height_in
|
||||||
|
|
||||||
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||||
|
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
||||||
|
|
||||||
|
N, H = node_to_get_shape_value_of_indices(i_shape, [0]), node_to_get_shape_value_of_indices(i_shape, [1])
|
||||||
|
|
||||||
div = create_op_with_const_inputs(
|
div = create_op_with_const_inputs(
|
||||||
graph, Div, {1: float32_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
|
graph, Div, {1: int64_array([frame_height * node.kernel[1]])}, {'name': node_name + '/div_stride_h'})
|
||||||
div.in_port(0).connect(H.out_port(0))
|
div.in_port(0).connect(H.out_port(0))
|
||||||
|
|
||||||
concat = create_op_with_const_inputs(graph, Concat, {2: float32_array([1]), 3: float32_array([node.patch_stride])},
|
concat = create_op_with_const_inputs(graph, Concat, {sp_dim_2: int64_array([frame_height]),
|
||||||
|
channel_dim: int64_array([node.kernel[1]])},
|
||||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||||
concat.in_port(0).connect(N.out_port(0))
|
concat.in_port(0).connect(N.out_port(0))
|
||||||
concat.in_port(1).connect(div.out_port(0))
|
concat.in_port(sp_dim_1).connect(div.out_port(0))
|
||||||
|
|
||||||
reshape_pattern = Cast(graph, {'name': node_name + '/to_int', 'dst_type': np.int64}).create_node()
|
|
||||||
concat.out_port(0).connect(reshape_pattern.in_port(0))
|
|
||||||
|
|
||||||
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
|
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
|
||||||
reshape_in.in_port(1).connect(reshape_pattern.out_port(0))
|
reshape_in.in_port(1).connect(concat.out_port(0))
|
||||||
|
|
||||||
|
# change layout from NHWC to NCHW
|
||||||
|
# should be replaced by common Permute logic in future
|
||||||
|
transpose = None
|
||||||
|
if channel_dim == 3 and node.channel_dims == 1:
|
||||||
|
transpose = create_op_node_with_second_input(graph, Transpose, int64_array([0, 3, 1, 2]),
|
||||||
|
{'name': node.name + '/Transpose'}, reshape_in)
|
||||||
|
|
||||||
# create Reshape after Convolution
|
# create Reshape after Convolution
|
||||||
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
|
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
|
||||||
@ -78,7 +89,7 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
|
|||||||
|
|
||||||
# connect input_reshape_node
|
# connect input_reshape_node
|
||||||
source = node.in_port(0).get_source()
|
source = node.in_port(0).get_source()
|
||||||
node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
|
node.in_port(0).get_connection().set_source(transpose.out_port(0) if transpose else reshape_in.out_port(0))
|
||||||
reshape_in.in_port(0).connect(source)
|
reshape_in.in_port(0).connect(source)
|
||||||
# connect output_reshape_node
|
# connect output_reshape_node
|
||||||
node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
|
node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
|
||||||
|
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright (C) 2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.front.common.replacement import FrontReplacementPattern
|
||||||
|
from mo.graph.graph import Node, Graph, rename_node
|
||||||
|
from mo.ops.concat import Concat
|
||||||
|
from mo.ops.convolution import Convolution
|
||||||
|
from mo.ops.memoryoffset import MemoryOffset
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceTimeHeightConvolutionPattern(FrontReplacementPattern):
|
||||||
|
enabled = True
|
||||||
|
run_not_recursively = True
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
from extensions.front.MoveEmbeddedInputsToInputs import MoveEmbeddedInputsToInputs
|
||||||
|
return [MoveEmbeddedInputsToInputs]
|
||||||
|
|
||||||
|
def run_before(self):
|
||||||
|
from extensions.front.kaldi.add_permute_after_convolution import ReplaceConvolutionTranspose
|
||||||
|
from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape
|
||||||
|
from extensions.front.kaldi.memory_offset_adjustment import MemoryOffsetAdjustment
|
||||||
|
from extensions.front.kaldi.split_recurrent_memoryoffset import SplitRecurrentMemoryOffset
|
||||||
|
return [MemoryOffsetAdjustment, ReplaceConvolutionReshape, ReplaceConvolutionTranspose,
|
||||||
|
SplitRecurrentMemoryOffset]
|
||||||
|
|
||||||
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
for node in graph.get_op_nodes(op='timeheightconvolutioncomponent'):
|
||||||
|
self.replace_timeheightconv(graph, node)
|
||||||
|
|
||||||
|
def replace_timeheightconv(self, graph: Graph, node: Node):
|
||||||
|
req_time_offsets = node.soft_get('time_offsets')
|
||||||
|
offsets = node.soft_get("offsets", [[]])
|
||||||
|
all_time_offsets = list(set(offsets[:, 0]))
|
||||||
|
all_time_offsets.sort()
|
||||||
|
in_name = node.soft_get('name', node.id)
|
||||||
|
rename_node(node, in_name + '/to_delete')
|
||||||
|
|
||||||
|
# create memoryoffsets for context gathering
|
||||||
|
# we need concat if time offsets more than 1
|
||||||
|
concat = Concat(graph, attrs={'name': in_name + '/Concat',
|
||||||
|
'in_ports_count': len(all_time_offsets)}).create_node()
|
||||||
|
i = 0
|
||||||
|
for t in all_time_offsets:
|
||||||
|
# if time offset included in required_time_offsets we don't need default value
|
||||||
|
has_default = t not in req_time_offsets
|
||||||
|
memoff = MemoryOffset(graph, attrs={'name': in_name + '/MemoryOffset_' + str(i),
|
||||||
|
't': t, 'has_default': has_default, 'splitted': False,
|
||||||
|
'pair_name': in_name + '/MemoryOffset_pair_' + str(i)}).create_node()
|
||||||
|
concat.in_port(i).connect(memoff.out_port(0))
|
||||||
|
memoff.in_port(0).connect(node.in_port(0).get_source())
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
stride = node.soft_get("height_subsample", 1)
|
||||||
|
|
||||||
|
kernel = int64_array([0, 0])
|
||||||
|
kernel[0] = len(set(offsets[:, 0]))
|
||||||
|
kernel[1] = len(set(offsets[:, 1]))
|
||||||
|
|
||||||
|
pad_h = int64_array([0, 0])
|
||||||
|
pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0
|
||||||
|
pad_h[1] = stride * node.height_out - (node.height_in - max([max(offsets[:, 1]), 0]))
|
||||||
|
|
||||||
|
dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / (kernel[0] - 1) if kernel[0] > 1 else 1
|
||||||
|
dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / (kernel[1] - 1) if kernel[0] > 1 else 1
|
||||||
|
|
||||||
|
conv_attrs = {
|
||||||
|
'name': in_name,
|
||||||
|
'output': node['out_channels'],
|
||||||
|
'height_in': node.height_in,
|
||||||
|
'bias_term': None,
|
||||||
|
'pad': int64_array([[0, 0], [0, 0], [0, 0], pad_h]),
|
||||||
|
'pad_spatial_shape': int64_array([[0, 0], pad_h]),
|
||||||
|
'dilation': int64_array([1, 1, dilation_t, dilation_h]),
|
||||||
|
'kernel': int64_array([node.out_channels, node.in_channels, kernel[0], kernel[1]]),
|
||||||
|
'stride': int64_array([1, 1, 1, stride]),
|
||||||
|
'kernel_spatial': kernel,
|
||||||
|
'input_feature_channel': 1,
|
||||||
|
'output_feature_channel': 0,
|
||||||
|
'channel_dims': int64_array([1]),
|
||||||
|
'spatial_dims': int64_array([2, 3]),
|
||||||
|
'batch_dims': int64_array([0]),
|
||||||
|
'kernel_spatial_idx': int64_array([2, 3]),
|
||||||
|
'group': 1,
|
||||||
|
'reshape_kernel': True,
|
||||||
|
'bias_addable': True,
|
||||||
|
}
|
||||||
|
conv = Convolution(graph, attrs=conv_attrs).create_node()
|
||||||
|
conv.in_port(0).connect(concat.out_port(0))
|
||||||
|
conv.in_port(1).connect(node.in_port(1).get_source())
|
||||||
|
|
||||||
|
# change layout for weights from OHWI to OIHW
|
||||||
|
# in future should be replaced by common Permute mechanics
|
||||||
|
weights = conv.in_port(1).get_source().node.value
|
||||||
|
weights = weights.reshape(int64_array([node.out_channels, -1, node.in_channels]))
|
||||||
|
weights = weights.transpose(int64_array([0, 2, 1]))
|
||||||
|
weights = weights.flatten()
|
||||||
|
conv.in_port(1).get_source().node.value = weights
|
||||||
|
|
||||||
|
conv.in_port(2).connect(node.in_port(2).get_source())
|
||||||
|
node.out_port(0).get_connection().set_source(conv.out_port(0))
|
||||||
|
graph.remove_node(node.id)
|
@ -26,6 +26,10 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
|
|||||||
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
|
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
|
||||||
return [RemoveMemoryDuplicationPattern]
|
return [RemoveMemoryDuplicationPattern]
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
from extensions.middle.split_tdnn_memoryoffset import SplitTdnnMemoryOffset
|
||||||
|
return [SplitTdnnMemoryOffset]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pattern():
|
def pattern():
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -5,11 +5,9 @@ import numpy as np
|
|||||||
|
|
||||||
from mo.front.caffe.extractors.utils import embed_input
|
from mo.front.caffe.extractors.utils import embed_input
|
||||||
from mo.front.extractor import FrontExtractorOp
|
from mo.front.extractor import FrontExtractorOp
|
||||||
from mo.front.kaldi.loader.utils import read_binary_bool_token, read_binary_integer32_token, collect_until_token, \
|
from mo.front.kaldi.loader.utils import collect_until_token, read_binary_float_token, read_binary_integer32_token
|
||||||
read_binary_float_token
|
|
||||||
from mo.front.kaldi.utils import read_binary_vector
|
from mo.front.kaldi.utils import read_binary_vector
|
||||||
from mo.ops.scale_shift import ScaleShiftOp
|
from mo.ops.scale_shift import ScaleShiftOp
|
||||||
from mo.utils.error import Error
|
|
||||||
|
|
||||||
|
|
||||||
class BatchNormComponentFrontExtractor(FrontExtractorOp):
|
class BatchNormComponentFrontExtractor(FrontExtractorOp):
|
||||||
@ -26,18 +24,12 @@ class BatchNormComponentFrontExtractor(FrontExtractorOp):
|
|||||||
collect_until_token(pb, b'<BlockDim>')
|
collect_until_token(pb, b'<BlockDim>')
|
||||||
block_dim = read_binary_integer32_token(pb)
|
block_dim = read_binary_integer32_token(pb)
|
||||||
|
|
||||||
if block_dim != dim:
|
|
||||||
raise Error("Dim is not equal BlockDim for BatchNorm is not supported")
|
|
||||||
|
|
||||||
collect_until_token(pb, b'<Epsilon>')
|
collect_until_token(pb, b'<Epsilon>')
|
||||||
eps = read_binary_float_token(pb)
|
eps = read_binary_float_token(pb)
|
||||||
|
|
||||||
collect_until_token(pb, b'<TargetRms>')
|
collect_until_token(pb, b'<TargetRms>')
|
||||||
target_rms = read_binary_float_token(pb)
|
target_rms = read_binary_float_token(pb)
|
||||||
|
|
||||||
collect_until_token(pb, b'<TestMode>')
|
|
||||||
test_mode = read_binary_bool_token(pb)
|
|
||||||
|
|
||||||
collect_until_token(pb, b'<StatsMean>')
|
collect_until_token(pb, b'<StatsMean>')
|
||||||
mean = read_binary_vector(pb)
|
mean = read_binary_vector(pb)
|
||||||
|
|
||||||
@ -47,8 +39,13 @@ class BatchNormComponentFrontExtractor(FrontExtractorOp):
|
|||||||
scale = target_rms / np.sqrt(var + eps)
|
scale = target_rms / np.sqrt(var + eps)
|
||||||
|
|
||||||
shift = - target_rms * mean / np.sqrt(var + eps)
|
shift = - target_rms * mean / np.sqrt(var + eps)
|
||||||
attrs = {'out-size': len(shift)}
|
|
||||||
|
scale = np.tile(scale, dim // block_dim)
|
||||||
|
shift = np.tile(shift, dim // block_dim)
|
||||||
|
|
||||||
|
attrs = {'out-size': dim}
|
||||||
embed_input(attrs, 1, 'weights', scale)
|
embed_input(attrs, 1, 'weights', scale)
|
||||||
embed_input(attrs, 2, 'biases', shift)
|
embed_input(attrs, 2, 'biases', shift)
|
||||||
|
|
||||||
ScaleShiftOp.update_node_stat(node, attrs)
|
ScaleShiftOp.update_node_stat(node, attrs)
|
||||||
return cls.enabled
|
return cls.enabled
|
||||||
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright (C) 2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mo.front.caffe.extractors.utils import embed_input
|
||||||
|
from mo.front.extractor import FrontExtractorOp
|
||||||
|
from mo.front.kaldi.loader.utils import collect_until_token, read_token_value
|
||||||
|
from mo.front.kaldi.utils import read_binary_matrix, read_binary_vector, read_binary_vector_of_pairs
|
||||||
|
from mo.ops.timeheightconvolution import TimeHeightConvolutionComponent
|
||||||
|
|
||||||
|
|
||||||
|
class TimeHeightConvolutionFrontExtractor(FrontExtractorOp):
|
||||||
|
op = 'timeheightconvolutioncomponent'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract(cls, node):
|
||||||
|
pb = node.parameters
|
||||||
|
collect_until_token(pb, b'<ConvolutionModel>')
|
||||||
|
in_shape = read_token_value(pb, b'<NumFiltersIn>')
|
||||||
|
out_shape = read_token_value(pb, b'<NumFiltersOut>')
|
||||||
|
height_in = read_token_value(pb, b'<HeightIn>')
|
||||||
|
height_out = read_token_value(pb, b'<HeightOut>')
|
||||||
|
height_subsample = read_token_value(pb, b'<HeightSubsampleOut>')
|
||||||
|
collect_until_token(pb, b'<Offsets>')
|
||||||
|
offsets = read_binary_vector_of_pairs(pb, read_token=False, dtype=np.int32)
|
||||||
|
collect_until_token(pb, b'<RequiredTimeOffsets>')
|
||||||
|
time_offsets = read_binary_vector(pb, read_token=False, dtype=np.int32)
|
||||||
|
collect_until_token(pb, b'<LinearParams>')
|
||||||
|
weights, _ = read_binary_matrix(pb)
|
||||||
|
collect_until_token(pb, b'<BiasParams>')
|
||||||
|
biases = read_binary_vector(pb)
|
||||||
|
|
||||||
|
offsets = offsets.reshape([len(offsets)//2, 2])
|
||||||
|
mapping_rule = { # stride for h axis
|
||||||
|
'height_subsample': height_subsample,
|
||||||
|
# input dimension for h axis
|
||||||
|
'height_in': height_in,
|
||||||
|
# output dimension for h axis
|
||||||
|
'height_out': height_out,
|
||||||
|
# input dimension for channel axis
|
||||||
|
'in_channels': in_shape,
|
||||||
|
# output dimension for channel axis
|
||||||
|
'out_channels': out_shape,
|
||||||
|
# array with pairs like the following
|
||||||
|
# [ (-1, -1) (-1, 0) (-1, 1)
|
||||||
|
# (0, -1) (0, 0) (0, 1)
|
||||||
|
# (1, -1) (1, 0) (1, 1)]
|
||||||
|
# it means that kernel 3x3 will be applied to calculate current value of output
|
||||||
|
'offsets': offsets,
|
||||||
|
# required time offsets to calculate current convolution
|
||||||
|
# time_offsets = [-1, 0, 1] for previous example means no padding for time axis and
|
||||||
|
# 3 values should be prepared
|
||||||
|
# time_offsets = [0] means zero padding [1, 1] for time axis
|
||||||
|
'time_offsets': time_offsets,
|
||||||
|
'out-size': out_shape * height_out}
|
||||||
|
|
||||||
|
embed_input(mapping_rule, 1, 'weights', weights)
|
||||||
|
embed_input(mapping_rule, 2, 'biases', biases)
|
||||||
|
|
||||||
|
TimeHeightConvolutionComponent.update_node_stat(node, mapping_rule)
|
||||||
|
return cls.enabled
|
@ -52,6 +52,7 @@ supported_components = [
|
|||||||
'sumgroupcomponent',
|
'sumgroupcomponent',
|
||||||
'tanhcomponent',
|
'tanhcomponent',
|
||||||
'tdnncomponent',
|
'tdnncomponent',
|
||||||
|
'timeheightconvolutioncomponent',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,6 +28,13 @@ def read_binary_vector(file_desc: io.BufferedReader, read_token: bool = True, dt
|
|||||||
return read_blob(file_desc, elements_number, dtype)
|
return read_blob(file_desc, elements_number, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def read_binary_vector_of_pairs(file_desc: io.BufferedReader, read_token: bool = True, dtype=np.float32):
|
||||||
|
if read_token:
|
||||||
|
read_placeholder(file_desc)
|
||||||
|
elements_number = read_binary_integer32_token(file_desc)
|
||||||
|
return read_blob(file_desc, 2 * elements_number, dtype)
|
||||||
|
|
||||||
|
|
||||||
def read_learning_info(pb: io.BufferedReader):
|
def read_learning_info(pb: io.BufferedReader):
|
||||||
while True:
|
while True:
|
||||||
read_placeholder(pb, 1)
|
read_placeholder(pb, 1)
|
||||||
|
19
model-optimizer/mo/ops/timeheightconvolution.py
Normal file
19
model-optimizer/mo/ops/timeheightconvolution.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from mo.graph.graph import Graph
|
||||||
|
from mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
|
class TimeHeightConvolutionComponent(Op):
|
||||||
|
op = 'timeheightconvolutioncomponent'
|
||||||
|
enabled = False
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
|
super().__init__(graph, {
|
||||||
|
'type': None,
|
||||||
|
'op': self.op,
|
||||||
|
'infer': None,
|
||||||
|
'in_ports_count': 1,
|
||||||
|
'out_ports_count': 1,
|
||||||
|
}, attrs)
|
@ -0,0 +1,324 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.front.kaldi.replace_timeheightconvolution import ReplaceTimeHeightConvolutionPattern
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from unit_tests.utils.graph import build_graph, regular_op, connect_front, const
|
||||||
|
|
||||||
|
|
||||||
|
class TimeheightconvolutionReplacerTest(unittest.TestCase):
|
||||||
|
nodes = {
|
||||||
|
**regular_op('placeholder', {}),
|
||||||
|
**regular_op('timeheightconv', {'op': 'timeheightconvolutioncomponent'}),
|
||||||
|
**const('weights', int64_array([])),
|
||||||
|
**const('biases', int64_array([])),
|
||||||
|
**regular_op('placeholder_out', {}),
|
||||||
|
|
||||||
|
**regular_op('concat', {'type': 'Concat', 'axis': 1}),
|
||||||
|
**regular_op('memoryoffset_0', {'type': None, 'op': 'MemoryOffset', 't': -1, 'has_default': False}),
|
||||||
|
**regular_op('memoryoffset_1', {'type': None, 'op': 'MemoryOffset', 't': 0, 'has_default': False}),
|
||||||
|
**regular_op('memoryoffset_2', {'type': None, 'op': 'MemoryOffset', 't': 1, 'has_default': True}),
|
||||||
|
**regular_op('conv', {'op': 'Convolution', 'type': 'Convolution', 'output': 12, 'height_in': 80}),
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_timeheightconvolution_1offset(self):
|
||||||
|
graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', '0:timeheightconv'),
|
||||||
|
*connect_front('weights', '1:timeheightconv'),
|
||||||
|
*connect_front('biases', '2:timeheightconv'),
|
||||||
|
*connect_front('timeheightconv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
graph.stage = 'front'
|
||||||
|
|
||||||
|
conv = graph.nodes['timeheightconv']
|
||||||
|
conv['height_subsample'] = 1
|
||||||
|
conv['height_in'] = 80
|
||||||
|
conv['height_out'] = 80
|
||||||
|
conv['in_channels'] = 1
|
||||||
|
conv['out_channels'] = 12
|
||||||
|
conv['offsets'] = int64_array([[-1, -1], [-1, 0], [-1, 1]])
|
||||||
|
conv['time_offsets'] = [-1]
|
||||||
|
graph.nodes['weights']['value'] = np.zeros([36])
|
||||||
|
|
||||||
|
ref_graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', 'memoryoffset_0'),
|
||||||
|
*connect_front('memoryoffset_0', '0:concat'),
|
||||||
|
*connect_front('concat', '0:conv'),
|
||||||
|
*connect_front('weights', '1:conv'),
|
||||||
|
*connect_front('biases', '2:conv'),
|
||||||
|
*connect_front('conv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
ref_graph.nodes['weights']['value'] = np.zeros([36])
|
||||||
|
new_conv = ref_graph.nodes['conv']
|
||||||
|
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [1, 1]])
|
||||||
|
new_conv['dilation'] = int64_array([1, 1, 1, 1])
|
||||||
|
new_conv['kernel'] = int64_array([12, 1, 1, 3])
|
||||||
|
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||||
|
|
||||||
|
|
||||||
|
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_timeheightconvolution_2_offsets(self):
|
||||||
|
graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', '0:timeheightconv'),
|
||||||
|
*connect_front('weights', '1:timeheightconv'),
|
||||||
|
*connect_front('biases', '2:timeheightconv'),
|
||||||
|
*connect_front('timeheightconv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
graph.stage = 'front'
|
||||||
|
|
||||||
|
conv = graph.nodes['timeheightconv']
|
||||||
|
conv['height_subsample'] = 1
|
||||||
|
conv['height_in'] = 80
|
||||||
|
conv['height_out'] = 80
|
||||||
|
conv['in_channels'] = 1
|
||||||
|
conv['out_channels'] = 12
|
||||||
|
conv['offsets'] = int64_array([[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1]])
|
||||||
|
conv['time_offsets'] = int64_array([-1, 0])
|
||||||
|
graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
|
||||||
|
ref_graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', 'memoryoffset_0'),
|
||||||
|
*connect_front('placeholder', 'memoryoffset_1'),
|
||||||
|
*connect_front('memoryoffset_0', '0:concat'),
|
||||||
|
*connect_front('memoryoffset_1', '1:concat'),
|
||||||
|
*connect_front('concat', '0:conv'),
|
||||||
|
*connect_front('weights', '1:conv'),
|
||||||
|
*connect_front('biases', '2:conv'),
|
||||||
|
*connect_front('conv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
new_conv = ref_graph.nodes['conv']
|
||||||
|
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [1, 1]])
|
||||||
|
new_conv['dilation'] = int64_array([1, 1, 1, 1])
|
||||||
|
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||||
|
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||||
|
|
||||||
|
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_timeheightconvolution_2_offsets_def(self):
|
||||||
|
graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', '0:timeheightconv'),
|
||||||
|
*connect_front('weights', '1:timeheightconv'),
|
||||||
|
*connect_front('biases', '2:timeheightconv'),
|
||||||
|
*connect_front('timeheightconv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
graph.stage = 'front'
|
||||||
|
|
||||||
|
conv = graph.nodes['timeheightconv']
|
||||||
|
conv['height_subsample'] = 1
|
||||||
|
conv['height_in'] = 80
|
||||||
|
conv['height_out'] = 80
|
||||||
|
conv['in_channels'] = 1
|
||||||
|
conv['out_channels'] = 12
|
||||||
|
conv['offsets'] = int64_array([[0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]])
|
||||||
|
conv['time_offsets'] = int64_array([0])
|
||||||
|
graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
|
||||||
|
ref_graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', 'memoryoffset_1'),
|
||||||
|
*connect_front('placeholder', 'memoryoffset_2'),
|
||||||
|
*connect_front('memoryoffset_1', '0:concat'),
|
||||||
|
*connect_front('memoryoffset_2', '1:concat'),
|
||||||
|
*connect_front('concat', '0:conv'),
|
||||||
|
*connect_front('weights', '1:conv'),
|
||||||
|
*connect_front('biases', '2:conv'),
|
||||||
|
*connect_front('conv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
new_conv = ref_graph.nodes['conv']
|
||||||
|
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [1, 1]])
|
||||||
|
new_conv['dilation'] = int64_array([1, 1, 1, 1])
|
||||||
|
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||||
|
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||||
|
|
||||||
|
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_timeheightconvolution_2_offsets_dilation(self):
|
||||||
|
graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', '0:timeheightconv'),
|
||||||
|
*connect_front('weights', '1:timeheightconv'),
|
||||||
|
*connect_front('biases', '2:timeheightconv'),
|
||||||
|
*connect_front('timeheightconv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
graph.stage = 'front'
|
||||||
|
|
||||||
|
conv = graph.nodes['timeheightconv']
|
||||||
|
conv['height_subsample'] = 1
|
||||||
|
conv['height_in'] = 80
|
||||||
|
conv['height_out'] = 80
|
||||||
|
conv['in_channels'] = 1
|
||||||
|
conv['out_channels'] = 12
|
||||||
|
conv['offsets'] = int64_array([[-1, -3], [-1, 0], [-1, 3], [1, -3], [1, 0], [1, 3]])
|
||||||
|
conv['time_offsets'] = int64_array([-1])
|
||||||
|
graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
|
||||||
|
ref_graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', 'memoryoffset_0'),
|
||||||
|
*connect_front('placeholder', 'memoryoffset_2'),
|
||||||
|
*connect_front('memoryoffset_0', '0:concat'),
|
||||||
|
*connect_front('memoryoffset_2', '1:concat'),
|
||||||
|
*connect_front('concat', '0:conv'),
|
||||||
|
*connect_front('weights', '1:conv'),
|
||||||
|
*connect_front('biases', '2:conv'),
|
||||||
|
*connect_front('conv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
new_conv = ref_graph.nodes['conv']
|
||||||
|
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [3, 3]])
|
||||||
|
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||||
|
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||||
|
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||||
|
|
||||||
|
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_timeheightconvolution_2_offsets_pad(self):
|
||||||
|
graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', '0:timeheightconv'),
|
||||||
|
*connect_front('weights', '1:timeheightconv'),
|
||||||
|
*connect_front('biases', '2:timeheightconv'),
|
||||||
|
*connect_front('timeheightconv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
graph.stage = 'front'
|
||||||
|
conv = graph.nodes['timeheightconv']
|
||||||
|
conv['height_subsample'] = 1
|
||||||
|
conv['height_in'] = 80
|
||||||
|
conv['height_out'] = 74
|
||||||
|
conv['in_channels'] = 1
|
||||||
|
conv['out_channels'] = 12
|
||||||
|
conv['offsets'] = int64_array([[-1, 0], [-1, 3], [-1, 6], [1, 0], [1, 3], [1, 6]])
|
||||||
|
conv['time_offsets'] = int64_array([-1])
|
||||||
|
graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
|
||||||
|
ref_graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', 'memoryoffset_0'),
|
||||||
|
*connect_front('placeholder', 'memoryoffset_2'),
|
||||||
|
*connect_front('memoryoffset_0', '0:concat'),
|
||||||
|
*connect_front('memoryoffset_2', '1:concat'),
|
||||||
|
*connect_front('concat', '0:conv'),
|
||||||
|
*connect_front('weights', '1:conv'),
|
||||||
|
*connect_front('biases', '2:conv'),
|
||||||
|
*connect_front('conv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
new_conv = ref_graph.nodes['conv']
|
||||||
|
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||||
|
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||||
|
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||||
|
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||||
|
|
||||||
|
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_timeheightconvolution_out_channels(self):
|
||||||
|
graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', '0:timeheightconv'),
|
||||||
|
*connect_front('weights', '1:timeheightconv'),
|
||||||
|
*connect_front('biases', '2:timeheightconv'),
|
||||||
|
*connect_front('timeheightconv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
graph.stage = 'front'
|
||||||
|
conv = graph.nodes['timeheightconv']
|
||||||
|
conv['height_subsample'] = 1
|
||||||
|
conv['height_in'] = 80
|
||||||
|
conv['height_out'] = 74
|
||||||
|
conv['in_channels'] = 3
|
||||||
|
conv['out_channels'] = 4
|
||||||
|
conv['offsets'] = int64_array([[-1, 0], [-1, 3], [-1, 6], [1, 0], [1, 3], [1, 6]])
|
||||||
|
conv['time_offsets'] = int64_array([-1])
|
||||||
|
graph.nodes['weights']['value'] = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||||
|
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
||||||
|
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
||||||
|
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72])
|
||||||
|
|
||||||
|
ref_graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', 'memoryoffset_0'),
|
||||||
|
*connect_front('placeholder', 'memoryoffset_2'),
|
||||||
|
*connect_front('memoryoffset_0', '0:concat'),
|
||||||
|
*connect_front('memoryoffset_2', '1:concat'),
|
||||||
|
*connect_front('concat', '0:conv'),
|
||||||
|
*connect_front('weights', '1:conv'),
|
||||||
|
*connect_front('biases', '2:conv'),
|
||||||
|
*connect_front('conv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
ref_graph.nodes['weights']['value'] = np.array([1, 4, 7, 10, 13, 16, 2, 5, 8, 11, 14, 17, 3, 6, 9, 12, 15, 18,
|
||||||
|
19, 22, 25, 28, 31, 34, 20, 23, 26, 29, 32, 35, 21, 24, 27, 30, 33, 36,
|
||||||
|
37, 40, 43, 46, 49, 52, 38, 41, 44, 47, 50, 53, 39, 42, 45, 48, 51, 54,
|
||||||
|
55, 58, 61, 64, 67, 70, 56, 59, 62, 65, 68, 71, 57, 60, 63, 66, 69, 72])
|
||||||
|
new_conv = ref_graph.nodes['conv']
|
||||||
|
new_conv['output'] = 4
|
||||||
|
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||||
|
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||||
|
new_conv['kernel'] = int64_array([4, 3, 2, 3])
|
||||||
|
new_conv['stride'] = int64_array([1, 1, 1, 1])
|
||||||
|
|
||||||
|
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_timeheightconvolution_2_offsets_stride(self):
|
||||||
|
graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', '0:timeheightconv'),
|
||||||
|
*connect_front('weights', '1:timeheightconv'),
|
||||||
|
*connect_front('biases', '2:timeheightconv'),
|
||||||
|
*connect_front('timeheightconv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
graph.stage = 'front'
|
||||||
|
conv = graph.nodes['timeheightconv']
|
||||||
|
conv['height_subsample'] = 2
|
||||||
|
conv['height_in'] = 80
|
||||||
|
conv['height_out'] = 37
|
||||||
|
conv['in_channels'] = 1
|
||||||
|
conv['out_channels'] = 12
|
||||||
|
conv['offsets'] = int64_array([[-1, 0], [-1, 3], [-1, 6], [1, 0], [1, 3], [1, 6]])
|
||||||
|
conv['time_offsets'] = int64_array([-1])
|
||||||
|
graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
|
||||||
|
ref_graph = build_graph(self.nodes, [
|
||||||
|
*connect_front('placeholder', 'memoryoffset_0'),
|
||||||
|
*connect_front('placeholder', 'memoryoffset_2'),
|
||||||
|
*connect_front('memoryoffset_0', '0:concat'),
|
||||||
|
*connect_front('memoryoffset_2', '1:concat'),
|
||||||
|
*connect_front('concat', '0:conv'),
|
||||||
|
*connect_front('weights', '1:conv'),
|
||||||
|
*connect_front('biases', '2:conv'),
|
||||||
|
*connect_front('conv', 'placeholder_out')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
ref_graph.nodes['weights']['value'] = np.zeros([72])
|
||||||
|
new_conv = ref_graph.nodes['conv']
|
||||||
|
new_conv['pad'] = int64_array([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||||
|
new_conv['dilation'] = int64_array([1, 1, 2, 3])
|
||||||
|
new_conv['kernel'] = int64_array([12, 1, 2, 3])
|
||||||
|
new_conv['stride'] = int64_array([1, 1, 1, 2])
|
||||||
|
|
||||||
|
ReplaceTimeHeightConvolutionPattern().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_out', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
@ -22,7 +22,7 @@ class ConvolutionalComponentFrontExtractorTest(KaldiFrontExtractorTest):
|
|||||||
pb += KaldiFrontExtractorTest.write_tag_with_value('<PatchStride>', 4)
|
pb += KaldiFrontExtractorTest.write_tag_with_value('<PatchStride>', 4)
|
||||||
pb += KaldiFrontExtractorTest.generate_learn_info()
|
pb += KaldiFrontExtractorTest.generate_learn_info()
|
||||||
pb += b'<Filters> '
|
pb += b'<Filters> '
|
||||||
pb += KaldiFrontExtractorTest.generate_matrix([2, 1])
|
pb += KaldiFrontExtractorTest.generate_matrix([2, 4])
|
||||||
pb += b'<Bias> '
|
pb += b'<Bias> '
|
||||||
pb += KaldiFrontExtractorTest.generate_vector(2)
|
pb += KaldiFrontExtractorTest.generate_vector(2)
|
||||||
cls.test_node['parameters'] = TestKaldiUtilsLoading.bytesio_from(pb)
|
cls.test_node['parameters'] = TestKaldiUtilsLoading.bytesio_from(pb)
|
||||||
@ -50,6 +50,6 @@ class ConvolutionalComponentFrontExtractorTest(KaldiFrontExtractorTest):
|
|||||||
self.assertEqual(self.test_node[attr], val_attrs[attr])
|
self.assertEqual(self.test_node[attr], val_attrs[attr])
|
||||||
|
|
||||||
def test_convolution_blobs(self):
|
def test_convolution_blobs(self):
|
||||||
self.assertTrue(np.array_equal(self.test_node.weights, [0, 1]))
|
self.assertTrue(np.array_equal(self.test_node.weights, [0, 1, 2, 3, 4, 5, 6, 7]))
|
||||||
self.assertTrue(np.array_equal(self.test_node.biases, [0, 1]))
|
self.assertTrue(np.array_equal(self.test_node.biases, [0, 1]))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user