[Pt FE]: aten::embedding_bag (#17098)

* [Pt FE]: aten::embedding_bag

* Update src/frontends/pytorch/src/op_table.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Ekaterina Aidova 2023-05-10 11:44:08 +04:00 committed by GitHub
parent e7d94ba020
commit 66e1af18b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 3 deletions

View File

@ -0,0 +1,69 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/embeddingbag_offsets_sum.hpp"
#include "openvino/op/embeddingbag_packedsum.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_embedding_bag(const NodeContext& context) {
// aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False,
// per_sample_weights=None, include_last_offset=False, padding_idx=None)
num_inputs_check(context, 9, 9);
// we have only EmbeddingBagSum case support, check it before translation
auto mode = context.const_input<int64_t>(4);
FRONT_END_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation");
auto weight = context.get_input(0);
auto indices = context.get_input(1);
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i32));
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));
Output<Node> result;
// parameters scale_grad_by_freq, sparse, padding_idx have relation to gradient calculation for training, skip them
// no offsets case
if (context.input_is_none(2)) {
// no per_sample_weights
if (context.input_is_none(6)) {
result = context.mark_node(std::make_shared<ov::op::v3::EmbeddingBagPackedSum>(weight, indices));
} else {
auto per_sample_weight = context.get_input(6);
per_sample_weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(per_sample_weight, weight));
result = context.mark_node(
std::make_shared<ov::op::v3::EmbeddingBagPackedSum>(weight, indices, per_sample_weight));
}
} else {
// with offsets case
auto offsets = context.get_input(2);
offsets = context.mark_node(std::make_shared<ov::op::v0::Convert>(offsets, element::i32));
auto include_last_offset = context.const_input<bool>(7);
FRONT_END_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported");
// no per_sample_wights
if (context.input_is_none(6)) {
result = context.mark_node(std::make_shared<ov::op::v3::EmbeddingBagOffsetsSum>(weight, indices, offsets));
} else {
auto per_sample_weight = context.get_input(6);
per_sample_weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(per_sample_weight, weight));
result = context.mark_node(std::make_shared<ov::op::v3::EmbeddingBagOffsetsSum>(weight,
indices,
offsets,
zero,
per_sample_weight));
}
// aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
// But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
}
return {result, zero, zero, zero};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -40,6 +40,7 @@ OP_CONVERTER(translate_dim);
OP_CONVERTER(translate_div);
OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_embedding_bag);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_expand);
OP_CONVERTER(translate_expand_as);
@ -210,6 +211,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::dropout_", op::skip_node},
{"aten::elu", op::translate_elu},
{"aten::embedding", op::translate_embedding},
{"aten::embedding_bag", op::translate_embedding_bag},
{"aten::empty", op::translate_empty},
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},

View File

@ -6,10 +6,10 @@ import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestExp(PytorchLayerTest):
class TestEmbedding(PytorchLayerTest):
def _prepare_input(self, indicies_size, indicies_dtype):
import numpy as np
return (np.random.randint(0, 9, size=indicies_size).astype(indicies_dtype), np.random.randn(10, 10))
return (np.random.randint(0, 9, size=indicies_size).astype(indicies_dtype), np.random.randn(10, 10).astype(np.float32))
def create_model(self):
import torch
@ -28,6 +28,6 @@ class TestExp(PytorchLayerTest):
@pytest.mark.precommit
@pytest.mark.parametrize("indicies_size", [1, 2, 3, 4])
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
def test_exp(self, ie_device, precision, ir_version, indicies_size, indicies_dtype):
def test_embedding(self, ie_device, precision, ir_version, indicies_size, indicies_dtype):
self._test(*self.create_model(), ie_device, precision, ir_version,
kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype})

View File

@ -0,0 +1,91 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestEmbeddingBag1dOffsets(PytorchLayerTest):
def _prepare_input(self, indicies_dtype, per_sample_weights=False):
import numpy as np
indices = np.array([2, 2, 2, 2, 4, 3, 2, 9]).astype(indicies_dtype)
weights = np.random.randn(10, 10).astype(np.float32)
offsets = np.array([0, 4]).astype(indicies_dtype)
if per_sample_weights:
per_sample_weights = np.random.randn(
*indices.shape).astype(np.float32)
return (indices, weights, offsets, per_sample_weights)
return (indices, weights, offsets)
def create_model(self, per_sample_weights):
import torch
import torch.nn.functional as F
class aten_embedding_bag(torch.nn.Module):
def __init__(self, per_sample_weights=False) -> None:
super().__init__()
if per_sample_weights:
self.forward = self.forward_offsets_per_sample_weights
def forward(self, indicies, weight, offsets):
return F.embedding_bag(indicies, weight, offsets, mode="sum")
def forward_offsets_per_sample_weights(self, indicies, weight, offsets, per_sample_wights):
return F.embedding_bag(indicies, weight, offsets, mode="sum", per_sample_weights=per_sample_wights)
ref_net = None
return aten_embedding_bag(per_sample_weights), ref_net, "aten::embedding_bag"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
@pytest.mark.parametrize("per_sample_weights", [True, False])
def test_embedding_bag(self, ie_device, precision, ir_version, indicies_dtype, per_sample_weights):
self._test(*self.create_model(per_sample_weights), ie_device, precision, ir_version,
kwargs_to_prepare_input={"indicies_dtype": indicies_dtype, "per_sample_weights": per_sample_weights},
trace_model=True, dynamic_shapes=not per_sample_weights)
class TestEmbeddingBag2d(PytorchLayerTest):
def _prepare_input(self, indicies_size, indicies_dtype, per_sample_weights):
import numpy as np
indices = np.random.randint(
0, 9, size=indicies_size).astype(indicies_dtype)
weights = np.random.randn(10, 10).astype(np.float32)
if per_sample_weights:
per_sample_weights = np.random.randn(
*indices.shape).astype(np.float32)
return (indices, weights, per_sample_weights)
return (indices, weights)
def create_model(self, per_sample_weights):
import torch
import torch.nn.functional as F
class aten_embedding_bag(torch.nn.Module):
def __init__(self, per_sample_weights=False) -> None:
super().__init__()
if per_sample_weights:
self.forward = self.forward_per_sample_weights
def forward(self, indicies, weight):
return F.embedding_bag(indicies, weight, mode="sum")
def forward_per_sample_weights(self, indicies, weight, per_sample_wights):
return F.embedding_bag(indicies, weight, mode="sum", per_sample_weights=per_sample_wights)
ref_net = None
return aten_embedding_bag(per_sample_weights), ref_net, "aten::embedding_bag"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("indicies_size", [[1, 1], [2, 5], [3, 10], [4, 7]])
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
@pytest.mark.parametrize("per_sample_weights", [True, False])
def test_embedding_bag(self, ie_device, precision, ir_version, indicies_dtype, indicies_size, per_sample_weights):
self._test(*self.create_model(per_sample_weights), ie_device, precision, ir_version,
kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype, "per_sample_weights": per_sample_weights},
trace_model=True, dynamic_shapes=not per_sample_weights)