[PT FE] Add support for GFPGAN model (#21371)
* [PT FE] Add support for GFPGAN model * Remove logs * Fix codestyle * Add support for aten::normal
This commit is contained in:
parent
cb5377fb1d
commit
007b6fd82c
@ -25,7 +25,11 @@ namespace op {
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
OutputVector make_random_normal(const NodeContext& context, Output<Node> sizes, element::Type target_type) {
|
||||
OutputVector make_random_normal(const NodeContext& context,
|
||||
const Output<Node>& sizes,
|
||||
element::Type target_type,
|
||||
const Output<Node>& scale_const,
|
||||
const Output<Node>& mean_const) {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<uint64_t> distrib(0, 9999);
|
||||
@ -57,8 +61,6 @@ OutputVector make_random_normal(const NodeContext& context, Output<Node> sizes,
|
||||
auto multiply_two_pi_uniform_2 = context.mark_node(std::make_shared<v1::Multiply>(multiply_two_pi, uniform_2));
|
||||
auto cos = context.mark_node(std::make_shared<v0::Cos>(multiply_two_pi_uniform_2));
|
||||
|
||||
auto scale_const = context.mark_node(v0::Constant::create(target_type, Shape{1}, {1}));
|
||||
auto mean_const = context.mark_node(v0::Constant::create(target_type, Shape{1}, {0}));
|
||||
auto sqrt_x_cos = context.mark_node(std::make_shared<v1::Multiply>(sqrt, cos));
|
||||
auto product = context.mark_node(std::make_shared<v1::Multiply>(scale_const, sqrt_x_cos));
|
||||
auto sum = context.mark_node(std::make_shared<v1::Add>(product, mean_const));
|
||||
@ -180,7 +182,9 @@ OutputVector translate_randn(const NodeContext& context) {
|
||||
// aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
// aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
|
||||
if (context.get_input_size() == 2 || context.get_input_size() == 3) {
|
||||
auto res = make_random_normal(context, sizes, dtype);
|
||||
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
|
||||
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
|
||||
auto res = make_random_normal(context, sizes, dtype, scale, mean);
|
||||
context.mutate_input(out_id, res[0]);
|
||||
return res;
|
||||
}
|
||||
@ -210,7 +214,9 @@ OutputVector translate_randn(const NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
|
||||
}
|
||||
}
|
||||
auto res = make_random_normal(context, sizes, dtype);
|
||||
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
|
||||
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
|
||||
auto res = make_random_normal(context, sizes, dtype, scale, mean);
|
||||
if (!dtype_applied) {
|
||||
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], convert_like_out));
|
||||
}
|
||||
@ -226,7 +232,9 @@ OutputVector translate_randn_like(const NodeContext& context) {
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
|
||||
auto dtype = element::f32;
|
||||
if (context.get_input_size() == 3) {
|
||||
auto res = make_random_normal(context, sizes, dtype);
|
||||
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
|
||||
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
|
||||
auto res = make_random_normal(context, sizes, dtype, scale, mean);
|
||||
context.mutate_input(2, res[0]);
|
||||
return res;
|
||||
}
|
||||
@ -246,7 +254,9 @@ OutputVector translate_randn_like(const NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
|
||||
}
|
||||
}
|
||||
auto res = make_random_normal(context, sizes, dtype);
|
||||
auto scale = context.mark_node(v0::Constant::create(dtype, Shape{1}, {1}));
|
||||
auto mean = context.mark_node(v0::Constant::create(dtype, Shape{1}, {0}));
|
||||
auto res = make_random_normal(context, sizes, dtype, scale, mean);
|
||||
if (!dtype_applied) {
|
||||
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], convert_like_out));
|
||||
}
|
||||
@ -283,6 +293,84 @@ OutputVector translate_randint(const NodeContext& context) {
|
||||
return {res};
|
||||
};
|
||||
|
||||
OutputVector translate_normal_(const NodeContext& context) {
|
||||
// aten::normal_(Tensor(a!) self, float mean=0., float std=1., *, Generator? generator=None) -> Tensor(a!)
|
||||
num_inputs_check(context, 3, 4);
|
||||
auto inp_tensor = context.get_input(0);
|
||||
auto mean = context.get_input(1);
|
||||
auto std = context.get_input(2);
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
|
||||
auto dtype = element::f32;
|
||||
auto res = make_random_normal(context, sizes, dtype, std, mean);
|
||||
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], inp_tensor));
|
||||
context.mutate_input(0, res[0]);
|
||||
return res;
|
||||
}
|
||||
|
||||
OutputVector translate_normal(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 8);
|
||||
auto mean = context.get_input(0);
|
||||
auto std = context.get_input(1);
|
||||
auto dtype = element::f32;
|
||||
if (context.get_input_size() == 3 || context.get_input_size() == 4) {
|
||||
// aten::normal.Tensor_float(Tensor mean, float std=1., *, Generator? generator=None) -> Tensor
|
||||
// aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
|
||||
// aten::normal.Tensor_float_out(Tensor mean, float std=1., *, Generator? generator=None, Tensor(a!) out) ->
|
||||
// Tensor(a!)
|
||||
// aten::normal.Tensor_float_out(Tensor mean, float std=1., *, Generator? generator=None, Tensor(a!)
|
||||
// out) -> Tensor(a!)
|
||||
// aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None,
|
||||
// Tensor(a!) out) -> Tensor(a!)
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(mean, element::i32));
|
||||
auto res = make_random_normal(context, sizes, dtype, std, mean);
|
||||
if (!context.input_is_none(3)) {
|
||||
// out
|
||||
auto out = context.get_input(3);
|
||||
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], out));
|
||||
context.mutate_input(3, res[0]);
|
||||
}
|
||||
return res;
|
||||
} else if (context.get_input_size() == 5) {
|
||||
// aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!)
|
||||
// out) -> Tensor(a!)
|
||||
auto sizes = context.get_input(2);
|
||||
auto res = make_random_normal(context, sizes, dtype, std, mean);
|
||||
if (!context.input_is_none(4)) {
|
||||
// out
|
||||
auto out = context.get_input(4);
|
||||
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], out));
|
||||
context.mutate_input(4, res[0]);
|
||||
}
|
||||
return res;
|
||||
} else if (context.get_input_size() == 8) {
|
||||
// aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType?
|
||||
// dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
auto sizes = context.get_input(2);
|
||||
Output<Node> convert_like_out;
|
||||
bool dtype_applied = true;
|
||||
if (!context.input_is_none(4)) {
|
||||
if (std::dynamic_pointer_cast<v0::Constant>(
|
||||
context.get_input_from_visible_context(3).get_node_shared_ptr())) {
|
||||
dtype = convert_dtype(context.const_input<int64_t>(4));
|
||||
} else if (const auto& fw_node = cast_fw_node(context.get_input(3).get_node_shared_ptr(), "prim::dtype")) {
|
||||
convert_like_out = fw_node->input_value(0);
|
||||
dtype_applied = false;
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
|
||||
}
|
||||
}
|
||||
auto res = make_random_normal(context, sizes, dtype, std, mean);
|
||||
if (!dtype_applied) {
|
||||
res[0] = context.mark_node(std::make_shared<v1::ConvertLike>(res[0], convert_like_out));
|
||||
}
|
||||
return res;
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false,
|
||||
"Unsupported number of inputs to aten::normal operation: ",
|
||||
context.get_input_size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -129,6 +129,8 @@ OP_CONVERTER(translate_new_zeros);
|
||||
OP_CONVERTER(translate_nms);
|
||||
OP_CONVERTER(translate_nonzero);
|
||||
OP_CONVERTER(translate_norm);
|
||||
OP_CONVERTER(translate_normal);
|
||||
OP_CONVERTER(translate_normal_);
|
||||
OP_CONVERTER(translate_not);
|
||||
OP_CONVERTER(translate_numel);
|
||||
OP_CONVERTER(translate_one_hot);
|
||||
@ -438,6 +440,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::new_zeros", op::translate_new_zeros},
|
||||
{"aten::nonzero", op::translate_nonzero},
|
||||
{"aten::norm", op::translate_norm},
|
||||
{"aten::normal", op::translate_normal},
|
||||
{"aten::normal_", op::translate_normal_},
|
||||
{"aten::numel", op::translate_numel},
|
||||
{"aten::numpy_T", op::translate_t},
|
||||
{"aten::one_hot", op::translate_one_hot},
|
||||
|
90
tests/layer_tests/pytorch_tests/test_rand.py
Normal file
90
tests/layer_tests/pytorch_tests/test_rand.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestInplaceNormal(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
|
||||
|
||||
def create_model(self, mean, std):
|
||||
class aten_normal(torch.nn.Module):
|
||||
def __init__(self, mean, std):
|
||||
super(aten_normal, self).__init__()
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def forward(self, x):
|
||||
x = x.to(torch.float32)
|
||||
return x.normal_(mean=self.mean, std=self.std), x
|
||||
|
||||
return aten_normal(mean, std), None, "aten::normal_"
|
||||
|
||||
@pytest.mark.parametrize("mean,std", [(0., 1.), (5., 20.)])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_inplace_normal(self, mean, std, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(mean, std),
|
||||
ie_device, precision, ir_version, custom_eps=1e30)
|
||||
|
||||
|
||||
class TestNormal(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
if isinstance(self.inputs, list):
|
||||
return (np.random.randn(*self.inputs).astype(np.float32),)
|
||||
return self.inputs
|
||||
|
||||
class aten_normal1(torch.nn.Module):
|
||||
def forward(self, mean, std):
|
||||
return torch.normal(mean, std)
|
||||
|
||||
class aten_normal2(torch.nn.Module):
|
||||
def forward(self, mean, std):
|
||||
x = torch.empty_like(mean, dtype=torch.float32)
|
||||
return torch.normal(mean, std, out=x), x
|
||||
|
||||
class aten_normal3(torch.nn.Module):
|
||||
def forward(self, mean):
|
||||
return torch.normal(mean)
|
||||
|
||||
class aten_normal4(torch.nn.Module):
|
||||
def forward(self, mean):
|
||||
x = torch.empty_like(mean, dtype=torch.float32)
|
||||
return torch.normal(mean, out=x), x
|
||||
|
||||
class aten_normal5(torch.nn.Module):
|
||||
def forward(self, mean):
|
||||
x = torch.empty_like(mean, dtype=torch.float32)
|
||||
return torch.normal(mean, 2., out=x), x
|
||||
|
||||
class aten_normal6(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x.to(torch.float32)
|
||||
return torch.normal(0., 1., x.shape)
|
||||
|
||||
class aten_normal7(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x.to(torch.float32)
|
||||
return torch.normal(0., 1., x.shape, out=x), x
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("model,inputs", [
|
||||
(aten_normal1(), (torch.arange(1., 11.).numpy(), torch.arange(1, 0, -0.1).numpy())),
|
||||
(aten_normal2(), (torch.arange(1., 11.).numpy(), torch.arange(1, 0, -0.1).numpy())),
|
||||
(aten_normal3(), (torch.arange(1., 11.).numpy(),)),
|
||||
(aten_normal4(), (torch.arange(1., 11.).numpy(),)),
|
||||
(aten_normal5(), (torch.arange(1., 11.).numpy(),)),
|
||||
(aten_normal6(), [1, 3, 224, 224]),
|
||||
(aten_normal7(), [1, 3, 224, 224]),
|
||||
])
|
||||
def test_inplace_normal(self, model, inputs, ie_device, precision, ir_version):
|
||||
self.inputs = inputs
|
||||
self._test(model, None, "aten::normal",
|
||||
ie_device, precision, ir_version, custom_eps=1e30)
|
@ -15,3 +15,5 @@ protobuf
|
||||
soundfile
|
||||
pandas
|
||||
super-image
|
||||
basicsr
|
||||
facexlib
|
||||
|
81
tests/model_hub_tests/torch_tests/test_gfpgan.py
Normal file
81
tests/model_hub_tests/torch_tests/test_gfpgan.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torch_utils import TestTorchConvertModel
|
||||
from openvino import convert_model
|
||||
import numpy as np
|
||||
|
||||
# To make tests reproducible we seed the random generator
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class TestGFPGANConvertModel(TestTorchConvertModel):
|
||||
def setup_class(self):
|
||||
self.repo_dir = tempfile.TemporaryDirectory()
|
||||
os.system(
|
||||
f"git clone https://github.com/TencentARC/GFPGAN.git {self.repo_dir.name}")
|
||||
subprocess.check_call(
|
||||
["git", "checkout", "bc5a5deb95a4a9653851177985d617af1b9bfa8b"], cwd=self.repo_dir.name)
|
||||
checkpoint_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
|
||||
subprocess.check_call(
|
||||
["wget", "-nv", checkpoint_url], cwd=self.repo_dir.name)
|
||||
|
||||
def load_model(self, model_name, model_link):
|
||||
sys.path.append(self.repo_dir.name)
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
filename = os.path.join(self.repo_dir.name, 'GFPGANv1.3.pth')
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
restorer = GFPGANer(
|
||||
model_path=filename,
|
||||
upscale=2,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=None)
|
||||
|
||||
self.example = (torch.randn(1, 3, 512, 512),)
|
||||
self.inputs = (torch.randn(1, 3, 512, 512),)
|
||||
return restorer.gfpgan
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
ov_model = convert_model(
|
||||
model_obj, example_input=self.example, input=[1, 3, 512, 512], verbose=True)
|
||||
return ov_model
|
||||
|
||||
def compare_results(self, fw_outputs, ov_outputs):
|
||||
assert len(fw_outputs) == len(ov_outputs), \
|
||||
"Different number of outputs between framework and OpenVINO:" \
|
||||
" {} vs. {}".format(len(fw_outputs), len(ov_outputs))
|
||||
|
||||
fw_eps = 5e-2
|
||||
is_ok = True
|
||||
for i in range(len(ov_outputs)):
|
||||
cur_fw_res = fw_outputs[i]
|
||||
cur_ov_res = ov_outputs[i]
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
cur_ov_res, cur_fw_res, fw_eps, fw_eps)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
# The model has aten::normal_ operation which produce random numbers.
|
||||
# Cannot reliably validate the output 0
|
||||
if i != 0:
|
||||
is_ok = False
|
||||
assert is_ok, "Accuracy validation failed"
|
||||
|
||||
def teardown_class(self):
|
||||
# remove all downloaded files from cache
|
||||
self.repo_dir.cleanup()
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test_convert_model(self, ie_device):
|
||||
self.run("GFPGAN", None, ie_device)
|
Loading…
Reference in New Issue
Block a user