type cast for Pow according to ONNX spec
This commit is contained in:
parent
9ef719ef8b
commit
bf85ebf84b
@ -63,7 +63,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
|
||||
return shape_accepting_nodes
|
||||
|
||||
@staticmethod
|
||||
def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], condition: Callable) -> List[Node]:
|
||||
def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], filter_condition: Callable) -> List[Node]:
|
||||
shape_accepting_ops = MarkNodesWithShapeValues.get_shape_accepting_ops()
|
||||
sources = []
|
||||
for node in nodes_with_shape_inputs:
|
||||
@ -72,7 +72,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
|
||||
continue
|
||||
|
||||
source_node = node.in_port(port_idx).get_source().node
|
||||
if not condition(source_node):
|
||||
if not filter_condition(source_node):
|
||||
continue
|
||||
|
||||
sources.append(source_node)
|
||||
|
@ -5,6 +5,7 @@ import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.passes.infer import copy_type_infer
|
||||
@ -37,6 +38,36 @@ def override_data_type_of_constant(node: Node):
|
||||
convert_const_node_value_type(node_to_convert, dst_type)
|
||||
|
||||
|
||||
def override_input_types_of_pow(pow_node: Node):
|
||||
# according to ONNX specification Pow op can have different input types
|
||||
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pow (Examples/types)
|
||||
|
||||
base_type = pow_node.in_port(0).get_data_type()
|
||||
exponent_type = pow_node.in_port(1).get_data_type()
|
||||
if type(base_type) != np.dtype:
|
||||
base_type = np.dtype(base_type)
|
||||
if type(exponent_type) != np.dtype:
|
||||
exponent_type = np.dtype(exponent_type)
|
||||
|
||||
prefix = pow_node.soft_get('name', pow_node.id) + '/Cast_{}_to_{}_type'
|
||||
if base_type != exponent_type:
|
||||
if base_type.itemsize >= exponent_type.itemsize:
|
||||
cast_to_base_type = Cast(pow_node.graph, {'name': prefix.format('exponent', 'base'),
|
||||
'dst_type': base_type}).create_node()
|
||||
pow_node.in_port(1).get_connection().insert_node(cast_to_base_type)
|
||||
Cast.type_infer(cast_to_base_type)
|
||||
else:
|
||||
cast_to_pow_type = Cast(pow_node.graph, {'name': prefix.format('base', 'exponent_type'),
|
||||
'dst_type': exponent_type}).create_node()
|
||||
pow_node.in_port(0).get_connection().insert_node(cast_to_pow_type)
|
||||
Cast.type_infer(cast_to_pow_type)
|
||||
|
||||
cast_out_to_base_type = Cast(pow_node.graph, {'name': prefix.format('output', 'base'),
|
||||
'dst_type': base_type}).create_node()
|
||||
pow_node.out_port(0).get_connection().insert_node(cast_out_to_base_type)
|
||||
Cast.type_infer(cast_out_to_base_type)
|
||||
|
||||
|
||||
class Elementwise(Op):
|
||||
enabled = False
|
||||
operation = None
|
||||
@ -135,6 +166,11 @@ class Pow(Elementwise):
|
||||
return np.array(a.astype(np.float32) ** b, dtype=np.float32)
|
||||
return a ** b
|
||||
|
||||
@staticmethod
|
||||
def type_infer(node):
|
||||
override_input_types_of_pow(node)
|
||||
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
||||
|
||||
|
||||
class LogicalElementwise(Elementwise):
|
||||
@staticmethod
|
||||
|
@ -769,12 +769,14 @@ class Graph(nx.MultiDiGraph):
|
||||
list(
|
||||
node.out_ports().keys())))
|
||||
|
||||
def dump_graph_for_graphviz(self, node_attrs: list = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout',
|
||||
'internal_layer_id'],
|
||||
def dump_graph_for_graphviz(self, additional_attrs: list = None,
|
||||
edge_attrs: list = ['in', 'out'], nodes_to_dump: list = None,
|
||||
save_to_svg=False, highlight_nodes: list = None):
|
||||
|
||||
from extensions.ops.tensor_iterator import _get_internal_output_node_id, _get_internal_input_node_id
|
||||
node_attrs = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout', 'internal_layer_id']
|
||||
if additional_attrs:
|
||||
node_attrs.extend(additional_attrs)
|
||||
|
||||
fill_color = {'op': 'lightblue', 'data': 'whitesmoke', 'highlight': 'firebrick'}
|
||||
fill_color_by_type = {'Const': 'lightpink', 'Parameter': 'yellowgreen', 'TensorIterator': 'lemonchiffon'}
|
||||
|
@ -333,7 +333,7 @@ class Port:
|
||||
else:
|
||||
self.get_connection().add_destination(port)
|
||||
|
||||
def _get_data_type(self):
|
||||
def _get_data_type(self) -> np.dtype:
|
||||
"""
|
||||
Internal method which does not raise with error if the data type is not known.
|
||||
Check value of the data node to determine input port data type as well as the respective value in the
|
||||
@ -382,7 +382,7 @@ class Port:
|
||||
# I64 precision for shapes but not all IE plugins support I64, so we should trust data type infer functions
|
||||
return source_port_data_type if source_port_data_type is not None else value_data_type
|
||||
|
||||
def get_data_type(self):
|
||||
def get_data_type(self) -> np.dtype:
|
||||
data_type = self._get_data_type()
|
||||
if data_type is None:
|
||||
raise Error('The data type for {} port {} of node {} is not defined'.format(self.type, self.idx,
|
||||
|
@ -5,11 +5,13 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Round, Elementwise
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Round, Elementwise, Pow
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node
|
||||
from mo.middle.passes.infer import type_infer
|
||||
from mo.ops.const import Const
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
|
||||
shaped_parameter, build_graph
|
||||
|
||||
@ -151,3 +153,100 @@ class TestElementwiseTypeAlignment(unittest.TestCase):
|
||||
type_infer(graph)
|
||||
add_node = Node(graph, 'add')
|
||||
self.assertEquals(add_node.out_port(0).get_data_type(), np.float32)
|
||||
|
||||
|
||||
class TestPowTypeAlignment(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def build_graph(edges, ref_edges, base_type=np.float32, exponent_type=np.float32):
|
||||
input_shape = int64_array([1, 3, 255, 255])
|
||||
|
||||
nodes = {
|
||||
**shaped_parameter('input_base', input_shape, {'data_type': base_type}),
|
||||
**shaped_parameter('input_exponent', input_shape, {'data_type': exponent_type}),
|
||||
**regular_op_with_empty_data('pow', {'op': 'Pow', 'type': 'Pow', 'type_infer': Pow.type_infer}),
|
||||
**regular_op_with_empty_data('cast_input', {'op': 'Cast', 'type': 'Cast',
|
||||
'type_infer': Cast.type_infer}),
|
||||
**regular_op_with_empty_data('cast_output', {'op': 'Cast', 'type': 'Cast',
|
||||
'type_infer': Cast.type_infer}),
|
||||
**result('result'),
|
||||
}
|
||||
|
||||
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||
graph_ref = build_graph(nodes, ref_edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'back'
|
||||
return graph, graph_ref
|
||||
|
||||
def test_base_int32_exponent_float32(self):
|
||||
edges = [
|
||||
*connect('input_base', '0:pow'),
|
||||
*connect('input_exponent', '1:pow'),
|
||||
*connect('pow', 'result')
|
||||
]
|
||||
|
||||
edges_ref = [
|
||||
*connect('input_base', '0:pow'),
|
||||
*connect('input_exponent', 'cast_input'),
|
||||
*connect('cast_input', '1:pow'),
|
||||
*connect('pow', 'result')
|
||||
]
|
||||
|
||||
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int32, exponent_type=np.float32)
|
||||
type_infer(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||
|
||||
def test_base_float32_exponent_int64(self):
|
||||
edges = [
|
||||
*connect('input_base', '0:pow'),
|
||||
*connect('input_exponent', '1:pow'),
|
||||
*connect('pow', 'result')
|
||||
]
|
||||
|
||||
edges_ref = [
|
||||
*connect('input_base', 'cast_input'),
|
||||
*connect('cast_input', '0:pow'),
|
||||
*connect('input_exponent', '1:pow'),
|
||||
*connect('pow', 'cast_output'),
|
||||
*connect('cast_output', 'result')
|
||||
]
|
||||
|
||||
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.float32, exponent_type=np.int64)
|
||||
type_infer(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||
|
||||
def test_base_in64_exponent_float32(self):
|
||||
edges = [
|
||||
*connect('input_base', '0:pow'),
|
||||
*connect('input_exponent', '1:pow'),
|
||||
*connect('pow', 'result')
|
||||
]
|
||||
|
||||
edges_ref = [
|
||||
*connect('input_base', '0:pow'),
|
||||
*connect('input_exponent', 'cast_input'),
|
||||
*connect('cast_input', '1:pow'),
|
||||
*connect('pow', 'result')
|
||||
]
|
||||
|
||||
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int64, exponent_type=np.float32)
|
||||
type_infer(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||
|
||||
def test_base_float32_exponent_float32(self):
|
||||
edges = [
|
||||
*connect('input_base', '0:pow'),
|
||||
*connect('input_exponent', '1:pow'),
|
||||
*connect('pow', 'result')
|
||||
]
|
||||
|
||||
graph, graph_ref = self.build_graph(edges, edges, base_type=np.float32, exponent_type=np.float32)
|
||||
type_infer(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||
|
Loading…
Reference in New Issue
Block a user