[PT FE]: support aten::broadcast_to (#18899)
This commit is contained in:
parent
86b8e0a930
commit
968edc9375
@ -236,6 +236,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
|
||||
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
|
||||
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
|
||||
{"aten::broadcast_to", op::translate_expand},
|
||||
{"aten::baddbmm", op::translate_addmm},
|
||||
{"aten::batch_norm", op::translate_batch_norm},
|
||||
{"aten::bitwise_not", op::translate_bitwise_not},
|
||||
|
@ -11,50 +11,65 @@ class TestExpand(PytorchLayerTest):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3).astype(np.float32),)
|
||||
|
||||
def create_model(self, dim):
|
||||
def create_model(self, dim, op_type="expand"):
|
||||
import torch
|
||||
|
||||
class aten_expand(torch.nn.Module):
|
||||
def __init__(self, dims):
|
||||
def __init__(self, dims, op_type="expand"):
|
||||
super(aten_expand, self).__init__()
|
||||
self.dims = dims
|
||||
if op_type == "broadcast_to":
|
||||
self.forward = self.forward_broadcast
|
||||
|
||||
def forward(self, x):
|
||||
return x.expand(self.dims)
|
||||
|
||||
def forward_broadcast(self, x):
|
||||
return x.broadcast_to(self.dims)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_expand(dim), ref_net, "aten::expand"
|
||||
return aten_expand(dim, op_type), ref_net, f"aten::{op_type}"
|
||||
|
||||
@pytest.mark.parametrize("dims", [(4, 3), (-1, -1), (1, 2, 3), (1, 2, 2, 3)])
|
||||
@pytest.mark.parametrize("op_type", ["expand", "broadcast_to"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_expand(self, dims, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dims), ie_device, precision, ir_version)
|
||||
def test_expand(self, dims, op_type, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dims, op_type), ie_device, precision, ir_version)
|
||||
|
||||
class TestExpandList(PytorchLayerTest):
|
||||
def _prepare_input(self, broadcast_shape):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3).astype(np.float32), np.random.randn(*broadcast_shape).astype(np.float32))
|
||||
|
||||
def create_model(self):
|
||||
def create_model(self, op_type="expand"):
|
||||
import torch
|
||||
|
||||
class aten_expand(torch.nn.Module):
|
||||
def __init__(self, op_type="expand"):
|
||||
super(aten_expand, self).__init__()
|
||||
if op_type == "broadcast_to":
|
||||
self.forward = self.forward_broadcast
|
||||
|
||||
def forward(self, x, y):
|
||||
y_shape = y.shape
|
||||
return x.expand([y_shape[0], y_shape[1]])
|
||||
|
||||
def forward_broadcast(self, x, y):
|
||||
y_shape = y.shape
|
||||
return x.broadcast_to([y_shape[0], y_shape[1]])
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_expand(), ref_net, ["aten::expand", "prim::ListConstruct"]
|
||||
return aten_expand(op_type), ref_net, [f"aten::{op_type}", "prim::ListConstruct"]
|
||||
|
||||
@pytest.mark.parametrize("dims", [(3, 3), (2, 3), (1, 3), [4, 3]])
|
||||
@pytest.mark.parametrize("op_type", ["expand", "broadcast_to"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_expand(self, dims, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"broadcast_shape": dims})
|
||||
def test_expand(self, dims, op_type, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(op_type), ie_device, precision, ir_version, kwargs_to_prepare_input={"broadcast_shape": dims})
|
||||
|
||||
|
||||
class TestExpandAs(PytorchLayerTest):
|
||||
|
Loading…
Reference in New Issue
Block a user