type cast for Pow according to ONNX spec

This commit is contained in:
Pavel Esir 2021-06-01 13:52:25 +03:00
parent 9ef719ef8b
commit bf85ebf84b
5 changed files with 144 additions and 7 deletions

View File

@ -63,7 +63,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
return shape_accepting_nodes return shape_accepting_nodes
@staticmethod @staticmethod
def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], condition: Callable) -> List[Node]: def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], filter_condition: Callable) -> List[Node]:
shape_accepting_ops = MarkNodesWithShapeValues.get_shape_accepting_ops() shape_accepting_ops = MarkNodesWithShapeValues.get_shape_accepting_ops()
sources = [] sources = []
for node in nodes_with_shape_inputs: for node in nodes_with_shape_inputs:
@ -72,7 +72,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
continue continue
source_node = node.in_port(port_idx).get_source().node source_node = node.in_port(port_idx).get_source().node
if not condition(source_node): if not filter_condition(source_node):
continue continue
sources.append(source_node) sources.append(source_node)

View File

@ -5,6 +5,7 @@ import logging as log
import numpy as np import numpy as np
from extensions.ops.Cast import Cast
from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer
from mo.graph.graph import Graph, Node from mo.graph.graph import Graph, Node
from mo.middle.passes.infer import copy_type_infer from mo.middle.passes.infer import copy_type_infer
@ -37,6 +38,36 @@ def override_data_type_of_constant(node: Node):
convert_const_node_value_type(node_to_convert, dst_type) convert_const_node_value_type(node_to_convert, dst_type)
def override_input_types_of_pow(pow_node: Node):
# according to ONNX specification Pow op can have different input types
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pow (Examples/types)
base_type = pow_node.in_port(0).get_data_type()
exponent_type = pow_node.in_port(1).get_data_type()
if type(base_type) != np.dtype:
base_type = np.dtype(base_type)
if type(exponent_type) != np.dtype:
exponent_type = np.dtype(exponent_type)
prefix = pow_node.soft_get('name', pow_node.id) + '/Cast_{}_to_{}_type'
if base_type != exponent_type:
if base_type.itemsize >= exponent_type.itemsize:
cast_to_base_type = Cast(pow_node.graph, {'name': prefix.format('exponent', 'base'),
'dst_type': base_type}).create_node()
pow_node.in_port(1).get_connection().insert_node(cast_to_base_type)
Cast.type_infer(cast_to_base_type)
else:
cast_to_pow_type = Cast(pow_node.graph, {'name': prefix.format('base', 'exponent_type'),
'dst_type': exponent_type}).create_node()
pow_node.in_port(0).get_connection().insert_node(cast_to_pow_type)
Cast.type_infer(cast_to_pow_type)
cast_out_to_base_type = Cast(pow_node.graph, {'name': prefix.format('output', 'base'),
'dst_type': base_type}).create_node()
pow_node.out_port(0).get_connection().insert_node(cast_out_to_base_type)
Cast.type_infer(cast_out_to_base_type)
class Elementwise(Op): class Elementwise(Op):
enabled = False enabled = False
operation = None operation = None
@ -135,6 +166,11 @@ class Pow(Elementwise):
return np.array(a.astype(np.float32) ** b, dtype=np.float32) return np.array(a.astype(np.float32) ** b, dtype=np.float32)
return a ** b return a ** b
@staticmethod
def type_infer(node):
override_input_types_of_pow(node)
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
class LogicalElementwise(Elementwise): class LogicalElementwise(Elementwise):
@staticmethod @staticmethod

View File

@ -769,12 +769,14 @@ class Graph(nx.MultiDiGraph):
list( list(
node.out_ports().keys()))) node.out_ports().keys())))
def dump_graph_for_graphviz(self, node_attrs: list = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout', def dump_graph_for_graphviz(self, additional_attrs: list = None,
'internal_layer_id'],
edge_attrs: list = ['in', 'out'], nodes_to_dump: list = None, edge_attrs: list = ['in', 'out'], nodes_to_dump: list = None,
save_to_svg=False, highlight_nodes: list = None): save_to_svg=False, highlight_nodes: list = None):
from extensions.ops.tensor_iterator import _get_internal_output_node_id, _get_internal_input_node_id from extensions.ops.tensor_iterator import _get_internal_output_node_id, _get_internal_input_node_id
node_attrs = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout', 'internal_layer_id']
if additional_attrs:
node_attrs.extend(additional_attrs)
fill_color = {'op': 'lightblue', 'data': 'whitesmoke', 'highlight': 'firebrick'} fill_color = {'op': 'lightblue', 'data': 'whitesmoke', 'highlight': 'firebrick'}
fill_color_by_type = {'Const': 'lightpink', 'Parameter': 'yellowgreen', 'TensorIterator': 'lemonchiffon'} fill_color_by_type = {'Const': 'lightpink', 'Parameter': 'yellowgreen', 'TensorIterator': 'lemonchiffon'}

