BitwiseElementwise ops init in MO (#20386)
This commit is contained in:
parent
6519afd4d3
commit
29475c738e
@ -246,3 +246,27 @@ class Sqrt(UnaryElementwise):
|
|||||||
if np.issubdtype(a.dtype, np.signedinteger):
|
if np.issubdtype(a.dtype, np.signedinteger):
|
||||||
return float32_array(a.astype(np.float32) ** 0.5)
|
return float32_array(a.astype(np.float32) ** 0.5)
|
||||||
return a ** 0.5
|
return a ** 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class BitwiseAnd(Elementwise):
|
||||||
|
op = 'BitwiseAnd'
|
||||||
|
op_type = 'BitwiseAnd'
|
||||||
|
version = 'opset13'
|
||||||
|
|
||||||
|
|
||||||
|
class BitwiseOr(Elementwise):
|
||||||
|
op = 'BitwiseOr'
|
||||||
|
op_type = 'BitwiseOr'
|
||||||
|
version = 'opset13'
|
||||||
|
|
||||||
|
|
||||||
|
class BitwiseXor(Elementwise):
|
||||||
|
op = 'BitwiseXor'
|
||||||
|
op_type = 'BitwiseXor'
|
||||||
|
version = 'opset13'
|
||||||
|
|
||||||
|
|
||||||
|
class BitwiseNot(UnaryElementwise):
|
||||||
|
op = 'BitwiseNot'
|
||||||
|
op_type = 'BitwiseNot'
|
||||||
|
version = 'opset13'
|
||||||
|
@ -6,6 +6,7 @@ import tempfile
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import openvino.runtime.opset13 as opset13
|
||||||
import openvino.runtime.opset12 as opset12
|
import openvino.runtime.opset12 as opset12
|
||||||
import openvino.runtime.opset11 as opset11
|
import openvino.runtime.opset11 as opset11
|
||||||
import openvino.runtime.opset10 as opset10
|
import openvino.runtime.opset10 as opset10
|
||||||
@ -245,3 +246,49 @@ class TestOps(unittest.TestCase):
|
|||||||
self.assertEqual(gn_node["version"], "opset12")
|
self.assertEqual(gn_node["version"], "opset12")
|
||||||
self.assertEqual(gn_node['num_groups'], 1)
|
self.assertEqual(gn_node['num_groups'], 1)
|
||||||
self.assertEqual(gn_node['epsilon'], 1e-06)
|
self.assertEqual(gn_node['epsilon'], 1e-06)
|
||||||
|
|
||||||
|
def test_bitwise_and_13(self):
|
||||||
|
a = opset13.parameter([4, 1], name="A", dtype=np.int32)
|
||||||
|
b = opset13.parameter([1, 2], name="B", dtype=np.int32)
|
||||||
|
|
||||||
|
op = opset13.bitwise_and(a, b)
|
||||||
|
model = Model(op, [a, b])
|
||||||
|
graph = TestOps.check_graph_can_save(model, "bitwise_and_model")
|
||||||
|
op_node = graph.get_op_nodes(op="BitwiseAnd")[0]
|
||||||
|
self.assertListEqual(op_node.out_port(0).data.get_shape().tolist(), [4, 2])
|
||||||
|
self.assertEqual(op_node["version"], "opset13")
|
||||||
|
self.assertEqual(op_node["auto_broadcast"], "numpy")
|
||||||
|
|
||||||
|
def test_bitwise_or_13(self):
|
||||||
|
a = opset13.parameter([4, 1], name="A", dtype=np.int32)
|
||||||
|
b = opset13.parameter([1, 2], name="B", dtype=np.int32)
|
||||||
|
|
||||||
|
op = opset13.bitwise_or(a, b)
|
||||||
|
model = Model(op, [a, b])
|
||||||
|
graph = TestOps.check_graph_can_save(model, "bitwise_or_model")
|
||||||
|
op_node = graph.get_op_nodes(op="BitwiseOr")[0]
|
||||||
|
self.assertListEqual(op_node.out_port(0).data.get_shape().tolist(), [4, 2])
|
||||||
|
self.assertEqual(op_node["version"], "opset13")
|
||||||
|
self.assertEqual(op_node["auto_broadcast"], "numpy")
|
||||||
|
|
||||||
|
def test_bitwise_xor_13(self):
|
||||||
|
a = opset13.parameter([4, 1], name="A", dtype=np.int32)
|
||||||
|
b = opset13.parameter([1, 2], name="B", dtype=np.int32)
|
||||||
|
|
||||||
|
op = opset13.bitwise_xor(a, b)
|
||||||
|
model = Model(op, [a, b])
|
||||||
|
graph = TestOps.check_graph_can_save(model, "bitwise_xor_model")
|
||||||
|
op_node = graph.get_op_nodes(op="BitwiseXor")[0]
|
||||||
|
self.assertListEqual(op_node.out_port(0).data.get_shape().tolist(), [4, 2])
|
||||||
|
self.assertEqual(op_node["version"], "opset13")
|
||||||
|
self.assertEqual(op_node["auto_broadcast"], "numpy")
|
||||||
|
|
||||||
|
def test_bitwise_not_13(self):
|
||||||
|
a = opset13.parameter([4, 2], name="A", dtype=np.int32)
|
||||||
|
|
||||||
|
op = opset13.bitwise_not(a)
|
||||||
|
model = Model(op, [a])
|
||||||
|
graph = TestOps.check_graph_can_save(model, "bitwise_not_model")
|
||||||
|
op_node = graph.get_op_nodes(op="BitwiseNot")[0]
|
||||||
|
self.assertListEqual(op_node.out_port(0).data.get_shape().tolist(), [4, 2])
|
||||||
|
self.assertEqual(op_node["version"], "opset13")
|
||||||
|
Loading…
Reference in New Issue
Block a user