Files
openvino/model-optimizer/extensions/back/CellNormalizer.py
Alexey Suhov 6478f1742a Align copyright notice in python scripts (CVS-51320) (#4974)
* Align copyright notice in python scripts (CVS-51320)
2021-03-26 17:54:28 +03:00

70 lines
2.8 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.split import VariadicSplit
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.ops.reshape import Reshape
class CellNormalizer(BackReplacementPattern):
# This class splits WR input on W and R for LSTMCell, GRUCell, RNNCell
enabled = True
force_clean_up = True
def pattern(self):
return dict(
nodes=[
('cell', dict(type=lambda type: type in ['LSTMCell', 'GRUCell', 'RNNCell']))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['cell']
cell_name = node.soft_get('name', node.id)
cell_type = node.soft_get('type')
WR_input_id = node.soft_get('wr_input_id')
hidden_size_coef = node.soft_get('gates_count')
hidden_size = node.get_attrs()["hidden_size"]
# default values for RNNCell/GRUCell
additional_port_id = 4
if cell_type == "LSTMCell":
additional_port_id = 5
WR_shape = node.in_port(WR_input_id).data.get_shape()
assert WR_shape is not None, "Undefined 'WR' input shape for Cell node '{}'".format(cell_name)
num_elements_in_WR = np.prod(WR_shape)
input_size = (num_elements_in_WR / (hidden_size_coef * hidden_size)) - hidden_size
# Reshape
reshape = create_op_node_with_second_input(graph, Reshape,
int64_array([hidden_size_coef * hidden_size,
hidden_size + input_size]),
{'name': cell_name + '/Dims'})
# VariadicSplit
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(1),
2: int64_array([input_size, hidden_size])},
{'out_ports_count': 2, 'name': cell_name + '/Split'},
reshape)
# Cell
node.in_port(WR_input_id).get_connection().set_destination(reshape.in_port(0))
node.add_input_port(additional_port_id, skip_if_exist=True)
assert node.in_port(additional_port_id).disconnected()
# (x, y, WR, B) -> (x, y, W, R, B(additional_port))
node.in_port(additional_port_id - 1).get_connection().set_destination(node.in_port(additional_port_id))
split.out_port(0).connect(node.in_port(additional_port_id - 2))
split.out_port(1).connect(node.in_port(additional_port_id - 1))