Files
openvino/model-optimizer/mo/front/mxnet/extractors/utils.py
Anastasia Popova cadff031d5 Fixed framework name attribute in mapping file. (#5046)
* Fixed framework name attribute for onnx, mxnet.

* Fixed framework name attribute for caffe.

* Removed unnecessary attribute setting from add_opoutput()

* Added identity nodes adding to outputs in mxnet loader.

* Removed unnecessary reformat.

* Removed unnecessary reformat.

* Added check for empty name.

* Used nodes indices instead of node names in loader.

* Code refactoring, small bug fixed.
2021-04-08 20:59:44 +03:00

220 lines
7.0 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import mxnet as mx
import numpy as np
from extensions.ops.elementwise import Elementwise
from mo.graph.graph import Node, Graph
from mo.ops.const import Const
from mo.utils.error import Error
from mo.utils.str_to import StrTo
from mo.utils.utils import refer_to_faq_msg
class AttrDictionary(object):
def __init__(self, dict):
self._dict = dict
def is_valid(self):
return not self._dict is None
def dict(self):
return self._dict
def add_dict(self, dict):
self._dict.update(dict)
def set(self, key, value):
self._dict[key] = value
def remove(self, key):
if key in self._dict:
del self._dict[key]
def str(self, key, default=None):
if not self.is_valid:
if default is None:
raise ValueError("Missing required parameter: " + key)
if key in self._dict:
return self._dict[key]
return default
def dtype(self, key, default=None):
if self.is_valid and key in self._dict:
return mxnet_str_dtype_to_np(self._dict[key])
return default
def bool(self, key, default=None):
attr = self.str(key, default)
if isinstance(attr, str):
if attr.isdigit():
return bool(int(attr))
return StrTo.bool(attr)
else:
return attr
def float(self, key, default=None):
return self.val(key, float, default)
def int(self, key, default=None):
return self.val(key, int, default)
def tuple(self, key, valtype=str, default=None):
attr = self.str(key, default)
if attr is None:
return default
if isinstance(attr, str):
if (not '(' in attr and not ')' in attr) and (not '[' in attr and not ']' in attr):
return (valtype(attr),)
if (not attr) or (not attr[1:-1].split(',')[0]):
return tuple([valtype(x) for x in default])
return StrTo.tuple(valtype, attr)
else:
return tuple([valtype(x) for x in attr])
def list(self, key, valtype, default=None, sep=","):
attr = self.str(key, default)
if isinstance(attr, list):
attr = [valtype(x) for x in attr]
return attr
else:
return StrTo.list(attr, valtype, sep)
def val(self, key, valtype, default=None):
attr = self.str(key, default)
attr = None if attr == 'None' else attr
if valtype is None:
return attr
else:
if not isinstance(attr, valtype) and attr is not None:
return valtype(attr)
else:
return attr
def has(self, key):
if not self.is_valid:
return False
else:
return key in self._dict
def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
edge_list = []
used_indices = set()
for in_port, src_node_id in enumerate(node['inputs']):
edge = create_mxnet_edge(index_node_key[src_node_id[0]], index_node_key[node_id], in_port, src_node_id[1],
nodes_list[src_node_id[0]]['name'])
edge_list.append(edge)
used_indices.add(src_node_id[0])
return edge_list, used_indices
def create_mxnet_edge(src_node_id: str, dst_node_id: str, src_port: int, dst_port: int, framework_name: str):
edge_attrs = {
'in': src_port,
'out': dst_port,
# debug anchor for framework name, out port and tensor name
'fw_tensor_debug_info': [(framework_name, dst_port, framework_name)],
'in_attrs': ['in'],
'out_attrs': ['out'],
'data_attrs': ['fw_tensor_debug_info']
}
return src_node_id, dst_node_id, edge_attrs
def get_mxnet_layer_attrs(json_dic: dict):
attr = 'param'
if 'attr' in json_dic:
attr = 'attr'
elif 'attrs' in json_dic:
attr = 'attrs'
return AttrDictionary(json_dic[attr] if attr in json_dic else {})
def get_json_layer_attrs(json_dic):
attr = 'param'
if 'attr' in json_dic:
attr = 'attr'
elif 'attrs' in json_dic:
attr = 'attrs'
return json_dic[attr]
def load_params(input_model, data_names=('data',)):
arg_params = {}
aux_params = {}
arg_keys = []
aux_keys = []
file_format = input_model.split('.')[-1]
loaded_weight = mx.nd.load(input_model)
if file_format == 'params':
for key in loaded_weight:
keys = key.split(':')
if len(keys) > 1 and 'aux' == keys[0]:
aux_keys.append(keys[1])
aux_params[keys[1]] = loaded_weight[key]
elif len(keys) > 1 and 'arg' == keys[0]:
arg_keys.append(keys[1])
arg_params[keys[1]] = loaded_weight[key]
else:
arg_keys.append(key)
arg_params[key] = loaded_weight[key]
elif file_format == 'nd':
for key in loaded_weight:
if 'auxs' in input_model:
aux_keys.append(key)
aux_params[key] = loaded_weight[key]
elif 'args' in input_model:
arg_keys.append(key)
arg_params[key] = loaded_weight[key]
else:
raise Error(
'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
refer_to_faq_msg(85), file_format)
data = mx.sym.Variable(data_names[0])
model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
model_params._arg_params = arg_params
model_params._aux_params = aux_params
model_params._param_names = arg_keys
model_params._aux_names = aux_keys
return model_params
def init_rnn_states(model_nodes):
states = {}
for i, node in enumerate(model_nodes):
if node['op'] == 'RNN':
for i in node['inputs'][2:]:
attrs = get_mxnet_layer_attrs(model_nodes[i[0]])
shape = attrs.tuple('__shape__', int, None)
if shape:
states.update({model_nodes[i[0]]['name']: shape})
return states
def scalar_ops_replacer(graph: Graph, node: Node, elementwise_op_type=Elementwise):
scalar_value = Const(graph, dict(value=node.scalar,
symbol_dict={'name': node.id + '/const'})).create_node()
lin_node = elementwise_op_type(graph, dict(name=node.id + '/lin_', symbol_dict={'name': node.id + '/lin_'})
).create_node()
node.in_port(0).get_connection().set_destination(lin_node.in_port(0))
lin_node.in_port(1).get_connection().set_source(scalar_value.out_port(0))
node.out_port(0).get_connection().set_source(lin_node.out_port(0))
return lin_node
MXNET_DATA_TYPES = {
'float16': np.float16,
'float32': np.float32,
'float64': np.float64,
'int8': np.int8,
'int32': np.int32,
'int64': np.int64,
}
def mxnet_str_dtype_to_np(dtype: str):
return MXNET_DATA_TYPES[dtype]