[PT FE]: support aten::broadcast_to (#18899)

This commit is contained in:
Ekaterina Aidova 2023-08-02 08:38:34 +04:00 committed by GitHub
parent 86b8e0a930
commit 968edc9375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 9 deletions

View File

@ -236,6 +236,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::broadcast_to", op::translate_expand},
{"aten::baddbmm", op::translate_addmm},
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bitwise_not", op::translate_bitwise_not},

View File

@ -11,50 +11,65 @@ class TestExpand(PytorchLayerTest):
import numpy as np
return (np.random.randn(1, 3).astype(np.float32),)
def create_model(self, dim):
def create_model(self, dim, op_type="expand"):
import torch
class aten_expand(torch.nn.Module):
def __init__(self, dims):
def __init__(self, dims, op_type="expand"):
super(aten_expand, self).__init__()
self.dims = dims
if op_type == "broadcast_to":
self.forward = self.forward_broadcast
def forward(self, x):
return x.expand(self.dims)
def forward_broadcast(self, x):
return x.broadcast_to(self.dims)
ref_net = None
return aten_expand(dim), ref_net, "aten::expand"
return aten_expand(dim, op_type), ref_net, f"aten::{op_type}"
@pytest.mark.parametrize("dims", [(4, 3), (-1, -1), (1, 2, 3), (1, 2, 2, 3)])
@pytest.mark.parametrize("op_type", ["expand", "broadcast_to"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_expand(self, dims, ie_device, precision, ir_version):
self._test(*self.create_model(dims), ie_device, precision, ir_version)
def test_expand(self, dims, op_type, ie_device, precision, ir_version):
self._test(*self.create_model(dims, op_type), ie_device, precision, ir_version)
class TestExpandList(PytorchLayerTest):
def _prepare_input(self, broadcast_shape):
import numpy as np
return (np.random.randn(1, 3).astype(np.float32), np.random.randn(*broadcast_shape).astype(np.float32))
def create_model(self):
def create_model(self, op_type="expand"):
import torch
class aten_expand(torch.nn.Module):
def __init__(self, op_type="expand"):
super(aten_expand, self).__init__()
if op_type == "broadcast_to":
self.forward = self.forward_broadcast
def forward(self, x, y):
y_shape = y.shape
return x.expand([y_shape[0], y_shape[1]])
def forward_broadcast(self, x, y):
y_shape = y.shape
return x.broadcast_to([y_shape[0], y_shape[1]])
ref_net = None
return aten_expand(), ref_net, ["aten::expand", "prim::ListConstruct"]
return aten_expand(op_type), ref_net, [f"aten::{op_type}", "prim::ListConstruct"]
@pytest.mark.parametrize("dims", [(3, 3), (2, 3), (1, 3), [4, 3]])
@pytest.mark.parametrize("op_type", ["expand", "broadcast_to"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_expand(self, dims, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"broadcast_shape": dims})
def test_expand(self, dims, op_type, ie_device, precision, ir_version):
self._test(*self.create_model(op_type), ie_device, precision, ir_version, kwargs_to_prepare_input={"broadcast_shape": dims})
class TestExpandAs(PytorchLayerTest):