diff --git a/src/frontends/pytorch/src/op/embedding_bag.cpp b/src/frontends/pytorch/src/op/embedding_bag.cpp new file mode 100644 index 00000000000..ee1cba3d1cf --- /dev/null +++ b/src/frontends/pytorch/src/op/embedding_bag.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/embeddingbag_offsets_sum.hpp" +#include "openvino/op/embeddingbag_packedsum.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_embedding_bag(const NodeContext& context) { + // aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False, + // per_sample_weights=None, include_last_offset=False, padding_idx=None) + num_inputs_check(context, 9, 9); + // we have only EmbeddingBagSum case support, check it before translation + auto mode = context.const_input(4); + FRONT_END_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation"); + auto weight = context.get_input(0); + auto indices = context.get_input(1); + indices = context.mark_node(std::make_shared(indices, element::i32)); + auto zero = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0})); + Output result; + // parameters scale_grad_by_freq, sparse, padding_idx have relation to gradient calculation for training, skip them + // no offsets case + if (context.input_is_none(2)) { + // no per_sample_weights + if (context.input_is_none(6)) { + result = context.mark_node(std::make_shared(weight, indices)); + } else { + auto per_sample_weight = context.get_input(6); + per_sample_weight = context.mark_node(std::make_shared(per_sample_weight, weight)); + result = context.mark_node( + std::make_shared(weight, indices, per_sample_weight)); + } + } else { + // with offsets case + auto offsets = context.get_input(2); + offsets = context.mark_node(std::make_shared(offsets, element::i32)); + auto include_last_offset = context.const_input(7); + FRONT_END_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported"); + // no per_sample_wights + if (context.input_is_none(6)) { + result = context.mark_node(std::make_shared(weight, indices, offsets)); + } else { + auto per_sample_weight = context.get_input(6); + per_sample_weight = context.mark_node(std::make_shared(per_sample_weight, weight)); + result = context.mark_node(std::make_shared(weight, + indices, + offsets, + zero, + per_sample_weight)); + } + // aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + // But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + } + return {result, zero, zero, zero}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index f23652f2eee..372f1fb6163 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -40,6 +40,7 @@ OP_CONVERTER(translate_dim); OP_CONVERTER(translate_div); OP_CONVERTER(translate_elu); OP_CONVERTER(translate_embedding); +OP_CONVERTER(translate_embedding_bag); OP_CONVERTER(translate_empty); OP_CONVERTER(translate_expand); OP_CONVERTER(translate_expand_as); @@ -210,6 +211,7 @@ const std::map get_supported_ops() { {"aten::dropout_", op::skip_node}, {"aten::elu", op::translate_elu}, {"aten::embedding", op::translate_embedding}, + {"aten::embedding_bag", op::translate_embedding_bag}, {"aten::empty", op::translate_empty}, {"aten::eq", op::translate_1to1_match_2_inputs_align_types}, {"aten::exp", op::translate_1to1_match_1_inputs}, diff --git a/tests/layer_tests/pytorch_tests/test_embedding.py b/tests/layer_tests/pytorch_tests/test_embedding.py index ad9637211c9..a448e21d3ff 100644 --- a/tests/layer_tests/pytorch_tests/test_embedding.py +++ b/tests/layer_tests/pytorch_tests/test_embedding.py @@ -6,10 +6,10 @@ import pytest from pytorch_layer_test_class import PytorchLayerTest -class TestExp(PytorchLayerTest): +class TestEmbedding(PytorchLayerTest): def _prepare_input(self, indicies_size, indicies_dtype): import numpy as np - return (np.random.randint(0, 9, size=indicies_size).astype(indicies_dtype), np.random.randn(10, 10)) + return (np.random.randint(0, 9, size=indicies_size).astype(indicies_dtype), np.random.randn(10, 10).astype(np.float32)) def create_model(self): import torch @@ -28,6 +28,6 @@ class TestExp(PytorchLayerTest): @pytest.mark.precommit @pytest.mark.parametrize("indicies_size", [1, 2, 3, 4]) @pytest.mark.parametrize("indicies_dtype", ["int", "int32"]) - def test_exp(self, ie_device, precision, ir_version, indicies_size, indicies_dtype): + def test_embedding(self, ie_device, precision, ir_version, indicies_size, indicies_dtype): self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype}) \ No newline at end of file diff --git a/tests/layer_tests/pytorch_tests/test_embedding_bag.py b/tests/layer_tests/pytorch_tests/test_embedding_bag.py new file mode 100644 index 00000000000..2595b226931 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_embedding_bag.py @@ -0,0 +1,91 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestEmbeddingBag1dOffsets(PytorchLayerTest): + def _prepare_input(self, indicies_dtype, per_sample_weights=False): + import numpy as np + indices = np.array([2, 2, 2, 2, 4, 3, 2, 9]).astype(indicies_dtype) + weights = np.random.randn(10, 10).astype(np.float32) + offsets = np.array([0, 4]).astype(indicies_dtype) + if per_sample_weights: + per_sample_weights = np.random.randn( + *indices.shape).astype(np.float32) + return (indices, weights, offsets, per_sample_weights) + return (indices, weights, offsets) + + def create_model(self, per_sample_weights): + import torch + import torch.nn.functional as F + + class aten_embedding_bag(torch.nn.Module): + def __init__(self, per_sample_weights=False) -> None: + super().__init__() + if per_sample_weights: + self.forward = self.forward_offsets_per_sample_weights + + def forward(self, indicies, weight, offsets): + return F.embedding_bag(indicies, weight, offsets, mode="sum") + + def forward_offsets_per_sample_weights(self, indicies, weight, offsets, per_sample_wights): + return F.embedding_bag(indicies, weight, offsets, mode="sum", per_sample_weights=per_sample_wights) + + ref_net = None + + return aten_embedding_bag(per_sample_weights), ref_net, "aten::embedding_bag" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("indicies_dtype", ["int", "int32"]) + @pytest.mark.parametrize("per_sample_weights", [True, False]) + def test_embedding_bag(self, ie_device, precision, ir_version, indicies_dtype, per_sample_weights): + self._test(*self.create_model(per_sample_weights), ie_device, precision, ir_version, + kwargs_to_prepare_input={"indicies_dtype": indicies_dtype, "per_sample_weights": per_sample_weights}, + trace_model=True, dynamic_shapes=not per_sample_weights) + + +class TestEmbeddingBag2d(PytorchLayerTest): + def _prepare_input(self, indicies_size, indicies_dtype, per_sample_weights): + import numpy as np + indices = np.random.randint( + 0, 9, size=indicies_size).astype(indicies_dtype) + weights = np.random.randn(10, 10).astype(np.float32) + if per_sample_weights: + per_sample_weights = np.random.randn( + *indices.shape).astype(np.float32) + return (indices, weights, per_sample_weights) + return (indices, weights) + + def create_model(self, per_sample_weights): + import torch + import torch.nn.functional as F + + class aten_embedding_bag(torch.nn.Module): + def __init__(self, per_sample_weights=False) -> None: + super().__init__() + if per_sample_weights: + self.forward = self.forward_per_sample_weights + + def forward(self, indicies, weight): + return F.embedding_bag(indicies, weight, mode="sum") + + def forward_per_sample_weights(self, indicies, weight, per_sample_wights): + return F.embedding_bag(indicies, weight, mode="sum", per_sample_weights=per_sample_wights) + + ref_net = None + + return aten_embedding_bag(per_sample_weights), ref_net, "aten::embedding_bag" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("indicies_size", [[1, 1], [2, 5], [3, 10], [4, 7]]) + @pytest.mark.parametrize("indicies_dtype", ["int", "int32"]) + @pytest.mark.parametrize("per_sample_weights", [True, False]) + def test_embedding_bag(self, ie_device, precision, ir_version, indicies_dtype, indicies_size, per_sample_weights): + self._test(*self.create_model(per_sample_weights), ie_device, precision, ir_version, + kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype, "per_sample_weights": per_sample_weights}, + trace_model=True, dynamic_shapes=not per_sample_weights)