Fix Scatter value infer for fully dynamic value (#17165)
* Fix issue with dynamic Scatter in MO IR Reader * Only normalize for 1D tensors * Add test
This commit is contained in:
@@ -1081,6 +1081,7 @@ openvino/tools/mo/utils/ir_reader/extenders/strided_slice_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/extenders/tensoriterator_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/extenders/topk_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/extenders/variadic_split_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/internal_ops/scatter.py
|
||||
openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py
|
||||
openvino/tools/mo/utils/ir_reader/internal_ops/unique.py
|
||||
openvino/tools/mo/utils/ir_reader/internal_ops/unsqueeze.py
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.scatter import ScatterUpdate, Scatter
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import dynamic_dimension_value
|
||||
|
||||
class ScatterUpdateInternal(ScatterUpdate):
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
updates_value = node.in_port(2).data.get_value()
|
||||
if updates_value is not None and isinstance(updates_value, np.ma.masked_array) and updates_value.ndim == 1:
|
||||
# we need to normalize masked_array so that the value infer works as expected
|
||||
value = [item if item is not np.ma.masked else dynamic_dimension_value for item in updates_value]
|
||||
updates_value = np.ma.masked_equal(value, dynamic_dimension_value).astype(dtype=updates_value.dtype)
|
||||
node.in_port(2).data.set_value(updates_value)
|
||||
ScatterUpdate.infer(node)
|
||||
@@ -35,6 +35,7 @@ from openvino.tools.mo.utils.ir_reader.extender import Extender
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.squeeze import SqueezeInternal
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.unsqueeze import UnsqueezeInternal
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.unique import UniqueInternal
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.scatter import ScatterUpdateInternal
|
||||
|
||||
# Operations not registered in collect_ops() function
|
||||
custom_ops = {
|
||||
@@ -51,6 +52,7 @@ custom_ops = {
|
||||
'MaxPool': Pooling,
|
||||
'Multiply': Mul,
|
||||
'Power': Pow,
|
||||
'ScatterUpdate': ScatterUpdateInternal,
|
||||
'Slice': OvSlice,
|
||||
'Split': Split,
|
||||
'Squeeze': SqueezeInternal,
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
import openvino.runtime.opset11 as opset11
|
||||
import openvino.runtime.opset10 as opset10
|
||||
from openvino.runtime import Model, serialize, Core
|
||||
from openvino.runtime import Model, serialize, Core, PartialShape, Dimension
|
||||
|
||||
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
|
||||
from openvino.tools.mo.utils.logger import init_logger
|
||||
@@ -193,3 +193,17 @@ class TestOps(unittest.TestCase):
|
||||
graph = TestOps.check_graph_can_save(model, 'strided_slice_model')
|
||||
strided_slice_node = graph.get_op_nodes(op="StridedSlice")[0]
|
||||
self.assertEqual(strided_slice_node["version"], "opset1")
|
||||
|
||||
def test_scatter_dynamic_shape(self):
|
||||
data_parameter = opset11.parameter(
|
||||
PartialShape.dynamic(Dimension(2)), name="Data", dtype=np.float32)
|
||||
shape_of = opset11.shape_of(data_parameter)
|
||||
gather = opset11.gather(shape_of, np.int32(1), 0)
|
||||
unsqueeze = opset11.unsqueeze(gather, 0)
|
||||
scatter = opset11.scatter_update(np.int64([0, 0]), np.int64([1]), unsqueeze, axis=0)
|
||||
mul = opset11.multiply(scatter, np.int64([1, 2]))
|
||||
reshape = opset11.reshape(data_parameter, mul, True)
|
||||
model = Model(reshape, [data_parameter])
|
||||
graph = TestOps.check_graph_can_save(model, 'scatter_dynamic_model')
|
||||
scatter_update_node = graph.get_op_nodes(op="ScatterUpdate")[0]
|
||||
self.assertListEqual(scatter_update_node.out_port(0).data.get_value().tolist(), [0, None])
|
||||
|
||||
Reference in New Issue
Block a user