type cast for Pow according to ONNX spec
This commit is contained in:
parent
9ef719ef8b
commit
bf85ebf84b
@ -63,7 +63,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
|
|||||||
return shape_accepting_nodes
|
return shape_accepting_nodes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], condition: Callable) -> List[Node]:
|
def get_sources_for_nodes(nodes_with_shape_inputs: List[Node], filter_condition: Callable) -> List[Node]:
|
||||||
shape_accepting_ops = MarkNodesWithShapeValues.get_shape_accepting_ops()
|
shape_accepting_ops = MarkNodesWithShapeValues.get_shape_accepting_ops()
|
||||||
sources = []
|
sources = []
|
||||||
for node in nodes_with_shape_inputs:
|
for node in nodes_with_shape_inputs:
|
||||||
@ -72,7 +72,7 @@ class MarkNodesWithShapeValues(BackReplacementPattern):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
source_node = node.in_port(port_idx).get_source().node
|
source_node = node.in_port(port_idx).get_source().node
|
||||||
if not condition(source_node):
|
if not filter_condition(source_node):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sources.append(source_node)
|
sources.append(source_node)
|
||||||
|
@ -5,6 +5,7 @@ import logging as log
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.ops.Cast import Cast
|
||||||
from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer
|
from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer
|
||||||
from mo.graph.graph import Graph, Node
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.middle.passes.infer import copy_type_infer
|
from mo.middle.passes.infer import copy_type_infer
|
||||||
@ -37,6 +38,36 @@ def override_data_type_of_constant(node: Node):
|
|||||||
convert_const_node_value_type(node_to_convert, dst_type)
|
convert_const_node_value_type(node_to_convert, dst_type)
|
||||||
|
|
||||||
|
|
||||||
|
def override_input_types_of_pow(pow_node: Node):
|
||||||
|
# according to ONNX specification Pow op can have different input types
|
||||||
|
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pow (Examples/types)
|
||||||
|
|
||||||
|
base_type = pow_node.in_port(0).get_data_type()
|
||||||
|
exponent_type = pow_node.in_port(1).get_data_type()
|
||||||
|
if type(base_type) != np.dtype:
|
||||||
|
base_type = np.dtype(base_type)
|
||||||
|
if type(exponent_type) != np.dtype:
|
||||||
|
exponent_type = np.dtype(exponent_type)
|
||||||
|
|
||||||
|
prefix = pow_node.soft_get('name', pow_node.id) + '/Cast_{}_to_{}_type'
|
||||||
|
if base_type != exponent_type:
|
||||||
|
if base_type.itemsize >= exponent_type.itemsize:
|
||||||
|
cast_to_base_type = Cast(pow_node.graph, {'name': prefix.format('exponent', 'base'),
|
||||||
|
'dst_type': base_type}).create_node()
|
||||||
|
pow_node.in_port(1).get_connection().insert_node(cast_to_base_type)
|
||||||
|
Cast.type_infer(cast_to_base_type)
|
||||||
|
else:
|
||||||
|
cast_to_pow_type = Cast(pow_node.graph, {'name': prefix.format('base', 'exponent_type'),
|
||||||
|
'dst_type': exponent_type}).create_node()
|
||||||
|
pow_node.in_port(0).get_connection().insert_node(cast_to_pow_type)
|
||||||
|
Cast.type_infer(cast_to_pow_type)
|
||||||
|
|
||||||
|
cast_out_to_base_type = Cast(pow_node.graph, {'name': prefix.format('output', 'base'),
|
||||||
|
'dst_type': base_type}).create_node()
|
||||||
|
pow_node.out_port(0).get_connection().insert_node(cast_out_to_base_type)
|
||||||
|
Cast.type_infer(cast_out_to_base_type)
|
||||||
|
|
||||||
|
|
||||||
class Elementwise(Op):
|
class Elementwise(Op):
|
||||||
enabled = False
|
enabled = False
|
||||||
operation = None
|
operation = None
|
||||||
@ -135,6 +166,11 @@ class Pow(Elementwise):
|
|||||||
return np.array(a.astype(np.float32) ** b, dtype=np.float32)
|
return np.array(a.astype(np.float32) ** b, dtype=np.float32)
|
||||||
return a ** b
|
return a ** b
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def type_infer(node):
|
||||||
|
override_input_types_of_pow(node)
|
||||||
|
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
||||||
|
|
||||||
|
|
||||||
class LogicalElementwise(Elementwise):
|
class LogicalElementwise(Elementwise):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -769,12 +769,14 @@ class Graph(nx.MultiDiGraph):
|
|||||||
list(
|
list(
|
||||||
node.out_ports().keys())))
|
node.out_ports().keys())))
|
||||||
|
|
||||||
def dump_graph_for_graphviz(self, node_attrs: list = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout',
|
def dump_graph_for_graphviz(self, additional_attrs: list = None,
|
||||||
'internal_layer_id'],
|
|
||||||
edge_attrs: list = ['in', 'out'], nodes_to_dump: list = None,
|
edge_attrs: list = ['in', 'out'], nodes_to_dump: list = None,
|
||||||
save_to_svg=False, highlight_nodes: list = None):
|
save_to_svg=False, highlight_nodes: list = None):
|
||||||
|
|
||||||
from extensions.ops.tensor_iterator import _get_internal_output_node_id, _get_internal_input_node_id
|
from extensions.ops.tensor_iterator import _get_internal_output_node_id, _get_internal_input_node_id
|
||||||
|
node_attrs = ['kind', 'op', 'shape', 'correct_data_layout', 'nchw_layout', 'internal_layer_id']
|
||||||
|
if additional_attrs:
|
||||||
|
node_attrs.extend(additional_attrs)
|
||||||
|
|
||||||
fill_color = {'op': 'lightblue', 'data': 'whitesmoke', 'highlight': 'firebrick'}
|
fill_color = {'op': 'lightblue', 'data': 'whitesmoke', 'highlight': 'firebrick'}
|
||||||
fill_color_by_type = {'Const': 'lightpink', 'Parameter': 'yellowgreen', 'TensorIterator': 'lemonchiffon'}
|
fill_color_by_type = {'Const': 'lightpink', 'Parameter': 'yellowgreen', 'TensorIterator': 'lemonchiffon'}
|
||||||
|
@ -333,7 +333,7 @@ class Port:
|
|||||||
else:
|
else:
|
||||||
self.get_connection().add_destination(port)
|
self.get_connection().add_destination(port)
|
||||||
|
|
||||||
def _get_data_type(self):
|
def _get_data_type(self) -> np.dtype:
|
||||||
"""
|
"""
|
||||||
Internal method which does not raise with error if the data type is not known.
|
Internal method which does not raise with error if the data type is not known.
|
||||||
Check value of the data node to determine input port data type as well as the respective value in the
|
Check value of the data node to determine input port data type as well as the respective value in the
|
||||||
@ -382,7 +382,7 @@ class Port:
|
|||||||
# I64 precision for shapes but not all IE plugins support I64, so we should trust data type infer functions
|
# I64 precision for shapes but not all IE plugins support I64, so we should trust data type infer functions
|
||||||
return source_port_data_type if source_port_data_type is not None else value_data_type
|
return source_port_data_type if source_port_data_type is not None else value_data_type
|
||||||
|
|
||||||
def get_data_type(self):
|
def get_data_type(self) -> np.dtype:
|
||||||
data_type = self._get_data_type()
|
data_type = self._get_data_type()
|
||||||
if data_type is None:
|
if data_type is None:
|
||||||
raise Error('The data type for {} port {} of node {} is not defined'.format(self.type, self.idx,
|
raise Error('The data type for {} port {} of node {} is not defined'.format(self.type, self.idx,
|
||||||
|
@ -5,11 +5,13 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.ops.elementwise import Round, Elementwise
|
from extensions.ops.Cast import Cast
|
||||||
|
from extensions.ops.elementwise import Round, Elementwise, Pow
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.graph.graph import Node
|
from mo.graph.graph import Node
|
||||||
from mo.middle.passes.infer import type_infer
|
from mo.middle.passes.infer import type_infer
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
|
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
|
||||||
shaped_parameter, build_graph
|
shaped_parameter, build_graph
|
||||||
|
|
||||||
@ -151,3 +153,100 @@ class TestElementwiseTypeAlignment(unittest.TestCase):
|
|||||||
type_infer(graph)
|
type_infer(graph)
|
||||||
add_node = Node(graph, 'add')
|
add_node = Node(graph, 'add')
|
||||||
self.assertEquals(add_node.out_port(0).get_data_type(), np.float32)
|
self.assertEquals(add_node.out_port(0).get_data_type(), np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPowTypeAlignment(unittest.TestCase):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_graph(edges, ref_edges, base_type=np.float32, exponent_type=np.float32):
|
||||||
|
input_shape = int64_array([1, 3, 255, 255])
|
||||||
|
|
||||||
|
nodes = {
|
||||||
|
**shaped_parameter('input_base', input_shape, {'data_type': base_type}),
|
||||||
|
**shaped_parameter('input_exponent', input_shape, {'data_type': exponent_type}),
|
||||||
|
**regular_op_with_empty_data('pow', {'op': 'Pow', 'type': 'Pow', 'type_infer': Pow.type_infer}),
|
||||||
|
**regular_op_with_empty_data('cast_input', {'op': 'Cast', 'type': 'Cast',
|
||||||
|
'type_infer': Cast.type_infer}),
|
||||||
|
**regular_op_with_empty_data('cast_output', {'op': 'Cast', 'type': 'Cast',
|
||||||
|
'type_infer': Cast.type_infer}),
|
||||||
|
**result('result'),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||||
|
graph_ref = build_graph(nodes, ref_edges, nodes_with_edges_only=True)
|
||||||
|
graph.stage = 'back'
|
||||||
|
return graph, graph_ref
|
||||||
|
|
||||||
|
def test_base_int32_exponent_float32(self):
|
||||||
|
edges = [
|
||||||
|
*connect('input_base', '0:pow'),
|
||||||
|
*connect('input_exponent', '1:pow'),
|
||||||
|
*connect('pow', 'result')
|
||||||
|
]
|
||||||
|
|
||||||
|
edges_ref = [
|
||||||
|
*connect('input_base', '0:pow'),
|
||||||
|
*connect('input_exponent', 'cast_input'),
|
||||||
|
*connect('cast_input', '1:pow'),
|
||||||
|
*connect('pow', 'result')
|
||||||
|
]
|
||||||
|
|
||||||
|
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int32, exponent_type=np.float32)
|
||||||
|
type_infer(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||||
|
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||||
|
|
||||||
|
def test_base_float32_exponent_int64(self):
|
||||||
|
edges = [
|
||||||
|
*connect('input_base', '0:pow'),
|
||||||
|
*connect('input_exponent', '1:pow'),
|
||||||
|
*connect('pow', 'result')
|
||||||
|
]
|
||||||
|
|
||||||
|
edges_ref = [
|
||||||
|
*connect('input_base', 'cast_input'),
|
||||||
|
*connect('cast_input', '0:pow'),
|
||||||
|
*connect('input_exponent', '1:pow'),
|
||||||
|
*connect('pow', 'cast_output'),
|
||||||
|
*connect('cast_output', 'result')
|
||||||
|
]
|
||||||
|
|
||||||
|
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.float32, exponent_type=np.int64)
|
||||||
|
type_infer(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||||
|
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||||
|
|
||||||
|
def test_base_in64_exponent_float32(self):
|
||||||
|
edges = [
|
||||||
|
*connect('input_base', '0:pow'),
|
||||||
|
*connect('input_exponent', '1:pow'),
|
||||||
|
*connect('pow', 'result')
|
||||||
|
]
|
||||||
|
|
||||||
|
edges_ref = [
|
||||||
|
*connect('input_base', '0:pow'),
|
||||||
|
*connect('input_exponent', 'cast_input'),
|
||||||
|
*connect('cast_input', '1:pow'),
|
||||||
|
*connect('pow', 'result')
|
||||||
|
]
|
||||||
|
|
||||||
|
graph, graph_ref = self.build_graph(edges, edges_ref, base_type=np.int64, exponent_type=np.float32)
|
||||||
|
type_infer(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||||
|
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||||
|
|
||||||
|
def test_base_float32_exponent_float32(self):
|
||||||
|
edges = [
|
||||||
|
*connect('input_base', '0:pow'),
|
||||||
|
*connect('input_exponent', '1:pow'),
|
||||||
|
*connect('pow', 'result')
|
||||||
|
]
|
||||||
|
|
||||||
|
graph, graph_ref = self.build_graph(edges, edges, base_type=np.float32, exponent_type=np.float32)
|
||||||
|
type_infer(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=False)
|
||||||
|
self.assertTrue(flag, 'Graphs after do not match to reference: {}'.format(resp))
|
||||||
|
Loading…
Reference in New Issue
Block a user