diff --git a/model-optimizer/extensions/front/kaldi/split_recurrent_memoryoffset.py b/model-optimizer/extensions/front/kaldi/split_recurrent_memoryoffset.py index 78f91a9aeb1..e10c1058fa3 100644 --- a/model-optimizer/extensions/front/kaldi/split_recurrent_memoryoffset.py +++ b/model-optimizer/extensions/front/kaldi/split_recurrent_memoryoffset.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,6 +57,13 @@ class SplitRecurrentMemoryOffset(FrontReplacementSubgraph): # MemoryOffset node is not in a recurrent block -- no splitting is needed return + # check that node has information for future partial infer + # element_size is set in loader based on dimensions of previous layer from original Kaldi model if not offset_node.has_valid('element_size'): - raise Error("In a recurrent block 'element_size' for node {} is not set".format(offset_node.id)) + # check if previous layer contains information about its shape in out-size + # out-size is set in extractor of some nodes like affinecomponent based on weight's size + if offset_node.in_port(0).get_source().node.has_valid('out-size'): + offset_node['element_size'] = offset_node.in_port(0).get_source().node['out-size'] + else: + raise Error("In a recurrent block 'element_size' for node {} is not set".format(offset_node.id)) SplitRecurrentMemoryOffset.split_offset(offset_node) diff --git a/model-optimizer/mo/front/kaldi/loader/loader.py b/model-optimizer/mo/front/kaldi/loader/loader.py index 0cbf193e87a..093b8b30c90 100644 --- a/model-optimizer/mo/front/kaldi/loader/loader.py +++ b/model-optimizer/mo/front/kaldi/loader/loader.py @@ -234,15 +234,18 @@ def load_components(file_descr, graph, component_layer_map=None): # it is separated in 2 parts to remove cycle from graph file_descr.seek(start_index) dim = 0 - try: - collect_until_token(file_descr, b'', size_search_zone=end_index - start_index) - cur_index = file_descr.tell() - if start_index < cur_index < end_index: - dim = read_binary_integer32_token(file_descr) - else: + dim_words = {b'', b''} + for dim_word in dim_words: + try: + collect_until_token(file_descr, dim_word, size_search_zone=end_index - start_index) + cur_index = file_descr.tell() + if start_index < cur_index < end_index: + dim = read_binary_integer32_token(file_descr) + break + else: + file_descr.seek(start_index) + except Error: file_descr.seek(start_index) - except Error: - file_descr.seek(start_index) if is_nnet3: if name in component_layer_map: