106 lines
4.3 KiB
Python
106 lines
4.3 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import numpy as np
|
|
|
|
from mo.graph.graph import Node, Graph
|
|
from mo.ops.op import Op
|
|
|
|
|
|
class EmbeddingBagBase(Op):
|
|
enabled = False
|
|
|
|
op = op_type = None
|
|
version = None
|
|
in_ports_count = None
|
|
|
|
def __init__(self, graph: Graph, attrs: dict):
|
|
super().__init__(graph, {
|
|
'op': self.op,
|
|
'type': self.op_type,
|
|
'version': self.version,
|
|
|
|
'infer': self.infer,
|
|
|
|
'in_ports_count': self.in_ports_count,
|
|
'out_ports_count': 1,
|
|
}, attrs)
|
|
|
|
@staticmethod
|
|
def infer(node: Node):
|
|
raise NotImplementedError('Please use specialized EmbeddingBag operation class, EmbeddingBagBase is base class')
|
|
|
|
|
|
class EmbeddingBagOffsetsSum(EmbeddingBagBase):
|
|
op = op_type = 'EmbeddingBagOffsetsSum'
|
|
version = 'opset3'
|
|
in_ports_count = 5
|
|
|
|
@staticmethod
|
|
def infer(node: Node):
|
|
name = node.soft_get('name', node.id)
|
|
|
|
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
|
|
assert len(connected_in_ports) >= 3 and all(p in connected_in_ports for p in [0, 1, 2]), \
|
|
"EmbeddingBagOffsetsSum should have at least 3 connected input port, but it doesn't " \
|
|
"for node: `{}`. Ports: {}".format(name, connected_in_ports)
|
|
|
|
weights_shape = node.in_port(0).data.get_shape()
|
|
assert len(weights_shape) >= 2,\
|
|
"EmbeddingBagOffsetsSum should have at least 2D weights for node: `{}`".format(name)
|
|
offsets_shape = node.in_port(2).data.get_shape()
|
|
assert offsets_shape is not None and len(offsets_shape) == 1,\
|
|
"Rank of the offsets in EmbeddingBagOffsetsSum should be equal to 1 for node: `{}`".format(name)
|
|
|
|
node.out_port(0).data.set_shape(np.concatenate((offsets_shape[:1], weights_shape[1:])))
|
|
|
|
|
|
class EmbeddingBagPackedSum(EmbeddingBagBase):
|
|
op = op_type = 'EmbeddingBagPackedSum'
|
|
version = 'opset3'
|
|
in_ports_count = 3
|
|
|
|
@staticmethod
|
|
def infer(node: Node):
|
|
name = node.soft_get('name', node.id)
|
|
|
|
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
|
|
assert len(connected_in_ports) >= 2 and all(p in connected_in_ports for p in [0, 1]), \
|
|
"EmbeddingBagPackedSum should have at least 2 connected input port, but it doesn't for node: `{}`. " \
|
|
"Ports: {}".format(name, connected_in_ports)
|
|
|
|
weights_shape = node.in_port(0).data.get_shape()
|
|
assert len(weights_shape) >= 2, \
|
|
"EmbeddingBagPackedSum should have at least 2D weights for node: `{}`".format(name)
|
|
input_shape = node.in_port(1).data.get_shape()
|
|
|
|
node.out_port(0).data.set_shape(np.concatenate((input_shape[:1], weights_shape[1:])))
|
|
|
|
|
|
class EmbeddingSegmentsSum(EmbeddingBagBase):
|
|
op = op_type = 'EmbeddingSegmentsSum'
|
|
version = 'opset3'
|
|
in_ports_count = 6
|
|
|
|
@staticmethod
|
|
def infer(node: Node):
|
|
name = node.soft_get('name', node.id)
|
|
|
|
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
|
|
assert len(connected_in_ports) >= 4 and all(p in connected_in_ports for p in [0, 1, 2, 3]), \
|
|
"EmbeddingSegmentsSum should have at least 4 connected input port, but it doesn't for node: `{}`. " \
|
|
"Ports: {}".format(name, connected_in_ports)
|
|
|
|
weights_shape = node.in_port(0).data.get_shape()
|
|
assert len(weights_shape) >= 2,\
|
|
"EmbeddingSegmentsSum should have at least 2D weights for node: `{}`".format(name)
|
|
indices_shape = node.in_port(1).data.get_shape()
|
|
segment_ids = node.in_port(2).data.get_shape()
|
|
assert len(indices_shape) == 1 and len(segment_ids) == 1 and indices_shape == segment_ids,\
|
|
"Both indices and segment_ids should have the same shape for node: `{}`".format(name)
|
|
num_segments = node.in_port(3).data.get_value()
|
|
assert num_segments is not None, "EmbeddingSegmentsSum should have a constant num_segments provided, but it " \
|
|
"doesn't for node: `{}`.".format(name)
|
|
output_shape = np.concatenate(([num_segments], weights_shape[1:]))
|
|
node.out_port(0).data.set_shape(output_shape)
|