[PT FE]: fix aten::embedding realization for integer-like indicies an… (#15721)

* [PT FE]: fix aten::embedding realization for integer-like indicies and add tests

* more comments

* Update src/frontends/pytorch/src/op/embedding.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Ekaterina Aidova
2023-02-17 12:09:46 +04:00
committed by GitHub
parent f03a3321fc
commit 225f9b3801
2 changed files with 42 additions and 4 deletions

View File

@@ -4,6 +4,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "utils.hpp"
@@ -13,13 +14,17 @@ namespace pytorch {
namespace op {
OutputVector translate_embedding(NodeContext& context) {
// aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool
// sparse=False)
num_inputs_check(context, 5, 5);
auto data = context.get_input(0);
auto indices = context.get_input(1);
// TODO: find out the meaning of input idx 2
FRONT_END_OP_CONVERSION_CHECK(
context.const_input<bool>(3) == false && context.const_input<bool>(4) == false,
"Only False is supported on inputs with indexes 3 and 4 for aten::embedding translation");
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i64));
// skip parameters 2, 3, 4 used only during trainig:
// padding_idx - if specified, the entries at padding_idx do not contribute to the gradient
// scale_grad_by_freq - if given, this will scale gradients by the inverse of frequency of
// the words in the mini-batch.
// sparse - if True, gradient will be represented as sparse tensor
auto axis_0 = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(data, indices, axis_0))};
};

View File

@@ -0,0 +1,33 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestExp(PytorchLayerTest):
def _prepare_input(self, indicies_size, indicies_dtype):
import numpy as np
return (np.random.randint(0, 9, size=indicies_size).astype(indicies_dtype), np.random.randn(10, 10))
def create_model(self):
import torch
import torch.nn.functional as F
class aten_embedding(torch.nn.Module):
def forward(self, indicies, weight):
return F.embedding(indicies, weight)
ref_net = None
return aten_embedding(), ref_net, "aten::embedding"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("indicies_size", [1, 2, 3, 4])
@pytest.mark.parametrize("indicies_dtype", ["int", "int32"])
def test_exp(self, ie_device, precision, ir_version, indicies_size, indicies_dtype):
self._test(*self.create_model(), ie_device, precision, ir_version,
kwargs_to_prepare_input={"indicies_size": indicies_size, "indicies_dtype": indicies_dtype})