Files
openvino/model-optimizer/extensions/ops/embedding_bag.py
Maxim Vafin f1811ad060 Implement support for opset3 EmbeddingBag ops (#546)
* [MO] Implement EmbeddingBag_3

* Transform dynamic sub-graph of Wide and Deep into EmbeddingSegmentsSum

- Expressed SparseWeightedSum sub-graph through EmbeddingSegmentsSum
- Removed experimental SparseWeightedSum layer
- Implemented tests for the transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix EmbeddingBag shape infer

* Fix EmbeddingSegmentsSum transformation for Wide and Deep

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix EmbeddingSegmentSum replacer after ports swap

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Update package_BOM.txt

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add unit tests for EmbeddingXXX shape infer

* Fix ATen resolver

* Remove deleted files from BOM

* Add opset version to embedding_bag

* Use base class for EmbeddingBag

* Fix per_sample_weights case

* Fix EmbeddingSegmentsSum transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix EmbeddingBag checks

* Fix ATen front transformation and merge conflicts

* Fix BOM

* Work around limitation for I64 input of W&D model

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Cleanup where operation to fix affect of WhereDecomposition transform

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix BOM

* Correct EmbeddingSegmentSum transform for Wide and Deep

Add casting segment ids to i32 and remove ConstToResult sub-graph.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Update BOM with RemoveConstToResult transform

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add more comments for RemoveConstToResult transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Remove useless logging in EmbeddingSegmentsSum transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Small fixes

* Move EmbeddingBag resolving back to front phase

* Improve error messages

* Fix typo in unittests

* Reimplement sparse_reshape middle transform

Avoid deprecated API.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Clean-up graph after sparse_reshape and ConstToResult transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix clean-up for transformations

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix clean-up for transformation #2

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
2020-06-08 18:06:40 +03:00

119 lines
4.8 KiB
Python

"""
Copyright (c) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
class EmbeddingBagBase(Op):
enabled = False
op = op_type = None
version = None
in_ports_count = None
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': self.op,
'type': self.op_type,
'version': self.version,
'infer': self.infer,
'in_ports_count': self.in_ports_count,
'out_ports_count': 1,
}, attrs)
@staticmethod
def infer(node: Node):
raise NotImplementedError('Please use specialized EmbeddingBag operation class, EmbeddingBagBase is base class')
class EmbeddingBagOffsetsSum(EmbeddingBagBase):
op = op_type = 'EmbeddingBagOffsetsSum'
version = 'opset3'
in_ports_count = 5
@staticmethod
def infer(node: Node):
name = node.soft_get('name', node.id)
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
assert len(connected_in_ports) >= 3 and all(p in connected_in_ports for p in [0, 1, 2]), \
"EmbeddingBagOffsetsSum should have at least 3 connected input port, but it doesn't " \
"for node: `{}`. Ports: {}".format(name, connected_in_ports)
weights_shape = node.in_port(0).data.get_shape()
assert len(weights_shape) >= 2,\
"EmbeddingBagOffsetsSum should have at least 2D weights for node: `{}`".format(name)
offsets_shape = node.in_port(2).data.get_shape()
assert offsets_shape is not None and len(offsets_shape) == 1,\
"Rank of the offsets in EmbeddingBagOffsetsSum should be equal to 1 for node: `{}`".format(name)
node.out_port(0).data.set_shape(np.concatenate((offsets_shape[:1], weights_shape[1:])))
class EmbeddingBagPackedSum(EmbeddingBagBase):
op = op_type = 'EmbeddingBagPackedSum'
version = 'opset3'
in_ports_count = 3
@staticmethod
def infer(node: Node):
name = node.soft_get('name', node.id)
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
assert len(connected_in_ports) >= 2 and all(p in connected_in_ports for p in [0, 1]), \
"EmbeddingBagPackedSum should have at least 2 connected input port, but it doesn't for node: `{}`. " \
"Ports: {}".format(name, connected_in_ports)
weights_shape = node.in_port(0).data.get_shape()
assert len(weights_shape) >= 2, \
"EmbeddingBagPackedSum should have at least 2D weights for node: `{}`".format(name)
input_shape = node.in_port(1).data.get_shape()
node.out_port(0).data.set_shape(np.concatenate((input_shape[:1], weights_shape[1:])))
class EmbeddingSegmentsSum(EmbeddingBagBase):
op = op_type = 'EmbeddingSegmentsSum'
version = 'opset3'
in_ports_count = 6
@staticmethod
def infer(node: Node):
name = node.soft_get('name', node.id)
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
assert len(connected_in_ports) >= 4 and all(p in connected_in_ports for p in [0, 1, 2, 3]), \
"EmbeddingSegmentsSum should have at least 4 connected input port, but it doesn't for node: `{}`. " \
"Ports: {}".format(name, connected_in_ports)
weights_shape = node.in_port(0).data.get_shape()
assert len(weights_shape) >= 2,\
"EmbeddingSegmentsSum should have at least 2D weights for node: `{}`".format(name)
indices_shape = node.in_port(1).data.get_shape()
segment_ids = node.in_port(2).data.get_shape()
assert len(indices_shape) == 1 and len(segment_ids) == 1 and indices_shape == segment_ids,\
"Both indices and segment_ids should have the same shape for node: `{}`".format(name)
num_segments = node.in_port(3).data.get_value()
assert num_segments is not None, "EmbeddingSegmentsSum should have a constant num_segments provided, but it " \
"doesn't for node: `{}`.".format(name)
output_shape = np.concatenate(([num_segments], weights_shape[1:]))
node.out_port(0).data.set_shape(output_shape)