type cast for Pow according to ONNX spec

This commit is contained in:
Pavel Esir 2021-06-01 13:52:25 +03:00
parent 9ef719ef8b
commit bf85ebf84b
5 changed files with 144 additions and 7 deletions

View File

@ -63,7 +63,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
return shape_accepting_nodes
@staticmethod
def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], condition: Callable) -> List[Node]:
def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], filter_condition: Callable) -> List[Node]:
shape_accepting_ops = MarkNodesWithShapeValues.get_shape_accepting_ops()
sources = []
for node in nodes_with_shape_inputs:
@ -72,7 +72,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
continue
source_node = node.in_port(port_idx).get_source().node
if not condition(source_node):
if not filter_condition(source_node):
continue
sources.append(source_node)

View File

@ -5,6 +5,7 @@ import logging as log
import numpy as np
from extensions.ops.Cast import Cast
from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer
from mo.graph.graph import Graph, Node
from mo.middle.passes.infer import copy_type_infer
@ -37,6 +38,36 @@ def override_data_type_of_constant(node: Node):
convert_const_node_value_type(node_to_convert, dst_type)
def override_input_types_of_pow(pow_node: Node):
# according to ONNX specification Pow op can have different input types
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pow (Examples/types)
base_type = pow_node.in_port(0).get_data_type()
exponent_type = pow_node.in_port(1).get_data_type()
if type(base_type) != np.dtype:
base_type = np.dtype(base_type)
if type(exponent_type) != np.dtype:
exponent_type = np.dtype(exponent_type)
prefix = pow_node.soft_get('name', pow_node.id) + '/Cast_{}_to_{}_type'
if base_type != exponent_type:
if base_type.itemsize >= exponent_type.itemsize:
cast_to_base_type = Cast(pow_node.graph, {'name': prefix.format('exponent', 'base'),
'dst_type': base_type}).create_node()
pow_node.in_port(1).get_connection().insert_node(cast_to_base_type)
Cast.type_infer(cast_to_base_type)
else:
cast_to_pow_type = Cast(pow_node.graph, {'name': prefix.format('base', 'exponent_type'),
'dst_type': exponent_type}).create_node()
pow_node.in_port(0).get_connection().insert_node(cast_to_pow_type)
Cast.type_infer(cast_to_pow_type)
cast_out_to_base_type = Cast(pow_node.graph, {'name': prefix.format('output', 'base'),
'dst_type': base_type}).create_node()
pow_node.out_port(0).get_connection().insert_node(cast_out_to_base_type)
Cast.type_infer(cast_out_to_base_type)
class Elementwise(Op):
enabled = False
operation = None
@ -135,6 +166,11 @@ class Pow(Elementwise):
return np.array(a.astype(np.float32) ** b, dtype=np.float32)
return a ** b
@staticmethod
def type_infer(node):
override_input_types_of_pow(node)
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
class LogicalElementwise(Elementwise):
@staticmethod

View File

@ -769,12 +769,14 @@ class Graph(nx.MultiDiGraph):
list(
node.out_ports().keys())))
def dump_graph_for_graphviz(self, node_attrs: list = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout',
'internal_layer_id'],
def dump_graph_for_graphviz(self, additional_attrs: list = None,
edge_attrs: list = ['in', 'out'], nodes_to_dump: list = None,
save_to_svg=False, highlight_nodes: list = None):
from extensions.ops.tensor_iterator import _get_internal_output_node_id, _get_internal_input_node_id
node_attrs = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout', 'internal_layer_id']
if additional_attrs:
node_attrs.extend(additional_attrs)
fill_color = {'op': 'lightblue', 'data': 'whitesmoke', 'highlight': 'firebrick'}
fill_color_by_type = {'Const': 'lightpink', 'Parameter': 'yellowgreen', 'TensorIterator': 'lemonchiffon'}

View File

@ -333,7 +333,7 @@ class Port:
else:
self.get_connection().add_destination(port)
def _get_data_type(self):
def _get_data_type(self) -> np.dtype:
"""
Internal method which does not raise with error if the data type is not known.
Check value of the data node to determine input port data type as well as the respective value in the
@ -382,7 +382,7 @@ class Port:
# I64 precision for shapes but not all IE plugins support I64, so we should trust data type infer functions
return source_port_data_type if source_port_data_type is not None else value_data_type
def get_data_type(self):
def get_data_type(self) -> np.dtype:
data_type = self._get_data_type()
if data_type is None:
raise Error('The data type for {} port {} of node {} is not defined'.format(self.type, self.idx,

View File

@ -5,11 +5,13 @@ import unittest
import numpy as np
from extensions.ops.elementwise import Round, Elementwise
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Round, Elementwise, Pow
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.middle.passes.infer import type_infer
from mo.ops.const import Const
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
shaped_parameter, build_graph
@ -151,3 +153,100 @@ class TestElementwiseTypeAlignment(unittest.TestCase):
type_infer(graph)
add_node = Node(graph, 'add')
self.assertEquals(add_node.out_port(0).get_data_type(), np.float32)
class TestPowTypeAlignment(unittest.TestCase):
@staticmethod
def build_graph(edges, ref_edges, base_type=np.float32, exponent_type=np.float32):
input_shape = int64_array([1, 3, 255, 255])
nodes = {
**shaped_parameter('input_base', input_shape, {'data_type': base_type}),
**shaped_parameter('input_exponent', input_shape, {'data_type': exponent_type}),
**regular_op_with_empty_data('pow', {'op': 'Pow', 'type': 'Pow', 'type_infer': Pow.type_infer}),
**regular_op_with_empty_data('cast_input', {'op': 'Cast', 'type': 'Cast',
'type_infer': Cast.type_infer}),
**regular_op_with_empty_data('cast_output', {'op': 'Cast', 'type': 'Cast',
'type_infer': Cast.type_infer}),
**result('result'),
}
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
graph_ref = build_graph(nodes, ref_edges, nodes_with_edges_only=True)
graph.stage = 'back'
return graph, graph_ref
def test_base_int32_exponent_float32(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
edges_ref = [
*connect('input_base', '0:pow'),
*connect('input_exponent', 'cast_input'),
*connect('cast_input', '1:pow'),
*connect('pow', 'result')
]
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int32, exponent_type=np.float32)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
def test_base_float32_exponent_int64(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
edges_ref = [
*connect('input_base', 'cast_input'),
*connect('cast_input', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'cast_output'),
*connect('cast_output', 'result')
]
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.float32, exponent_type=np.int64)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
def test_base_in64_exponent_float32(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
edges_ref = [
*connect('input_base', '0:pow'),
*connect('input_exponent', 'cast_input'),
*connect('cast_input', '1:pow'),
*connect('pow', 'result')
]
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int64, exponent_type=np.float32)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
def test_base_float32_exponent_float32(self):
edges = [
*connect('input_base', '0:pow'),
*connect('input_exponent', '1:pow'),
*connect('pow', 'result')
]
graph, graph_ref = self.build_graph(edges, edges, base_type=np.float32, exponent_type=np.float32)
type_infer(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))