[POT] Transformations to remove and add Convert operations in graph (#8672)
* Added special passes to remove and add Convert operations in POT * Update passes * Implement changes to support new FP16 models * Apply codestyle patch * Revert Cast inserting and add data parameter * Update FastBC rule Co-authored-by: Malinin, Nikita <nikita.malinin@intel.com>
This commit is contained in:
parent
e25c10075c
commit
163bc458db
@ -2,11 +2,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
from mo.graph.graph import Graph
|
||||
|
||||
from openvino.tools.pot.graph.node_utils import get_node_inputs
|
||||
from .editor import create_node, connect_nodes_by_name
|
||||
from openvino.tools.pot.graph.node_utils import get_node_input, get_node_inputs
|
||||
from .editor import create_node, connect_nodes_by_name, get_node_by_name
|
||||
|
||||
|
||||
def build_graph(graph_attrs, meta_data, nodes, edges):
|
||||
@ -139,4 +140,12 @@ def build_graph_for_node(model, input_name, input_shape, node, remove_bias=False
|
||||
edges.append((node.name, result_name, {'out': 0, 'in': 0}))
|
||||
graph = build_graph(*make_copy_graph_attrs(model, input_name, input_shape), nodes, edges)
|
||||
graph.ir_v10 = True
|
||||
|
||||
# Add the neccessary attribute to the new graph
|
||||
src_node = get_node_by_name(graph, node.name)
|
||||
weights_node = get_node_input(src_node, 1)
|
||||
weights_node = get_node_input(weights_node, 0) \
|
||||
if weights_node.type == 'FakeQuantize' else weights_node
|
||||
if weights_node.out_port(0).get_data_type() == np.float16:
|
||||
weights_node.out_node(0)['Insert_Convert_operation_after'] = True
|
||||
return graph
|
||||
|
@ -10,7 +10,7 @@ from mo.utils.logger import init_logger
|
||||
from openvino.inference_engine import IECore # pylint: disable=E0611
|
||||
from openvino.offline_transformations import ApplyPOTTransformations # pylint: disable=import-error,no-name-in-module
|
||||
|
||||
from ..graph.passes import ModelPreprocessor
|
||||
from ..graph.passes import ModelPreprocessor, remove_converts, add_removed_converts
|
||||
from ..utils.logger import stdout_redirect
|
||||
|
||||
init_logger('ERROR', False)
|
||||
@ -47,6 +47,7 @@ def load_graph(model_config, target_device='ANY'):
|
||||
graph_from_ir.meta_data = meta_data
|
||||
graph_from_ir.ir_v10 = True
|
||||
graph_from_ir.graph['cmd_params'] = orig_graph_from_ir.graph['cmd_params']
|
||||
remove_converts(graph_from_ir)
|
||||
model_preprocessing(graph_from_ir)
|
||||
if os.path.exists(serialized_xml_path):
|
||||
os.remove(serialized_xml_path)
|
||||
@ -71,8 +72,9 @@ def save_graph(graph: Graph, save_path, model_name=None):
|
||||
if not os.access(save_path, os.W_OK):
|
||||
raise PermissionError(
|
||||
'Output directory {} is not writable for the current user. '.format(save_path))
|
||||
|
||||
save_restored_graph(graph=deepcopy(graph), path=save_path, meta_data=graph.meta_data,
|
||||
graph_copy = deepcopy(graph)
|
||||
add_removed_converts(graph_copy)
|
||||
save_restored_graph(graph=graph_copy, path=save_path, meta_data=graph.meta_data,
|
||||
name=model_name)
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ import numpy as np
|
||||
from extensions.back.ForceStrictPrecision import ForceStrictPrecision
|
||||
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
|
||||
from extensions.ops.elementwise import Add
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.fakequantize import FakeQuantize
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
@ -20,6 +21,7 @@ from mo.graph.graph import Graph, Node
|
||||
from mo.graph.port import Port
|
||||
from mo.middle.pattern_match import apply_pattern
|
||||
from mo.ops.const import Const
|
||||
from mo.middle.passes.convert_data_type import convert_blob
|
||||
from mo.middle.passes.infer import type_infer
|
||||
|
||||
from . import editor as ge
|
||||
@ -704,6 +706,7 @@ def create_bias_node(graph: Graph, src_node):
|
||||
|
||||
for destination_port in destination_ports:
|
||||
add_op.out_port(0).connect(destination_port)
|
||||
add_bias.out_node(0)['Insert_Convert_operation_after'] = True
|
||||
|
||||
|
||||
def create_fake_quantize_node(graph: Graph, name):
|
||||
@ -878,3 +881,39 @@ def find_shape_subgraph_endpoints(out_ports: List[Port], visited: set = None) ->
|
||||
visited_nodes.add(in_port.node)
|
||||
visited.add(in_port)
|
||||
return visited_nodes
|
||||
|
||||
|
||||
def remove_converts(graph: Graph):
|
||||
for op in graph.get_op_nodes(type='Convert'):
|
||||
source_op = op.in_port(0).get_source().node
|
||||
if source_op.type == 'Const' and source_op.data_type == np.float16:
|
||||
# Get access to data node after Convert operation and set Insert_Convert_operation_after
|
||||
# to restore Convert operation later
|
||||
op.out_node(0)['Insert_Convert_operation_after'] = True
|
||||
# Mark Const and Convert operation to fold them
|
||||
source_op['need_shape_inference'] = True
|
||||
op['stop_value_propagation'] = False
|
||||
op['need_shape_inference'] = True
|
||||
graph.clean_up()
|
||||
|
||||
|
||||
def add_removed_converts(graph: Graph):
|
||||
for data_node_name in graph.get_nodes_with_attributes(Insert_Convert_operation_after=True):
|
||||
data_node = Node(graph, data_node_name)
|
||||
# Get access to Const node connected to data node
|
||||
const_op = data_node.in_node(0)
|
||||
assert const_op.data_type == np.float32, "Error when try to insert Convert operation after Const: {}".\
|
||||
format(const_op.soft_get('name'))
|
||||
|
||||
convert_op = Cast(graph, {'dst_type': np.float32,
|
||||
'name': const_op.name + '/restored_convert',
|
||||
'stop_value_propagation': True}).create_node()
|
||||
|
||||
# Insert Convert operation after Const operation
|
||||
consumer_port = const_op.out_port(0).get_connection().get_destination()
|
||||
const_op.out_port(0).get_connection().set_destination(convert_op.in_port(0))
|
||||
convert_op.out_port(0).connect(consumer_port)
|
||||
|
||||
# Convert Const value to FP32 to make types in graph consistent
|
||||
const_op.value, _, _ = convert_blob(const_op.value, np.float16)
|
||||
const_op.infer(const_op)
|
||||
|
Loading…
Reference in New Issue
Block a user