Add support for aten::_set_item (#15643)
* Add support for aten::_set_item * Update loop.cpp * Update tests/layer_tests/pytorch_tests/test_set_item.py Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com> * Update test_set_item.py * Apply code review comments * Fix code style * Update tests/layer_tests/pytorch_tests/test_set_item.py --------- Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com>
This commit is contained in:
@@ -2,8 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/loop.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@@ -12,41 +13,30 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_loop(NodeContext& context) {
|
||||
auto loop = std::make_shared<opset10::Loop>(context.get_input(0), context.get_input(1));
|
||||
const auto& inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Loop must have at least 2 inputs.");
|
||||
auto loop = std::make_shared<ov::op::v5::Loop>(inputs[0], inputs[1]);
|
||||
auto decoder = context.get_decoder();
|
||||
FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1, "Loop must have 1 subgraph.");
|
||||
auto subgraph_decoder = decoder->get_subgraph_decoder(0);
|
||||
auto body = context.convert_subgraph(0);
|
||||
loop->set_function(body);
|
||||
opset10::Loop::SpecialBodyPorts spec_ports{0, 0};
|
||||
ov::op::v5::Loop::SpecialBodyPorts spec_ports{0, 0};
|
||||
loop->set_special_body_ports(spec_ports);
|
||||
|
||||
auto inputs = subgraph_decoder->inputs();
|
||||
std::set<size_t> input_idxs(inputs.begin(), inputs.end());
|
||||
std::map<size_t, ParameterVector> inputs_map;
|
||||
|
||||
auto body_parameters = body->get_parameters();
|
||||
// #0 parameter is counter
|
||||
for (size_t i = 1; i < body_parameters.size(); i++) {
|
||||
// #0 body parameter is counter; #0 loop input is counter, #1 loop input is condition
|
||||
// Connect other inputs
|
||||
for (size_t i = 2; i < inputs.size(); i++) {
|
||||
loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]});
|
||||
}
|
||||
// Connect inputs from external context
|
||||
for (auto i = inputs.size() - 1; i < body_parameters.size(); i++) {
|
||||
auto param = body_parameters[i];
|
||||
auto name = param->get_output_tensor(0).get_any_name();
|
||||
size_t input_idx = (size_t)std::stoll(name);
|
||||
if (inputs_map.count(input_idx)) {
|
||||
inputs_map[input_idx] = {param};
|
||||
} else {
|
||||
inputs_map[input_idx].push_back(param);
|
||||
}
|
||||
}
|
||||
for (const auto& input : inputs_map) {
|
||||
if (!input_idxs.count(input.first)) {
|
||||
auto external_output = context.get_tensor_from_model_or_create_input(input.first);
|
||||
loop->set_invariant_inputs(external_output, input.second);
|
||||
} else {
|
||||
auto external_output = context.get_tensor_from_model(input.first);
|
||||
if (external_output.get_node()) {
|
||||
loop->set_invariant_inputs(external_output, input.second);
|
||||
}
|
||||
}
|
||||
auto external_output = context.get_tensor_from_model_or_create_input(input_idx);
|
||||
loop->set_invariant_inputs(external_output, {param});
|
||||
}
|
||||
// TODO: Connect back edges (merged inputs)
|
||||
auto body_results = body->get_results();
|
||||
@@ -69,4 +59,4 @@ OutputVector translate_loop(NodeContext& context) {
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
||||
36
src/frontends/pytorch/src/op/set_item.cpp
Normal file
36
src/frontends/pytorch/src/op/set_item.cpp
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/scatter_update.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_set_item(NodeContext& context) {
|
||||
// schema: aten::_set_item.t(t[](a!) l, int idx, t(b -> *) el) -> t[](a!)
|
||||
// _set_item inserts element in list
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto input = context.get_input(0);
|
||||
auto idx = context.get_input(1);
|
||||
auto idx_unsqueezed = context.mark_node(std::make_shared<v0::Unsqueeze>(idx, zero));
|
||||
auto value = context.get_input(2);
|
||||
auto value_unsqueezed = context.mark_node(std::make_shared<v0::Unsqueeze>(value, zero));
|
||||
auto res = context.mark_node(std::make_shared<v3::ScatterUpdate>(input, idx_unsqueezed, value_unsqueezed, zero));
|
||||
context.mutate_input(0, res);
|
||||
return {res};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -91,6 +91,7 @@ OP_CONVERTER(translate_roll);
|
||||
OP_CONVERTER(translate_rsqrt);
|
||||
OP_CONVERTER(translate_rsub);
|
||||
OP_CONVERTER(translate_select);
|
||||
OP_CONVERTER(translate_set_item);
|
||||
OP_CONVERTER(translate_selu);
|
||||
OP_CONVERTER(translate_size);
|
||||
OP_CONVERTER(translate_slice);
|
||||
@@ -123,6 +124,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
{"aten::_set_item", op::translate_set_item},
|
||||
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
|
||||
{"aten::acos", op::translate_1to1_match_1_inputs<opset10::Acos>},
|
||||
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
|
||||
|
||||
@@ -70,10 +70,17 @@ class PytorchLayerTest:
|
||||
im = fe.load(decoder)
|
||||
om = fe.convert(im)
|
||||
|
||||
torch_inps = [torch.from_numpy(inp) if isinstance(inp, np.ndarray) else inp for inp in inputs]
|
||||
|
||||
params = om.get_parameters()
|
||||
# todo: support lists and dicts
|
||||
for i in range(len(inputs)):
|
||||
inp = inputs[i]
|
||||
if isinstance(inp, list):
|
||||
inputs[i] = np.array(inp)
|
||||
if inputs[i].dtype == np.int64:
|
||||
inputs[i] = inputs[i].astype(np.int32)
|
||||
inp = inputs[i]
|
||||
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
|
||||
params[i].set_element_type(self._type_map[inp.dtype.name])
|
||||
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
|
||||
@@ -90,7 +97,6 @@ class PytorchLayerTest:
|
||||
return
|
||||
|
||||
# Framework infer:
|
||||
torch_inps = [torch.from_numpy(inp) for inp in inputs]
|
||||
fw_res = model(*torch_inps)
|
||||
|
||||
if not isinstance(fw_res, (tuple)):
|
||||
|
||||
35
tests/layer_tests/pytorch_tests/test_set_item.py
Normal file
35
tests/layer_tests/pytorch_tests/test_set_item.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestSetItem(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return [np.random.randn(10).astype(np.int32).tolist()]
|
||||
|
||||
def create_model(self, idx):
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
class aten_set_item(torch.nn.Module):
|
||||
def __init__(self, idx):
|
||||
super(aten_set_item, self).__init__()
|
||||
self.idx = idx
|
||||
|
||||
def forward(self, x: List[int]):
|
||||
x[self.idx] = 0
|
||||
return torch.tensor(x).to(torch.int)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_set_item(idx), ref_net, "aten::_set_item"
|
||||
|
||||
@pytest.mark.parametrize("idx", [0, 1, pytest.param(-1, marks=pytest.mark.xfail(reason="103748 ov scatter do not support negative indices"))])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_set_item_list(self, idx, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(idx), ie_device, precision, ir_version)
|
||||
Reference in New Issue
Block a user