[PT FE]: support aten::broadcast_tensors (#19994)
* broadcast tensors * [PT FE]: support aten::broadcast_tensors * apply review comments * remove add
This commit is contained in:
parent
2151e5f979
commit
26d18c924b
@ -173,6 +173,29 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
}
|
||||
}
|
||||
|
||||
if (auto broadcast_tensors = cast_fw_node(input_node, "aten::broadcast_tensors")) {
|
||||
auto tensors = cast_fw_node(broadcast_tensors->input_value(0).get_node_shared_ptr(), "prim::ListConstruct");
|
||||
if (!tensors) {
|
||||
add_exception_to_fw_node(input_node,
|
||||
"aten::broadcast_tensors: only prim::ListConstruct supported as input.");
|
||||
return false;
|
||||
}
|
||||
Output<Node> final_shape_t = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
for (auto input : tensors->inputs()) {
|
||||
auto tensor_shape = rg.make<opset10::ShapeOf>(input.get_source_output(), element::i32);
|
||||
final_shape_t =
|
||||
rg.make<opset10::Broadcast>(final_shape_t, tensor_shape, ov::op::BroadcastType::BIDIRECTIONAL);
|
||||
}
|
||||
auto final_shape = rg.make<opset10::ShapeOf>(final_shape_t, element::i32);
|
||||
OutputVector outputs;
|
||||
for (auto input : tensors->inputs()) {
|
||||
outputs.push_back(rg.make<opset10::Broadcast>(input.get_source_output(), final_shape));
|
||||
}
|
||||
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||
replace_node(list_unpack, outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto unbind = cast_fw_node(input_node, "aten::unbind")) {
|
||||
const auto input = unbind->get_input_source_output(0);
|
||||
const auto axis = unbind->get_input_source_output(1);
|
||||
|
45
tests/layer_tests/pytorch_tests/test_broadcast_tensors.py
Normal file
45
tests/layer_tests/pytorch_tests/test_broadcast_tensors.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestBroadcastTensors(PytorchLayerTest):
|
||||
def _prepare_input(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype):
|
||||
import numpy as np
|
||||
return (
|
||||
np.random.randn(*x_shape).astype(x_dtype),
|
||||
np.random.randn(*y_shape).astype(y_dtype),
|
||||
np.random.randn(*z_shape).astype(z_dtype))
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_broadcast_tensors(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_broadcast_tensors, self).__init__()
|
||||
|
||||
def forward(self, x, y, z):
|
||||
x1, y1, z1 = torch.broadcast_tensors(x, y, z)
|
||||
return x1, y1, z1
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_broadcast_tensors(), ref_net, ("prim::ListConstruct", "aten::broadcast_tensors", "prim::ListUnpack")
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("x_shape", [[1, ], [2, 1], [2, 2, 1]])
|
||||
@pytest.mark.parametrize("y_shape", [[2, ], [1, 2], [1, 2, 1]])
|
||||
@pytest.mark.parametrize("z_shape", [[1, 2], [2, 2], [1, 2, 1, 1]])
|
||||
@pytest.mark.parametrize("x_dtype", ["float32", "int32"])
|
||||
@pytest.mark.parametrize("y_dtype", ["float32", "int32"])
|
||||
@pytest.mark.parametrize("z_dtype", ["float32", "int32"])
|
||||
def test_broadcast_tensors(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||
"x_shape": x_shape, "x_dtype": x_dtype,
|
||||
"y_shape": y_shape, "y_dtype": y_dtype,
|
||||
"z_shape": z_shape, "z_dtype": z_dtype,
|
||||
})
|
Loading…
Reference in New Issue
Block a user