[PT FE]: support aten::broadcast_tensors (#19994)

* broadcast tensors

* [PT FE]: support aten::broadcast_tensors

* apply review comments

* remove add
This commit is contained in:
Ekaterina Aidova 2023-09-22 13:54:44 +04:00 committed by GitHub
parent 2151e5f979
commit 26d18c924b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 0 deletions

View File

@ -173,6 +173,29 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
} }
} }
if (auto broadcast_tensors = cast_fw_node(input_node, "aten::broadcast_tensors")) {
auto tensors = cast_fw_node(broadcast_tensors->input_value(0).get_node_shared_ptr(), "prim::ListConstruct");
if (!tensors) {
add_exception_to_fw_node(input_node,
"aten::broadcast_tensors: only prim::ListConstruct supported as input.");
return false;
}
Output<Node> final_shape_t = opset10::Constant::create(element::i32, Shape{}, {0});
for (auto input : tensors->inputs()) {
auto tensor_shape = rg.make<opset10::ShapeOf>(input.get_source_output(), element::i32);
final_shape_t =
rg.make<opset10::Broadcast>(final_shape_t, tensor_shape, ov::op::BroadcastType::BIDIRECTIONAL);
}
auto final_shape = rg.make<opset10::ShapeOf>(final_shape_t, element::i32);
OutputVector outputs;
for (auto input : tensors->inputs()) {
outputs.push_back(rg.make<opset10::Broadcast>(input.get_source_output(), final_shape));
}
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, outputs);
return true;
}
if (auto unbind = cast_fw_node(input_node, "aten::unbind")) { if (auto unbind = cast_fw_node(input_node, "aten::unbind")) {
const auto input = unbind->get_input_source_output(0); const auto input = unbind->get_input_source_output(0);
const auto axis = unbind->get_input_source_output(1); const auto axis = unbind->get_input_source_output(1);

View File

@ -0,0 +1,45 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestBroadcastTensors(PytorchLayerTest):
def _prepare_input(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype):
import numpy as np
return (
np.random.randn(*x_shape).astype(x_dtype),
np.random.randn(*y_shape).astype(y_dtype),
np.random.randn(*z_shape).astype(z_dtype))
def create_model(self):
import torch
class aten_broadcast_tensors(torch.nn.Module):
def __init__(self):
super(aten_broadcast_tensors, self).__init__()
def forward(self, x, y, z):
x1, y1, z1 = torch.broadcast_tensors(x, y, z)
return x1, y1, z1
ref_net = None
return aten_broadcast_tensors(), ref_net, ("prim::ListConstruct", "aten::broadcast_tensors", "prim::ListUnpack")
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("x_shape", [[1, ], [2, 1], [2, 2, 1]])
@pytest.mark.parametrize("y_shape", [[2, ], [1, 2], [1, 2, 1]])
@pytest.mark.parametrize("z_shape", [[1, 2], [2, 2], [1, 2, 1, 1]])
@pytest.mark.parametrize("x_dtype", ["float32", "int32"])
@pytest.mark.parametrize("y_dtype", ["float32", "int32"])
@pytest.mark.parametrize("z_dtype", ["float32", "int32"])
def test_broadcast_tensors(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"x_shape": x_shape, "x_dtype": x_dtype,
"y_shape": y_shape, "y_dtype": y_dtype,
"z_shape": z_shape, "z_dtype": z_dtype,
})