Added workaround for logical elementwise operations to change the constant input data type if it does not match the other one (#1955)

This commit is contained in:
Evgeny Lazarev 2020-08-27 08:46:34 +03:00 committed by GitHub
parent d82a16abd8
commit 0e9ead3495
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,12 +19,31 @@ import logging as log
import numpy as np
from mo.front.common.partial_infer.eltwise import eltwise_infer, bias_add_infer
from mo.graph.graph import Graph
from mo.graph.graph import Graph, Node
from mo.middle.passes.infer import copy_type_infer
from mo.ops.op import Op
from mo.pipeline.common import convert_const_node_value_type
def override_data_type_of_constant(node: Node):
in_type_0 = node.in_port(0).get_data_type()
in_type_1 = node.in_port(1).get_data_type()
if in_type_0 != in_type_1:
# in case of input values data type mismatch we try to change the type of the constant to match the type of
# another input. The input values data type mismatch occur when the MO performs replacement of some
# operations like SquaredDifference of inputs with floating point data type to Power layer with the integer
# power value 2, or when replacing Neg operation with Mul with -1 as second input.
in_node_0 = node.in_port(0).get_source().node
in_node_1 = node.in_port(1).get_source().node
if in_node_0.op == 'Const':
convert_const_node_value_type(in_node_0, in_type_1)
elif in_node_1.op == 'Const':
convert_const_node_value_type(in_node_1, in_type_0)
else:
log.error('Elementwise operation {} has inputs of different data types: {} and {}'.format(
node.soft_get('name'), in_type_0, in_type_1))
class Elementwise(Op):
enabled = False
operation = None
@ -48,23 +67,7 @@ class Elementwise(Op):
@staticmethod
def type_infer(node):
in_type_0 = node.in_port(0).get_data_type()
in_type_1 = node.in_port(1).get_data_type()
if in_type_0 != in_type_1:
# in case of input values data type mismatch we try to change the type of the constant to match the type of
# another input. The input values data type mismatch occur when the MO performs replacement of some
# operations like SquaredDifference of inputs with floating point data type to Power layer with the integer
# power value 2, or when replacing Neg operation with Mul with -1 as second input.
in_node_0 = node.in_port(0).get_source().node
in_node_1 = node.in_port(1).get_source().node
if in_node_0.op == 'Const':
convert_const_node_value_type(in_node_0, in_type_1)
elif in_node_1.op == 'Const':
convert_const_node_value_type(in_node_1, in_type_0)
else:
log.error('Elementwise operation {} has inputs of different data types: {} and {}'.format(
node.soft_get('name'), in_type_0, in_type_1))
override_data_type_of_constant(node)
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
@ -139,6 +142,7 @@ class Pow(Elementwise):
class LogicalElementwise(Elementwise):
@staticmethod
def type_infer(node):
override_data_type_of_constant(node)
node.out_port(0).set_data_type(np.bool)