* Fixed framework name attribute for onnx, mxnet. * Fixed framework name attribute for caffe. * Removed unnecessary attribute setting from add_opoutput() * Added identity nodes adding to outputs in mxnet loader. * Removed unnecessary reformat. * Removed unnecessary reformat. * Added check for empty name. * Used nodes indices instead of node names in loader. * Code refactoring, small bug fixed.
220 lines
7.0 KiB
Python
220 lines
7.0 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import mxnet as mx
|
|
import numpy as np
|
|
|
|
from extensions.ops.elementwise import Elementwise
|
|
from mo.graph.graph import Node, Graph
|
|
from mo.ops.const import Const
|
|
from mo.utils.error import Error
|
|
from mo.utils.str_to import StrTo
|
|
from mo.utils.utils import refer_to_faq_msg
|
|
|
|
|
|
class AttrDictionary(object):
|
|
def __init__(self, dict):
|
|
self._dict = dict
|
|
|
|
def is_valid(self):
|
|
return not self._dict is None
|
|
|
|
def dict(self):
|
|
return self._dict
|
|
|
|
def add_dict(self, dict):
|
|
self._dict.update(dict)
|
|
|
|
def set(self, key, value):
|
|
self._dict[key] = value
|
|
|
|
def remove(self, key):
|
|
if key in self._dict:
|
|
del self._dict[key]
|
|
|
|
def str(self, key, default=None):
|
|
if not self.is_valid:
|
|
if default is None:
|
|
raise ValueError("Missing required parameter: " + key)
|
|
if key in self._dict:
|
|
return self._dict[key]
|
|
return default
|
|
|
|
def dtype(self, key, default=None):
|
|
if self.is_valid and key in self._dict:
|
|
return mxnet_str_dtype_to_np(self._dict[key])
|
|
return default
|
|
|
|
def bool(self, key, default=None):
|
|
attr = self.str(key, default)
|
|
if isinstance(attr, str):
|
|
if attr.isdigit():
|
|
return bool(int(attr))
|
|
return StrTo.bool(attr)
|
|
else:
|
|
return attr
|
|
|
|
def float(self, key, default=None):
|
|
return self.val(key, float, default)
|
|
|
|
def int(self, key, default=None):
|
|
return self.val(key, int, default)
|
|
|
|
def tuple(self, key, valtype=str, default=None):
|
|
attr = self.str(key, default)
|
|
if attr is None:
|
|
return default
|
|
if isinstance(attr, str):
|
|
if (not '(' in attr and not ')' in attr) and (not '[' in attr and not ']' in attr):
|
|
return (valtype(attr),)
|
|
if (not attr) or (not attr[1:-1].split(',')[0]):
|
|
return tuple([valtype(x) for x in default])
|
|
return StrTo.tuple(valtype, attr)
|
|
else:
|
|
return tuple([valtype(x) for x in attr])
|
|
|
|
def list(self, key, valtype, default=None, sep=","):
|
|
attr = self.str(key, default)
|
|
if isinstance(attr, list):
|
|
attr = [valtype(x) for x in attr]
|
|
return attr
|
|
else:
|
|
return StrTo.list(attr, valtype, sep)
|
|
|
|
def val(self, key, valtype, default=None):
|
|
attr = self.str(key, default)
|
|
attr = None if attr == 'None' else attr
|
|
if valtype is None:
|
|
return attr
|
|
else:
|
|
if not isinstance(attr, valtype) and attr is not None:
|
|
return valtype(attr)
|
|
else:
|
|
return attr
|
|
|
|
def has(self, key):
|
|
if not self.is_valid:
|
|
return False
|
|
else:
|
|
return key in self._dict
|
|
|
|
|
|
def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
|
|
edge_list = []
|
|
used_indices = set()
|
|
for in_port, src_node_id in enumerate(node['inputs']):
|
|
edge = create_mxnet_edge(index_node_key[src_node_id[0]], index_node_key[node_id], in_port, src_node_id[1],
|
|
nodes_list[src_node_id[0]]['name'])
|
|
edge_list.append(edge)
|
|
used_indices.add(src_node_id[0])
|
|
return edge_list, used_indices
|
|
|
|
|
|
def create_mxnet_edge(src_node_id: str, dst_node_id: str, src_port: int, dst_port: int, framework_name: str):
|
|
edge_attrs = {
|
|
'in': src_port,
|
|
'out': dst_port,
|
|
# debug anchor for framework name, out port and tensor name
|
|
'fw_tensor_debug_info': [(framework_name, dst_port, framework_name)],
|
|
'in_attrs': ['in'],
|
|
'out_attrs': ['out'],
|
|
'data_attrs': ['fw_tensor_debug_info']
|
|
}
|
|
return src_node_id, dst_node_id, edge_attrs
|
|
|
|
|
|
def get_mxnet_layer_attrs(json_dic: dict):
|
|
attr = 'param'
|
|
if 'attr' in json_dic:
|
|
attr = 'attr'
|
|
elif 'attrs' in json_dic:
|
|
attr = 'attrs'
|
|
return AttrDictionary(json_dic[attr] if attr in json_dic else {})
|
|
|
|
|
|
def get_json_layer_attrs(json_dic):
|
|
attr = 'param'
|
|
if 'attr' in json_dic:
|
|
attr = 'attr'
|
|
elif 'attrs' in json_dic:
|
|
attr = 'attrs'
|
|
return json_dic[attr]
|
|
|
|
|
|
def load_params(input_model, data_names=('data',)):
|
|
arg_params = {}
|
|
aux_params = {}
|
|
arg_keys = []
|
|
aux_keys = []
|
|
file_format = input_model.split('.')[-1]
|
|
loaded_weight = mx.nd.load(input_model)
|
|
if file_format == 'params':
|
|
for key in loaded_weight:
|
|
keys = key.split(':')
|
|
if len(keys) > 1 and 'aux' == keys[0]:
|
|
aux_keys.append(keys[1])
|
|
aux_params[keys[1]] = loaded_weight[key]
|
|
elif len(keys) > 1 and 'arg' == keys[0]:
|
|
arg_keys.append(keys[1])
|
|
arg_params[keys[1]] = loaded_weight[key]
|
|
else:
|
|
arg_keys.append(key)
|
|
arg_params[key] = loaded_weight[key]
|
|
elif file_format == 'nd':
|
|
for key in loaded_weight:
|
|
if 'auxs' in input_model:
|
|
aux_keys.append(key)
|
|
aux_params[key] = loaded_weight[key]
|
|
elif 'args' in input_model:
|
|
arg_keys.append(key)
|
|
arg_params[key] = loaded_weight[key]
|
|
else:
|
|
raise Error(
|
|
'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
|
|
refer_to_faq_msg(85), file_format)
|
|
|
|
data = mx.sym.Variable(data_names[0])
|
|
model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
|
|
model_params._arg_params = arg_params
|
|
model_params._aux_params = aux_params
|
|
model_params._param_names = arg_keys
|
|
model_params._aux_names = aux_keys
|
|
return model_params
|
|
|
|
|
|
def init_rnn_states(model_nodes):
|
|
states = {}
|
|
for i, node in enumerate(model_nodes):
|
|
if node['op'] == 'RNN':
|
|
for i in node['inputs'][2:]:
|
|
attrs = get_mxnet_layer_attrs(model_nodes[i[0]])
|
|
shape = attrs.tuple('__shape__', int, None)
|
|
if shape:
|
|
states.update({model_nodes[i[0]]['name']: shape})
|
|
return states
|
|
|
|
|
|
def scalar_ops_replacer(graph: Graph, node: Node, elementwise_op_type=Elementwise):
|
|
scalar_value = Const(graph, dict(value=node.scalar,
|
|
symbol_dict={'name': node.id + '/const'})).create_node()
|
|
lin_node = elementwise_op_type(graph, dict(name=node.id + '/lin_', symbol_dict={'name': node.id + '/lin_'})
|
|
).create_node()
|
|
node.in_port(0).get_connection().set_destination(lin_node.in_port(0))
|
|
lin_node.in_port(1).get_connection().set_source(scalar_value.out_port(0))
|
|
node.out_port(0).get_connection().set_source(lin_node.out_port(0))
|
|
return lin_node
|
|
|
|
|
|
MXNET_DATA_TYPES = {
|
|
'float16': np.float16,
|
|
'float32': np.float32,
|
|
'float64': np.float64,
|
|
'int8': np.int8,
|
|
'int32': np.int32,
|
|
'int64': np.int64,
|
|
}
|
|
|
|
|
|
def mxnet_str_dtype_to_np(dtype: str):
|
|
return MXNET_DATA_TYPES[dtype]
|