View File

@ -333,7 +333,7 @@ class Port:
else: else:
self.get_connection().add_destination(port) self.get_connection().add_destination(port)
def _get_data_type(self): def _get_data_type(self) -> np.dtype:
""" """
Internal method which does not raise with error if the data type is not known. Internal method which does not raise with error if the data type is not known.
Check value of the data node to determine input port data type as well as the respective value in the Check value of the data node to determine input port data type as well as the respective value in the
@ -382,7 +382,7 @@ class Port:
# I64 precision for shapes but not all IE plugins support I64, so we should trust data type infer functions # I64 precision for shapes but not all IE plugins support I64, so we should trust data type infer functions
return source_port_data_type if source_port_data_type is not None else value_data_type return source_port_data_type if source_port_data_type is not None else value_data_type
def get_data_type(self): def get_data_type(self) -> np.dtype:
data_type = self._get_data_type() data_type = self._get_data_type()
if data_type is None: if data_type is None:
raise Error('The data type for {} port {} of node {} is not defined'.format(self.type, self.idx, raise Error('The data type for {} port {} of node {} is not defined'.format(self.type, self.idx,

View File

@ -5,11 +5,13 @@ import unittest
import numpy as np import numpy as np
from extensions.ops.elementwise import Round, Elementwise from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Round, Elementwise, Pow
from mo.front.common.partial_infer.utils import int64_array from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node from mo.graph.graph import Node
from mo.middle.passes.infer import type_infer from mo.middle.passes.infer import type_infer
from mo.ops.const import Const from mo.ops.const import Const
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \ from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
shaped_parameter, build_graph shaped_parameter, build_graph
@ -151,3 +153,100 @@ class TestElementwiseTypeAlignment(unittest.TestCase):
type_infer(graph) type_infer(graph)
add_node = Node(graph, 'add') add_node = Node(graph, 'add')
self.assertEquals(add_node.out_port(0).get_data_type(), np.float32) self.assertEquals(add_node.out_port(0).get_data_type(), np.float32)
class TestPowTypeAlignment(unittest.TestCase):
@staticmethod
def build_graph(edges, ref_edges, base_type=np.float32, exponent_type=np.float32):
input_shape = int64_array([1, 3, 255, 255])
nodes = {
**shaped_parameter('input_base', input_shape, {'data_type': base_type}),
**shaped_parameter('input_exponent', input_shape, {'data_type': exponent_type}),
**regular_op_with_empty_data('pow', {'op': 'Pow', 'type': 'Pow', 'type_infer': Pow.type_infer}),
**regular_op_with_empty_data('cast_input', {'op': 'Cast', 'type': 'Cast',
'type_infer': Cast.type_infer}),
**regular_op_with_empty_data('cast_output', {'op': 'Cast', 'type': 'Cast',
'type_infer': Cast.type_infer}),
**result('result'),
}
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
graph_ref = build_graph(nodes, ref_edges, nodes_with_edges_only=True)
graph.stage = 'back'
return graph, graph_ref
def test_base_int32_exponent_float32(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
edges_ref = [
*connect('input_base', '0:pow'),
*connect('input_exponent', 'cast_input'),
*connect('cast_input', '1:pow'),
*connect('pow', 'result')
]
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int32, exponent_type=np.float32)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
def test_base_float32_exponent_int64(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
edges_ref = [
*connect('input_base', 'cast_input'),
*connect('cast_input', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'cast_output'),
*connect('cast_output', 'result')
]
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.float32, exponent_type=np.int64)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
def test_base_in64_exponent_float32(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
edges_ref = [
*connect('input_base', '0:pow'),
*connect('input_exponent', 'cast_input'),
*connect('cast_input', '1:pow'),
*connect('pow', 'result')
]
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int64, exponent_type=np.float32)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
def test_base_float32_exponent_float32(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
graph, graph_ref = self.build_graph(edges, edges, base_type=np.float32, exponent_type=np.float32)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))