[POT] Transformations to remove and add Convert operations in graph (#8672)

* Added special passes to remove and add Convert operations in POT

* Update passes

* Implement changes to support new FP16 models

* Apply codestyle patch

* Revert Cast inserting and add data parameter

* Update FastBC rule

Co-authored-by: Malinin, Nikita <nikita.malinin@intel.com>
This commit is contained in:
Anton Chetverikov 2021-11-25 12:59:39 +03:00 committed by GitHub
parent e25c10075c
commit 163bc458db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 5 deletions

View File

@ -2,11 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
import numpy as np
from mo.graph.graph import Graph
from openvino.tools.pot.graph.node_utils import get_node_inputs
from .editor import create_node, connect_nodes_by_name
from openvino.tools.pot.graph.node_utils import get_node_input, get_node_inputs
from .editor import create_node, connect_nodes_by_name, get_node_by_name
def build_graph(graph_attrs, meta_data, nodes, edges):
@ -139,4 +140,12 @@ def build_graph_for_node(model, input_name, input_shape, node, remove_bias=False
edges.append((node.name, result_name, {'out': 0, 'in': 0}))
graph = build_graph(*make_copy_graph_attrs(model, input_name, input_shape), nodes, edges)
graph.ir_v10 = True
# Add the neccessary attribute to the new graph
src_node = get_node_by_name(graph, node.name)
weights_node = get_node_input(src_node, 1)
weights_node = get_node_input(weights_node, 0) \
if weights_node.type == 'FakeQuantize' else weights_node
if weights_node.out_port(0).get_data_type() == np.float16:
weights_node.out_node(0)['Insert_Convert_operation_after'] = True
return graph

View File

@ -10,7 +10,7 @@ from mo.utils.logger import init_logger
from openvino.inference_engine import IECore # pylint: disable=E0611
from openvino.offline_transformations import ApplyPOTTransformations # pylint: disable=import-error,no-name-in-module
from ..graph.passes import ModelPreprocessor
from ..graph.passes import ModelPreprocessor, remove_converts, add_removed_converts
from ..utils.logger import stdout_redirect
init_logger('ERROR', False)
@ -47,6 +47,7 @@ def load_graph(model_config, target_device='ANY'):
graph_from_ir.meta_data = meta_data
graph_from_ir.ir_v10 = True
graph_from_ir.graph['cmd_params'] = orig_graph_from_ir.graph['cmd_params']
remove_converts(graph_from_ir)
model_preprocessing(graph_from_ir)
if os.path.exists(serialized_xml_path):
os.remove(serialized_xml_path)
@ -71,8 +72,9 @@ def save_graph(graph: Graph, save_path, model_name=None):
if not os.access(save_path, os.W_OK):
raise PermissionError(
'Output directory {} is not writable for the current user. '.format(save_path))
save_restored_graph(graph=deepcopy(graph), path=save_path, meta_data=graph.meta_data,
graph_copy = deepcopy(graph)
add_removed_converts(graph_copy)
save_restored_graph(graph=graph_copy, path=save_path, meta_data=graph.meta_data,
name=model_name)

View File

@ -13,6 +13,7 @@ import numpy as np
from extensions.back.ForceStrictPrecision import ForceStrictPrecision
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
from extensions.ops.elementwise import Add
from extensions.ops.Cast import Cast
from extensions.ops.fakequantize import FakeQuantize
from mo.back.replacement import BackReplacementPattern
from mo.front.common.replacement import FrontReplacementSubgraph
@ -20,6 +21,7 @@ from mo.graph.graph import Graph, Node
from mo.graph.port import Port
from mo.middle.pattern_match import apply_pattern
from mo.ops.const import Const
from mo.middle.passes.convert_data_type import convert_blob
from mo.middle.passes.infer import type_infer
from . import editor as ge
@ -704,6 +706,7 @@ def create_bias_node(graph: Graph, src_node):
for destination_port in destination_ports:
add_op.out_port(0).connect(destination_port)
add_bias.out_node(0)['Insert_Convert_operation_after'] = True
def create_fake_quantize_node(graph: Graph, name):
@ -878,3 +881,39 @@ def find_shape_subgraph_endpoints(out_ports: List[Port], visited: set = None) ->
visited_nodes.add(in_port.node)
visited.add(in_port)
return visited_nodes
def remove_converts(graph: Graph):
for op in graph.get_op_nodes(type='Convert'):
source_op = op.in_port(0).get_source().node
if source_op.type == 'Const' and source_op.data_type == np.float16:
# Get access to data node after Convert operation and set Insert_Convert_operation_after
# to restore Convert operation later
op.out_node(0)['Insert_Convert_operation_after'] = True
# Mark Const and Convert operation to fold them
source_op['need_shape_inference'] = True
op['stop_value_propagation'] = False
op['need_shape_inference'] = True
graph.clean_up()
def add_removed_converts(graph: Graph):
for data_node_name in graph.get_nodes_with_attributes(Insert_Convert_operation_after=True):
data_node = Node(graph, data_node_name)
# Get access to Const node connected to data node
const_op = data_node.in_node(0)
assert const_op.data_type == np.float32, "Error when try to insert Convert operation after Const: {}".\
format(const_op.soft_get('name'))
convert_op = Cast(graph, {'dst_type': np.float32,
'name': const_op.name + '/restored_convert',
'stop_value_propagation': True}).create_node()
# Insert Convert operation after Const operation
consumer_port = const_op.out_port(0).get_connection().get_destination()
const_op.out_port(0).get_connection().set_destination(convert_op.in_port(0))
convert_op.out_port(0).connect(consumer_port)
# Convert Const value to FP32 to make types in graph consistent
const_op.value, _, _ = convert_blob(const_op.value, np.float16)
const_op.infer(const_op)