Files
openvino/tests/layer_tests/pytorch_tests/test_if.py
Maxim Vafin abaf61d059 Improve detectron2 support (#16011)
* Improve op support for detectron mask rcnn

* Initial commit

* Fix for reading processed list

* Format code

* Cleanup

* cleanup

* Cleanup

* cleanup test

* Add comment

* Add rt_info

* fix type

* More fixes for detectron

* Fix build

* Add tests for if

* Revert changes in index

* Add comment

* Fix test

* Fix get_axes_range

* Add tests and fix if type alignment

* Fix code style

---------

Co-authored-by: Mateusz <mateusz.mikolajczyk@intel.com>
2023-03-23 22:30:03 +00:00

41 lines
1.2 KiB
Python

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest
class TestIf(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 3, 224, 224).astype(np.float32), self.y)
def create_model(self):
import torch
import torch.nn.functional as F
class prim_if(torch.nn.Module):
def __init__(self):
super(prim_if, self).__init__()
def forward(self, x, y):
if y > 0:
res = x.new_empty((0, 10), dtype=torch.uint8)
else:
res = torch.zeros(x.shape[:2], dtype=torch.bool)
return res.to(torch.bool)
ref_net = None
return prim_if(), ref_net, "prim::If"
@pytest.mark.parametrize("y", [np.array(1),
np.array(-1)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_if(self, y, ie_device, precision, ir_version):
self.y = y
self._test(*self.create_model(), ie_device, precision, ir_version)