Added workaround for logical elementwise operations to change the constant input data type if it does not match the other one (#1955)
This commit is contained in:
parent
d82a16abd8
commit
0e9ead3495
@ -19,12 +19,31 @@ import logging as log
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.passes.infer import copy_type_infer
|
||||
from mo.ops.op import Op
|
||||
from mo.pipeline.common import convert_const_node_value_type
|
||||
|
||||
|
||||
def override_data_type_of_constant(node: Node):
|
||||
in_type_0 = node.in_port(0).get_data_type()
|
||||
in_type_1 = node.in_port(1).get_data_type()
|
||||
if in_type_0 != in_type_1:
|
||||
# in case of input values data type mismatch we try to change the type of the constant to match the type of
|
||||
# another input. The input values data type mismatch occur when the MO performs replacement of some
|
||||
# operations like SquaredDifference of inputs with floating point data type to Power layer with the integer
|
||||
# power value 2, or when replacing Neg operation with Mul with -1 as second input.
|
||||
in_node_0 = node.in_port(0).get_source().node
|
||||
in_node_1 = node.in_port(1).get_source().node
|
||||
if in_node_0.op == 'Const':
|
||||
convert_const_node_value_type(in_node_0, in_type_1)
|
||||
elif in_node_1.op == 'Const':
|
||||
convert_const_node_value_type(in_node_1, in_type_0)
|
||||
else:
|
||||
log.error('Elementwise operation {} has inputs of different data types: {} and {}'.format(
|
||||
node.soft_get('name'), in_type_0, in_type_1))
|
||||
|
||||
|
||||
class Elementwise(Op):
|
||||
enabled = False
|
||||
operation = None
|
||||
@ -48,23 +67,7 @@ class Elementwise(Op):
|
||||
|
||||
@staticmethod
|
||||
def type_infer(node):
|
||||
in_type_0 = node.in_port(0).get_data_type()
|
||||
in_type_1 = node.in_port(1).get_data_type()
|
||||
if in_type_0 != in_type_1:
|
||||
# in case of input values data type mismatch we try to change the type of the constant to match the type of
|
||||
# another input. The input values data type mismatch occur when the MO performs replacement of some
|
||||
# operations like SquaredDifference of inputs with floating point data type to Power layer with the integer
|
||||
# power value 2, or when replacing Neg operation with Mul with -1 as second input.
|
||||
in_node_0 = node.in_port(0).get_source().node
|
||||
in_node_1 = node.in_port(1).get_source().node
|
||||
if in_node_0.op == 'Const':
|
||||
convert_const_node_value_type(in_node_0, in_type_1)
|
||||
elif in_node_1.op == 'Const':
|
||||
convert_const_node_value_type(in_node_1, in_type_0)
|
||||
else:
|
||||
log.error('Elementwise operation {} has inputs of different data types: {} and {}'.format(
|
||||
node.soft_get('name'), in_type_0, in_type_1))
|
||||
|
||||
override_data_type_of_constant(node)
|
||||
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
||||
|
||||
|
||||
@ -139,6 +142,7 @@ class Pow(Elementwise):
|
||||
class LogicalElementwise(Elementwise):
|
||||
@staticmethod
|
||||
def type_infer(node):
|
||||
override_data_type_of_constant(node)
|
||||
node.out_port(0).set_data_type(np.bool)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user