[Pt FE]: aten::embedding_bag (#17098)
* [Pt FE]: aten::embedding_bag * Update src/frontends/pytorch/src/op_table.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
e7d94ba020
commit
66e1af18b5
69
src/frontends/pytorch/src/op/embedding_bag.cpp
Normal file
69
src/frontends/pytorch/src/op/embedding_bag.cpp
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/convert.hpp"
|
||||||
|
#include "openvino/op/convert_like.hpp"
|
||||||
|
#include "openvino/op/embeddingbag_offsets_sum.hpp"
|
||||||
|
#include "openvino/op/embeddingbag_packedsum.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace pytorch {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
OutputVector translate_embedding_bag(const NodeContext& context) {
|
||||||
|
// aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False,
|
||||||
|
// per_sample_weights=None, include_last_offset=False, padding_idx=None)
|
||||||
|
num_inputs_check(context, 9, 9);
|
||||||
|
// we have only EmbeddingBagSum case support, check it before translation
|
||||||
|
auto mode = context.const_input<int64_t>(4);
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation");
|
||||||
|
auto weight = context.get_input(0);
|
||||||
|
auto indices = context.get_input(1);
|
||||||
|
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i32));
|
||||||
|
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));
|
||||||
|
Output<Node> result;
|
||||||
|
// parameters scale_grad_by_freq, sparse, padding_idx have relation to gradient calculation for training, skip them
|
||||||
|
// no offsets case
|
||||||
|
if (context.input_is_none(2)) {
|
||||||
|
// no per_sample_weights
|
||||||
|
if (context.input_is_none(6)) {
|
||||||
|
result = context.mark_node(std::make_shared<ov::op::v3::EmbeddingBagPackedSum>(weight, indices));
|
||||||
|
} else {
|
||||||
|
auto per_sample_weight = context.get_input(6);
|
||||||
|
per_sample_weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(per_sample_weight, weight));
|
||||||
|
result = context.mark_node(
|
||||||
|
std::make_shared<ov::op::v3::EmbeddingBagPackedSum>(weight, indices, per_sample_weight));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// with offsets case
|
||||||
|
auto offsets = context.get_input(2);
|
||||||
|
offsets = context.mark_node(std::make_shared<ov::op::v0::Convert>(offsets, element::i32));
|
||||||
|
auto include_last_offset = context.const_input<bool>(7);
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported");
|
||||||
|
// no per_sample_wights
|
||||||
|
if (context.input_is_none(6)) {
|
||||||
|
result = context.mark_node(std::make_shared<ov::op::v3::EmbeddingBagOffsetsSum>(weight, indices, offsets));
|
||||||
|
} else {
|
||||||
|
auto per_sample_weight = context.get_input(6);
|
||||||
|
per_sample_weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(per_sample_weight, weight));
|
||||||
|
result = context.mark_node(std::make_shared<ov::op::v3::EmbeddingBagOffsetsSum>(weight,
|
||||||
|
indices,
|
||||||
|
offsets,
|
||||||
|
zero,
|
||||||
|
per_sample_weight));
|
||||||
|
}
|
||||||
|
// aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
|
||||||
|
// But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
|
||||||
|
}
|
||||||
|
return {result, zero, zero, zero};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pytorch
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -40,6 +40,7 @@ OP_CONVERTER(translate_dim);
|
|||||||
OP_CONVERTER(translate_div);
|
OP_CONVERTER(translate_div);
|
||||||
OP_CONVERTER(translate_elu);
|
OP_CONVERTER(translate_elu);
|
||||||
OP_CONVERTER(translate_embedding);
|
OP_CONVERTER(translate_embedding);
|
||||||
|
OP_CONVERTER(translate_embedding_bag);
|
||||||
OP_CONVERTER(translate_empty);
|
OP_CONVERTER(translate_empty);
|
||||||
OP_CONVERTER(translate_expand);
|
OP_CONVERTER(translate_expand);
|
||||||
OP_CONVERTER(translate_expand_as);
|
OP_CONVERTER(translate_expand_as);
|
||||||
@ -210,6 +211,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"aten::dropout_", op::skip_node},
|
{"aten::dropout_", op::skip_node},
|
||||||
{"aten::elu", op::translate_elu},
|
{"aten::elu", op::translate_elu},
|
||||||
{"aten::embedding", op::translate_embedding},
|
{"aten::embedding", op::translate_embedding},
|
||||||
|
{"aten::embedding_bag", op::translate_embedding_bag},
|
||||||
{"aten::empty", op::translate_empty},
|
{"aten::empty", op::translate_empty},
|
||||||
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
||||||
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},
|
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},
|
||||||
|
@ -6,10 +6,10 @@ import pytest
|
|||||||
from pytorch_layer_test_class import PytorchLayerTest
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
class TestExp(PytorchLayerTest):
|
class TestEmbedding(PytorchLayerTest):
|
||||||
def _prepare_input(self, indicies_size, indicies_dtype):
|
def _prepare_input(self, indicies_size, indicies_dtype):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
return (np.random.randint(0, 9, size=indicies_size).astype(indicies_dtype), np.random.randn(10, 10))
|
return (np.random.randint(0, 9, size=indicies_size).astype(indicies_dtype), np.random.randn(10, 10).astype(np.float32))
|
||||||
|
|
||||||
def create_model(self):
|
def create_model(self):
|
||||||
import torch
|
import torch
|
||||||
@ -28,6 +28,6 @@ class TestExp(PytorchLayerTest):
|
|||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
@pytest.mark.parametrize("indicies_size", [1, 2, 3, 4])
|
@pytest.mark.parametrize("indicies_size", [1, 2, 3, 4])
|
||||||
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
|
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
|
||||||
def test_exp(self, ie_device, precision, ir_version, indicies_size, indicies_dtype):
|
def test_embedding(self, ie_device, precision, ir_version, indicies_size, indicies_dtype):
|
||||||
self._test(*self.create_model(), ie_device, precision, ir_version,
|
self._test(*self.create_model(), ie_device, precision, ir_version,
|
||||||
kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype})
|
kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype})
|
91
tests/layer_tests/pytorch_tests/test_embedding_bag.py
Normal file
91
tests/layer_tests/pytorch_tests/test_embedding_bag.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingBag1dOffsets(PytorchLayerTest):
|
||||||
|
def _prepare_input(self, indicies_dtype, per_sample_weights=False):
|
||||||
|
import numpy as np
|
||||||
|
indices = np.array([2, 2, 2, 2, 4, 3, 2, 9]).astype(indicies_dtype)
|
||||||
|
weights = np.random.randn(10, 10).astype(np.float32)
|
||||||
|
offsets = np.array([0, 4]).astype(indicies_dtype)
|
||||||
|
if per_sample_weights:
|
||||||
|
per_sample_weights = np.random.randn(
|
||||||
|
*indices.shape).astype(np.float32)
|
||||||
|
return (indices, weights, offsets, per_sample_weights)
|
||||||
|
return (indices, weights, offsets)
|
||||||
|
|
||||||
|
def create_model(self, per_sample_weights):
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class aten_embedding_bag(torch.nn.Module):
|
||||||
|
def __init__(self, per_sample_weights=False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if per_sample_weights:
|
||||||
|
self.forward = self.forward_offsets_per_sample_weights
|
||||||
|
|
||||||
|
def forward(self, indicies, weight, offsets):
|
||||||
|
return F.embedding_bag(indicies, weight, offsets, mode="sum")
|
||||||
|
|
||||||
|
def forward_offsets_per_sample_weights(self, indicies, weight, offsets, per_sample_wights):
|
||||||
|
return F.embedding_bag(indicies, weight, offsets, mode="sum", per_sample_weights=per_sample_wights)
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_embedding_bag(per_sample_weights), ref_net, "aten::embedding_bag"
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
|
||||||
|
@pytest.mark.parametrize("per_sample_weights", [True, False])
|
||||||
|
def test_embedding_bag(self, ie_device, precision, ir_version, indicies_dtype, per_sample_weights):
|
||||||
|
self._test(*self.create_model(per_sample_weights), ie_device, precision, ir_version,
|
||||||
|
kwargs_to_prepare_input={"indicies_dtype": indicies_dtype, "per_sample_weights": per_sample_weights},
|
||||||
|
trace_model=True, dynamic_shapes=not per_sample_weights)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingBag2d(PytorchLayerTest):
|
||||||
|
def _prepare_input(self, indicies_size, indicies_dtype, per_sample_weights):
|
||||||
|
import numpy as np
|
||||||
|
indices = np.random.randint(
|
||||||
|
0, 9, size=indicies_size).astype(indicies_dtype)
|
||||||
|
weights = np.random.randn(10, 10).astype(np.float32)
|
||||||
|
if per_sample_weights:
|
||||||
|
per_sample_weights = np.random.randn(
|
||||||
|
*indices.shape).astype(np.float32)
|
||||||
|
return (indices, weights, per_sample_weights)
|
||||||
|
return (indices, weights)
|
||||||
|
|
||||||
|
def create_model(self, per_sample_weights):
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class aten_embedding_bag(torch.nn.Module):
|
||||||
|
def __init__(self, per_sample_weights=False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if per_sample_weights:
|
||||||
|
self.forward = self.forward_per_sample_weights
|
||||||
|
|
||||||
|
def forward(self, indicies, weight):
|
||||||
|
return F.embedding_bag(indicies, weight, mode="sum")
|
||||||
|
|
||||||
|
def forward_per_sample_weights(self, indicies, weight, per_sample_wights):
|
||||||
|
return F.embedding_bag(indicies, weight, mode="sum", per_sample_weights=per_sample_wights)
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_embedding_bag(per_sample_weights), ref_net, "aten::embedding_bag"
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
@pytest.mark.parametrize("indicies_size", [[1, 1], [2, 5], [3, 10], [4, 7]])
|
||||||
|
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
|
||||||
|
@pytest.mark.parametrize("per_sample_weights", [True, False])
|
||||||
|
def test_embedding_bag(self, ie_device, precision, ir_version, indicies_dtype, indicies_size, per_sample_weights):
|
||||||
|
self._test(*self.create_model(per_sample_weights), ie_device, precision, ir_version,
|
||||||
|
kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype, "per_sample_weights": per_sample_weights},
|
||||||
|
trace_model=True, dynamic_shapes=not per_sample_weights)
|
Loading…
Reference in New Issue
Block a user