137 lines
5.5 KiB
Python
137 lines
5.5 KiB
Python
"""
|
|
Copyright (C) 2018-2020 Intel Corporation
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import json
|
|
import logging as log
|
|
import os
|
|
|
|
import mxnet as mx
|
|
import numpy as np
|
|
|
|
from mo.front.mxnet.extractor import common_mxnet_fields
|
|
from mo.front.mxnet.extractors.utils import get_mxnet_node_edges, load_params, init_rnn_states
|
|
from mo.front.mxnet.nd_to_params import build_params_file
|
|
from mo.graph.graph import Node, Graph
|
|
from mo.utils.error import Error
|
|
from mo.utils.utils import refer_to_faq_msg
|
|
|
|
|
|
def load_symbol_nodes(model_name, input_symbol: str = None, legacy_mxnet_model: bool = False):
|
|
if input_symbol:
|
|
json_name = input_symbol
|
|
if legacy_mxnet_model:
|
|
log.warning('If you use --input_symbol with legacy MXNet models be sure that symbol and param names ' +
|
|
'have correct format supported by MXNet')
|
|
else:
|
|
json_name = '%s-symbol.json' % model_name
|
|
input_symbol = json_name
|
|
|
|
if legacy_mxnet_model and (input_symbol == json_name):
|
|
log.warning('For legacy MXNet models Model Optimizer does not support conversion of old MXNet models' +
|
|
'(trained with 1.0.0 version of MXNet and lower) with custom layers. ' +
|
|
refer_to_faq_msg(93))
|
|
sym = mx.symbol.load(json_name)
|
|
model_nodes = json.loads(sym.tojson())
|
|
else:
|
|
if os.path.isfile(json_name):
|
|
model_nodes = json.load(open(json_name))
|
|
else:
|
|
raise Error('Specified input json {} does not exist. ' +
|
|
refer_to_faq_msg(84), json_name)
|
|
|
|
return model_nodes['nodes']
|
|
|
|
|
|
def parse_input_model(input_model):
|
|
path_wo_ext = '.'.join(input_model.split('.')[:-1])
|
|
model_name_w_iter = path_wo_ext.split(os.sep)[-1]
|
|
iteration_number = int(model_name_w_iter.split('-')[-1])
|
|
model_name = '-'.join(path_wo_ext.split('-')[:-1])
|
|
return model_name, iteration_number
|
|
|
|
|
|
def load_symbol_def(input_model_name, input_symbol, input_names: str = '', nd_prefix_name: str = '', pretrained_model_name: str = '', legacy_mxnet_model: bool = False):
|
|
if not nd_prefix_name and not pretrained_model_name:
|
|
# model name always has extension 'param'
|
|
try:
|
|
model_name, iteration_number = parse_input_model(input_model_name)
|
|
except ValueError as err:
|
|
raise Error(
|
|
'Input model name {} is not in an expected format, cannot extract iteration number. ' +
|
|
refer_to_faq_msg(48),
|
|
input_model_name)
|
|
|
|
if input_names:
|
|
model_params = load_params(input_model_name, data_names=input_names.split(','))
|
|
else:
|
|
model_params = load_params(input_model_name)
|
|
|
|
elif nd_prefix_name and pretrained_model_name and input_symbol:
|
|
model_name, iteration_number = parse_input_model(pretrained_model_name)
|
|
model_name = '-'.join(input_symbol.split('-')[:-1])
|
|
model_params = build_params_file(nd_prefix_name, pretrained_model_name, input_names)
|
|
else:
|
|
raise Error(
|
|
"Arguments --nd_prefix_name, --pretrained_model_name and --input_symbol should be provided. Please provide all or do not use any. " +
|
|
refer_to_faq_msg(81))
|
|
|
|
model_nodes = load_symbol_nodes(model_name, input_symbol, legacy_mxnet_model)
|
|
|
|
return model_nodes, model_params, model_name, iteration_number
|
|
|
|
|
|
def symbol_attrs(symbol_node):
|
|
return {'symbol_dict': symbol_node}
|
|
|
|
|
|
def symbol2nx(graph, model_nodes, model_params, input_names: str = ''):
|
|
if not input_names:
|
|
input_names = ('data',)
|
|
else:
|
|
input_names = input_names.split(',')
|
|
|
|
rnn_states = init_rnn_states(model_nodes)
|
|
names_rnn_states = list(rnn_states.keys())
|
|
|
|
# as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of graph node
|
|
index_node_keys = {}
|
|
for i, node in enumerate(model_nodes):
|
|
if node['name'] in model_params._arg_params and node['name'] not in input_names:
|
|
node['value'] = np.array(model_params._arg_params[node['name']].asnumpy(), dtype=np.float32)
|
|
elif node['name'] in model_params._aux_params and node['name'] not in input_names:
|
|
node['value'] = np.array(model_params._aux_params[node['name']].asnumpy(), dtype=np.float32)
|
|
elif node['name'] in names_rnn_states:
|
|
node['value'] = np.zeros(rnn_states[node['name']])
|
|
node_name = graph.unique_id(node['name'])
|
|
graph.add_node(node_name, **symbol_attrs(node))
|
|
graph.node[node_name].update(common_mxnet_fields(Node(graph, node_name)))
|
|
index_node_keys[i] = node_name
|
|
|
|
for i, attrs in enumerate(model_nodes):
|
|
node = attrs
|
|
edges = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys)
|
|
if len(edges) > 0:
|
|
graph.add_edges_from(edges)
|
|
|
|
return graph
|
|
|
|
|
|
def find_output_node(graph: Graph, src_input_index):
|
|
for i, attrs in (list(graph.nodes(data=True))[src_input_index + 1:]):
|
|
for input_index in attrs['symbol_dict']['inputs']:
|
|
if input_index[0] == src_input_index:
|
|
return i
|