RGB to Gray ConvertColor Preprocessor (#13855)

* Add RGB to Gray ConvertColor Preprocessor

* Revert some changes

* Additional updates

* Polish RGBToGray implementation

Move Conv computations to FP32 and add Round for INT cases

* Extend tests - add comparison with OpenCV

* ClangFormat

* Return commented code

* Make `pre_post_process` tests parametrized

* Add shape calculation for RGB ColorFormat. Make RGB->GRAY 4D only

* Update ov_core_unit_tests

* Update ov_template_func_tests

* Clang Formatting

* Fix python preprocess tests

* Fix warning: truncation from 'double' to 'float'
of RGB->GRAY weights

* Fix flake8 issues

* Fix `convert_color_asserts` test

* Treat GRAY as NHWC

* Resolve review issues with tests

* Fix Py preprocess tests after rebase

* Clang Format

* Add case if C < 0

* Add reference link to RGB coefficients

* Add doxygen documentation for ColorFormat enum
This commit is contained in:
Vitaliy Urusovskij 2023-01-13 13:50:30 +04:00 committed by GitHub
parent ded9156b98
commit 32e74f339e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 419 additions and 50 deletions

View File

@ -199,7 +199,7 @@ def test_graph_preprocess_output_postprocess():
def test_graph_preprocess_spatial_static_shape():
shape = [2, 2, 2]
shape = [3, 2, 2]
parameter_a = ops.parameter(shape, dtype=np.int32, name="A")
model = parameter_a
function = Model(model, [parameter_a], "TestFunction")
@ -210,7 +210,7 @@ def test_graph_preprocess_spatial_static_shape():
ppp = PrePostProcessor(function)
inp = ppp.input()
inp.tensor().set_layout(layout).set_spatial_static_shape(2, 2).set_color_format(color_format)
inp.preprocess().convert_element_type(Type.f32).mean([1., 2.])
inp.preprocess().convert_element_type(Type.f32).mean([1., 2., 3.])
inp.model().set_layout(layout)
out = ppp.output()
out.tensor().set_layout(layout).set_element_type(Type.f32)
@ -227,7 +227,7 @@ def test_graph_preprocess_spatial_static_shape():
]
assert len(model_operators) == 7
assert function.get_output_size() == 1
assert list(function.get_output_shape(0)) == [2, 2, 2]
assert list(function.get_output_shape(0)) == [3, 2, 2]
assert function.get_output_element_type(0) == Type.f32
for op in expected_ops:
assert op in model_operators
@ -363,9 +363,10 @@ def test_graph_preprocess_set_memory_type():
(ResizeAlgorithm.RESIZE_NEAREST, ColorFormat.BGR, ColorFormat.I420_THREE_PLANES, True),
(ResizeAlgorithm.RESIZE_NEAREST, ColorFormat.BGR, ColorFormat.NV12_SINGLE_PLANE, True),
(ResizeAlgorithm.RESIZE_NEAREST, ColorFormat.BGR, ColorFormat.NV12_TWO_PLANES, True),
(ResizeAlgorithm.RESIZE_NEAREST, ColorFormat.BGR, ColorFormat.UNDEFINED, True)])
(ResizeAlgorithm.RESIZE_NEAREST, ColorFormat.BGR, ColorFormat.UNDEFINED, True),
])
def test_graph_preprocess_steps(algorithm, color_format1, color_format2, is_failing):
shape = [1, 1, 3, 3]
shape = [1, 3, 3, 3]
parameter_a = ops.parameter(shape, dtype=np.float32, name="A")
model = parameter_a
function = Model(model, [parameter_a], "TestFunction")
@ -394,7 +395,7 @@ def test_graph_preprocess_steps(algorithm, color_format1, color_format2, is_fail
]
assert len(model_operators) == 16
assert function.get_output_size() == 1
assert list(function.get_output_shape(0)) == [1, 1, 3, 3]
assert list(function.get_output_shape(0)) == [1, 3, 3, 3]
assert function.get_output_element_type(0) == Type.f32
for op in expected_ops:
assert op in model_operators

View File

@ -9,19 +9,16 @@ namespace preprocess {
/// \brief Color format enumeration for conversion
enum class ColorFormat {
UNDEFINED,
NV12_SINGLE_PLANE, // Image in NV12 format as single tensor
/// \brief Image in NV12 format represented as separate tensors for Y and UV planes.
NV12_TWO_PLANES,
I420_SINGLE_PLANE, // Image in I420 (YUV) format as single tensor
/// \brief Image in I420 format represented as separate tensors for Y, U and V planes.
I420_THREE_PLANES,
RGB,
BGR,
/// \brief Image in RGBX interleaved format (4 channels)
RGBX,
/// \brief Image in BGRX interleaved format (4 channels)
BGRX
UNDEFINED, //!< Undefined color format
NV12_SINGLE_PLANE, //!< Image in NV12 format represented as separate tensors for Y and UV planes.
NV12_TWO_PLANES, //!< Image in NV12 format represented as separate tensors for Y and UV planes.
I420_SINGLE_PLANE, //!< Image in I420 (YUV) format as single tensor
I420_THREE_PLANES, //!< Image in I420 format represented as separate tensors for Y, U and V planes.
RGB, //!< Image in RGB interleaved format (3 channels)
BGR, //!< Image in BGR interleaved format (3 channels)
GRAY, //!< Image in GRAY format (1 channel)
RGBX, //!< Image in RGBX interleaved format (4 channels)
BGRX //!< Image in BGRX interleaved format (4 channels)
};
} // namespace preprocess

View File

@ -20,13 +20,18 @@ std::unique_ptr<ColorFormatInfo> ColorFormatInfo::get(ColorFormat format) {
res.reset(new ColorFormatInfoI420_ThreePlanes(format));
break;
case ColorFormat::RGB:
res.reset(new ColorFormatRGB(format));
break;
case ColorFormat::BGR:
res.reset(new ColorFormatNHWC(format));
res.reset(new ColorFormatBGR(format));
break;
case ColorFormat::RGBX:
case ColorFormat::BGRX:
res.reset(new ColorFormatInfo_RGBX_Base(format));
break;
case ColorFormat::GRAY:
res.reset(new ColorFormatNHWC(format));
break;
default:
res.reset(new ColorFormatInfo(format));
break;

View File

@ -43,6 +43,9 @@ inline std::string color_format_name(ColorFormat format) {
case ColorFormat::BGRX:
name = "BGRX";
break;
case ColorFormat::GRAY:
name = "GRAY";
break;
default:
name = "Unknown";
break;
@ -65,16 +68,17 @@ public:
return {};
}
// Calculate shape of plane based image shape in NHWC format
PartialShape shape(size_t plane_num, const PartialShape& image_src_shape) const {
PartialShape shape(size_t plane_num, const PartialShape& src_shape, const Layout& src_layout) const {
OPENVINO_ASSERT(plane_num < planes_count(),
"Internal error: incorrect plane number specified for color format");
return calculate_shape(plane_num, image_src_shape);
return calculate_shape(plane_num, src_shape, src_layout);
}
protected:
virtual PartialShape calculate_shape(size_t plane_num, const PartialShape& image_shape) const {
return image_shape;
virtual PartialShape calculate_shape(size_t plane_num,
const PartialShape& src_shape,
const Layout& src_layout) const {
return src_shape;
}
explicit ColorFormatInfo(ColorFormat format) : m_format(format) {}
@ -91,15 +95,38 @@ public:
}
};
// Assume that it should work with NHWC and NCHW cases, but default is NHWC
class ColorFormatRGB : public ColorFormatNHWC {
public:
explicit ColorFormatRGB(ColorFormat format) : ColorFormatNHWC(format) {}
protected:
PartialShape calculate_shape(size_t plane_num,
const PartialShape& src_shape,
const Layout& src_layout) const override {
PartialShape result = src_shape;
if (src_shape.rank().is_static()) {
auto c_idx = ov::layout::channels_idx(src_layout);
c_idx = c_idx < 0 ? c_idx + src_shape.size() : c_idx;
result[c_idx] = 3;
}
return result;
}
};
using ColorFormatBGR = ColorFormatRGB;
// Applicable for both NV12 and I420 formats
class ColorFormatInfoYUV420_Single : public ColorFormatNHWC {
public:
explicit ColorFormatInfoYUV420_Single(ColorFormat format) : ColorFormatNHWC(format) {}
protected:
PartialShape calculate_shape(size_t plane_num, const PartialShape& image_shape) const override {
PartialShape result = image_shape;
if (image_shape.rank().is_static() && image_shape.rank().get_length() == 4) {
PartialShape calculate_shape(size_t plane_num,
const PartialShape& src_shape,
const Layout& src_layout) const override {
PartialShape result = src_shape;
if (src_shape.rank().is_static() && src_shape.rank().get_length() == 4) {
result[3] = 1;
if (result[1].is_static()) {
result[1] = result[1].get_length() * 3 / 2;
@ -118,9 +145,11 @@ public:
}
protected:
PartialShape calculate_shape(size_t plane_num, const PartialShape& image_shape) const override {
PartialShape result = image_shape;
if (image_shape.rank().is_static() && image_shape.rank().get_length() == 4) {
PartialShape calculate_shape(size_t plane_num,
const PartialShape& src_shape,
const Layout& src_layout) const override {
PartialShape result = src_shape;
if (src_shape.rank().is_static() && src_shape.rank().get_length() == 4) {
if (plane_num == 0) {
result[3] = 1;
return result;
@ -148,9 +177,11 @@ public:
}
protected:
PartialShape calculate_shape(size_t plane_num, const PartialShape& image_shape) const override {
PartialShape result = image_shape;
if (image_shape.rank().is_static() && image_shape.rank().get_length() == 4) {
PartialShape calculate_shape(size_t plane_num,
const PartialShape& src_shape,
const Layout& src_layout) const override {
PartialShape result = src_shape;
if (src_shape.rank().is_static() && src_shape.rank().get_length() == 4) {
result[3] = 1; // Number of channels is always 1 for I420 planes
if (plane_num == 0) {
return result;
@ -173,9 +204,11 @@ public:
explicit ColorFormatInfo_RGBX_Base(ColorFormat format) : ColorFormatNHWC(format) {}
protected:
PartialShape calculate_shape(size_t plane_num, const PartialShape& image_shape) const override {
PartialShape result = image_shape;
if (image_shape.rank().is_static() && image_shape.rank().get_length() == 4) {
PartialShape calculate_shape(size_t plane_num,
const PartialShape& src_shape,
const Layout& src_layout) const override {
PartialShape result = src_shape;
if (src_shape.rank().is_static() && src_shape.rank().get_length() == 4) {
result[3] = 4;
return result;
}

View File

@ -95,7 +95,7 @@ InputInfo::InputInfoImpl::InputInfoData InputInfo::InputInfoImpl::create_new_par
// Create separate parameter for each plane. Shape is based on color format
for (size_t plane = 0; plane < color_info->planes_count(); plane++) {
auto plane_shape = color_info->shape(plane, new_param_shape);
auto plane_shape = color_info->shape(plane, new_param_shape, res.m_tensor_layout);
auto plane_param = std::make_shared<opset8::Parameter>(tensor_elem_type, plane_shape);
if (plane < get_tensor_data()->planes_sub_names().size()) {
std::unordered_set<std::string> plane_tensor_names;
@ -163,7 +163,7 @@ bool InputInfo::InputInfoImpl::build(const std::shared_ptr<Model>& model,
context.model_shape() = data.m_param->get_partial_shape();
context.target_element_type() = data.m_param->get_element_type();
// 2. Apply preprocessing
// Apply preprocessing
auto nodes = data.as_nodes();
for (const auto& action : get_preprocess()->actions()) {
auto action_result = action.m_op(nodes, model, context);
@ -176,11 +176,12 @@ bool InputInfo::InputInfoImpl::build(const std::shared_ptr<Model>& model,
"preprocessing operation. Current format is '",
color_format_name(context.color_format()),
"'");
OPENVINO_ASSERT(is_rgb_family(context.color_format()) || context.color_format() == ColorFormat::UNDEFINED,
"model shall have RGB/BGR color format. Consider add 'convert_color' preprocessing operation "
OPENVINO_ASSERT(is_rgb_family(context.color_format()) || context.color_format() == ColorFormat::GRAY ||
context.color_format() == ColorFormat::UNDEFINED,
"model shall have RGB/BGR/GRAY color format. Consider add 'convert_color' preprocessing operation "
"to convert current color format '",
color_format_name(context.color_format()),
"'to RGB/BGR");
"'to RGB/BGR/GRAY");
// Implicit: Convert element type + layout to user's tensor implicitly
auto implicit_steps = create_implicit_steps(context, nodes[0].get_element_type());
@ -193,6 +194,7 @@ bool InputInfo::InputInfoImpl::build(const std::shared_ptr<Model>& model,
if (node.get_partial_shape() != context.model_shape()) {
need_validate = true; // Trigger revalidation if input parameter shape is changed
}
// Check final shape
OPENVINO_ASSERT(node.get_partial_shape().compatible(context.model_shape()),
"Resulting shape '",

View File

@ -428,6 +428,64 @@ void PreStepsList::add_convert_color_impl(const ColorFormat& dst_format) {
context.color_format() = dst_format;
return res;
}
if ((context.color_format() == ColorFormat::RGB || context.color_format() == ColorFormat::BGR) &&
(dst_format == ColorFormat::GRAY)) {
auto node = nodes[0];
auto elem_type = node.get_element_type();
auto shape = node.get_partial_shape();
OPENVINO_ASSERT(shape.size() == 4,
"Input shape size should be equal to 4, actual size: ",
shape.size());
auto channels_idx = get_and_check_channels_idx(context.layout(), shape);
OPENVINO_ASSERT(shape[channels_idx] == 3,
"Channels dimesion should be equal to 3, actual value: ",
shape[channels_idx]);
auto is_transposed = false, is_converted = false;
if (channels_idx + 1 == shape.size()) {
// Transpose N...C to NC...
auto permutation = layout::utils::find_permutation(context.layout(), shape, ov::Layout{"NC..."});
auto perm_constant =
op::v0::Constant::create<int64_t>(element::i64, Shape{permutation.size()}, permutation);
node = std::make_shared<op::v1::Transpose>(node, perm_constant);
is_transposed = true;
}
if (elem_type.is_integral_number()) {
// Compute in floats due weights are floats
node = std::make_shared<op::v0::Convert>(node, element::f32);
is_converted = true;
}
// RGB coefficients were used from https://docs.opencv.org/3.4/de/d25/imgproc_color_conversions.html
auto weights_data = context.color_format() == ColorFormat::RGB
? std::vector<float>{0.299f, 0.587f, 0.114f}
: std::vector<float>{0.114f, 0.587f, 0.299f};
auto weights_shape = ov::Shape(shape.size(), 1);
weights_shape[1] = 3; // Set kernel layout to [1, 3, 1, ...]
auto weights_node = std::make_shared<ov::op::v0::Constant>(element::f32, weights_shape, weights_data);
node = std::make_shared<ov::op::v1::Convolution>(node,
weights_node,
ov::Strides(weights_shape.size() - 2, 1),
ov::CoordinateDiff(weights_shape.size() - 2, 0),
ov::CoordinateDiff(weights_shape.size() - 2, 0),
ov::Strides(weights_shape.size() - 2, 1));
if (is_converted) {
// Round values according to OpenCV rule before converting to integral values
auto round_val =
std::make_shared<ov::op::v5::Round>(node, ov::op::v5::Round::RoundMode::HALF_TO_EVEN);
node = std::make_shared<op::v0::Convert>(round_val, elem_type);
}
if (is_transposed) {
// Return NC... to N...C
auto permutation = layout::utils::find_permutation(ov::Layout{"NC..."}, shape, context.layout());
auto perm_constant =
op::v0::Constant::create<int64_t>(element::i64, Shape{permutation.size()}, permutation);
node = std::make_shared<op::v1::Transpose>(node, perm_constant);
}
context.color_format() = dst_format;
return std::make_tuple(std::vector<Output<Node>>{node}, true);
}
if (context.color_format() == ColorFormat::RGBX) {
if (dst_format == ColorFormat::RGB) {
auto res = cut_last_channel(nodes, function, context);

View File

@ -2,12 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops.hpp"
#include "openvino/core/preprocess/pre_post_process.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/util/common_util.hpp"
#include "preprocess/color_utils.hpp"
#include "util/test_tools.hpp"
using namespace ov;
@ -209,6 +211,140 @@ TEST(pre_post_process, tensor_element_type_and_scale) {
EXPECT_EQ(ov::layout::get_layout(f->input(0)), Layout());
}
class convert_color_to_gray : public ::testing::TestWithParam<ColorFormat> {};
TEST_P(convert_color_to_gray, nhwc_to_nhwc) {
const auto& COLOR_FORMAT = GetParam();
auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 2, 1});
auto p = PrePostProcessor(f);
auto name = f->get_parameters()[0]->get_friendly_name();
auto tensor_names = f->get_parameters().front()->get_output_tensor(0).get_names();
p.input().tensor().set_color_format(COLOR_FORMAT);
p.input().preprocess().convert_color(ColorFormat::GRAY);
f = p.build();
EXPECT_EQ(f->get_parameters().size(), 1);
EXPECT_EQ(f->get_parameters().front()->get_layout(), "NHWC");
EXPECT_EQ(ov::layout::get_layout(f->input(0)), "NHWC");
EXPECT_EQ(f->get_parameters().front()->get_friendly_name(), name);
EXPECT_EQ(f->get_parameters().front()->get_output_tensor(0).get_names(), tensor_names);
EXPECT_EQ(f->get_parameters().front()->get_partial_shape(), (PartialShape{Dimension::dynamic(), 2, 2, 3}));
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 2, 2, 1}));
}
TEST_P(convert_color_to_gray, nchw_to_nchw) {
const auto& COLOR_FORMAT = GetParam();
auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 1, 2, 2});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("NCHW").set_color_format(COLOR_FORMAT);
p.input().model().set_layout("NCHW");
p.input().preprocess().convert_color(ColorFormat::GRAY);
f = p.build();
EXPECT_EQ(f->get_parameters().front()->get_layout(), "NCHW");
EXPECT_EQ(f->get_parameters().front()->get_partial_shape(), (PartialShape{Dimension::dynamic(), 3, 2, 2}));
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1, 2, 2}));
}
TEST_P(convert_color_to_gray, nhwc_to_nchw) {
const auto& COLOR_FORMAT = GetParam();
auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 1, 2, 2});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("NHWC").set_color_format(COLOR_FORMAT);
p.input().model().set_layout("NCHW");
p.input().preprocess().convert_color(ColorFormat::GRAY);
f = p.build();
EXPECT_EQ(f->get_parameters().front()->get_layout(), "NHWC");
EXPECT_EQ(f->get_parameters().front()->get_partial_shape(), (PartialShape{Dimension::dynamic(), 2, 2, 3}));
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1, 2, 2}));
}
TEST_P(convert_color_to_gray, nchw_to_nhwc) {
const auto& COLOR_FORMAT = GetParam();
auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 2, 1});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("NCHW").set_color_format(COLOR_FORMAT);
p.input().model().set_layout("NHWC");
p.input().preprocess().convert_color(ColorFormat::GRAY);
f = p.build();
EXPECT_EQ(f->get_parameters().front()->get_layout(), "NCHW");
EXPECT_EQ(f->get_parameters().front()->get_partial_shape(), (PartialShape{Dimension::dynamic(), 3, 2, 2}));
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 2, 2, 1}));
}
TEST_P(convert_color_to_gray, assert_no_N_dim) {
const auto& COLOR_FORMAT = GetParam();
OV_EXPECT_THROW(auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 2, 1});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("DHWC").set_color_format(COLOR_FORMAT);
p.input().preprocess().convert_color(ColorFormat::GRAY);
p.build();
, ov::AssertFailure, ::testing::HasSubstr("Dimension name 'N' is not found in layout"));
}
TEST_P(convert_color_to_gray, assert_no_C_dim) {
const auto& COLOR_FORMAT = GetParam();
OV_EXPECT_THROW(auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 2, 1});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("NHWD").set_color_format(COLOR_FORMAT);
p.input().preprocess().convert_color(ColorFormat::GRAY);
p.build();
, ov::AssertFailure, ::testing::HasSubstr("C dimension index is not defined"));
}
TEST_P(convert_color_to_gray, assert_rank_less_4) {
const auto& COLOR_FORMAT = GetParam();
OV_EXPECT_THROW(auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("NHC").set_color_format(COLOR_FORMAT);
p.input().preprocess().convert_color(ColorFormat::GRAY);
p.build();
, ov::AssertFailure, ::testing::HasSubstr("Input shape size should be equal to 4, actual size: 3"));
}
TEST_P(convert_color_to_gray, assert_rank_more_4) {
const auto& COLOR_FORMAT = GetParam();
OV_EXPECT_THROW(auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 2, 2, 3});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("NDHWC").set_color_format(COLOR_FORMAT);
p.input().preprocess().convert_color(ColorFormat::GRAY);
p.build();
, ov::AssertFailure, ::testing::HasSubstr("Input shape size should be equal to 4, actual size: 5"));
}
TEST_P(convert_color_to_gray, assert_C_not_equal_1) {
const auto& COLOR_FORMAT = GetParam();
OV_EXPECT_THROW(auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 2, 5});
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("NHWC").set_color_format(COLOR_FORMAT);
p.input().preprocess().convert_color(ColorFormat::GRAY);
p.build();
,
ov::AssertFailure,
::testing::HasSubstr("Resulting shape '[?,2,2,1]' after preprocessing is not aligned with original "
"parameter's shape: [?,2,2,5]"));
}
INSTANTIATE_TEST_SUITE_P(pre_post_process,
convert_color_to_gray,
::testing::Values(ColorFormat::RGB, ColorFormat::BGR),
[](const ::testing::TestParamInfo<convert_color_to_gray::ParamType>& info) {
std::string name =
color_format_name(info.param) + "_to_" + color_format_name(ColorFormat::GRAY);
return name;
});
TEST(pre_post_process, convert_color_nv12_rgb_single) {
auto f = create_simple_function(element::f32, PartialShape{Dimension::dynamic(), 2, 2, 3});
auto p = PrePostProcessor(f);

View File

@ -66,7 +66,7 @@ protected:
// convolution
std::shared_ptr<ngraph::opset1::Constant> weightsNode = nullptr;
ngraph::Shape convFilterShape = { channelsCount, channelsCount, 3, 3 }; // out channel, /input channels, kernel h, kernel w
ngraph::Shape convFilterShape = { channelsCount, channelsCount, 3, 3 }; // out channel, input channels, kernel h, kernel w
if (netPrecision == Precision::FP32) {
std::vector<float> weightValuesFP32;
weightValuesFP32.resize(channelsCount * channelsCount * 3 * 3);

View File

@ -29,7 +29,12 @@ public:
}
};
class PreprocessOpenCVReferenceTest_YUV : public PreprocessOpenCVReferenceTest {
/// \brief Test class with counting deviated pixels
///
/// OpenCV contains custom implementation for 8U and 16U (all calculations
/// are done in INTs instead of FLOATs), so deviation in 1 color step
/// between pixels is expected
class PreprocessOpenCVReferenceTest_8U : public PreprocessOpenCVReferenceTest {
public:
void Validate() override {
threshold = 1.f;
@ -58,7 +63,141 @@ static std::shared_ptr<Model> create_simple_function(element::Type type, const P
return std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{data1});
}
TEST_F(PreprocessOpenCVReferenceTest_YUV, convert_i420_full_color_range) {
TEST_F(PreprocessOpenCVReferenceTest, convert_rgb_gray_fp32) {
const size_t input_height = 50;
const size_t input_width = 50;
auto input_shape = Shape{1, input_height, input_width, 3};
auto model_shape = Shape{1, input_height, input_width, 1};
auto input_img = std::vector<float> (shape_size(input_shape));
std::default_random_engine random(0); // hard-coded seed to make test results predictable
std::uniform_int_distribution<int> distrib(-5, 300);
for (std::size_t i = 0; i < shape_size(input_shape); i++)
input_img[i] = static_cast<float>(distrib(random));
function = create_simple_function(element::f32, model_shape);
inputData.clear();
auto p = PrePostProcessor(function);
p.input().tensor().set_color_format(ColorFormat::RGB);
p.input().preprocess().convert_color(ColorFormat::GRAY);
function = p.build();
const auto &param = function->get_parameters()[0];
inputData.emplace_back(param->get_element_type(), param->get_shape(), input_img.data());
// Calculate reference expected values from OpenCV
cv::Mat cvPic = cv::Mat(input_height, input_width, CV_32FC3, input_img.data());
cv::Mat picGRAY;
cv::cvtColor(cvPic, picGRAY, CV_RGB2GRAY);
refOutData.emplace_back(param->get_element_type(), model_shape, picGRAY.data);
// Exec now
Exec();
}
TEST_F(PreprocessOpenCVReferenceTest, convert_bgr_gray_fp32) {
const size_t input_height = 50;
const size_t input_width = 50;
auto input_shape = Shape{1, input_height, input_width, 3};
auto model_shape = Shape{1, input_height, input_width, 1};
auto input_img = std::vector<float> (shape_size(input_shape));
std::default_random_engine random(0); // hard-coded seed to make test results predictable
std::uniform_int_distribution<int> distrib(-5, 300);
for (std::size_t i = 0; i < shape_size(input_shape); i++)
input_img[i] = static_cast<float>(distrib(random));
function = create_simple_function(element::f32, model_shape);
inputData.clear();
auto p = PrePostProcessor(function);
p.input().tensor().set_color_format(ColorFormat::BGR);
p.input().preprocess().convert_color(ColorFormat::GRAY);
function = p.build();
const auto &param = function->get_parameters()[0];
inputData.emplace_back(param->get_element_type(), param->get_shape(), input_img.data());
// Calculate reference expected values from OpenCV
cv::Mat cvPic = cv::Mat(input_height, input_width, CV_32FC3, input_img.data());
cv::Mat picGRAY;
cv::cvtColor(cvPic, picGRAY, CV_BGR2GRAY);
refOutData.emplace_back(param->get_element_type(), model_shape, picGRAY.data);
// Exec now
Exec();
}
TEST_F(PreprocessOpenCVReferenceTest_8U, convert_rgb_gray_u8) {
const size_t input_height = 50;
const size_t input_width = 50;
auto input_shape = Shape{1, input_height, input_width, 3};
auto model_shape = Shape{1, input_height, input_width, 1};
auto input_img = std::vector<float> (shape_size(input_shape));
std::default_random_engine random(0); // hard-coded seed to make test results predictable
std::uniform_int_distribution<int> distrib(0, 255);
for (std::size_t i = 0; i < shape_size(input_shape); i++)
input_img[i] = static_cast<uint8_t>(distrib(random));
function = create_simple_function(element::u8, model_shape);
inputData.clear();
auto p = PrePostProcessor(function);
p.input().tensor().set_color_format(ColorFormat::RGB);
p.input().preprocess().convert_color(ColorFormat::GRAY);
function = p.build();
const auto &param = function->get_parameters()[0];
inputData.emplace_back(param->get_element_type(), param->get_shape(), input_img.data());
// Calculate reference expected values from OpenCV
cv::Mat cvPic = cv::Mat(input_height, input_width, CV_8UC3, input_img.data());
cv::Mat picGRAY;
cv::cvtColor(cvPic, picGRAY, CV_RGB2GRAY);
refOutData.emplace_back(param->get_element_type(), model_shape, picGRAY.data);
// Exec now
Exec();
}
TEST_F(PreprocessOpenCVReferenceTest_8U, convert_bgr_gray_u8) {
const size_t input_height = 50;
const size_t input_width = 50;
auto input_shape = Shape{1, input_height, input_width, 3};
auto model_shape = Shape{1, input_height, input_width, 1};
auto input_img = std::vector<uint8_t> (shape_size(input_shape));
std::default_random_engine random(0); // hard-coded seed to make test results predictable
std::uniform_int_distribution<int> distrib(0, 255);
for (std::size_t i = 0; i < shape_size(input_shape); i++)
input_img[i] = static_cast<uint8_t>(distrib(random));
function = create_simple_function(element::u8, model_shape);
inputData.clear();
auto p = PrePostProcessor(function);
p.input().tensor().set_color_format(ColorFormat::BGR);
p.input().preprocess().convert_color(ColorFormat::GRAY);
function = p.build();
const auto &param = function->get_parameters()[0];
inputData.emplace_back(param->get_element_type(), param->get_shape(), input_img.data());
// Calculate reference expected values from OpenCV
cv::Mat cvPic = cv::Mat(input_height, input_width, CV_8UC3, input_img.data());
cv::Mat picGRAY;
cv::cvtColor(cvPic, picGRAY, CV_BGR2GRAY);
refOutData.emplace_back(param->get_element_type(), model_shape, picGRAY.data);
// Exec now
Exec();
}
TEST_F(PreprocessOpenCVReferenceTest_8U, convert_i420_full_color_range) {
size_t height = 64; // 64/2 = 32 values for R
size_t width = 64; // 64/2 = 32 values for G
int b_step = 5;
@ -94,7 +233,7 @@ TEST_F(PreprocessOpenCVReferenceTest_YUV, convert_i420_full_color_range) {
Exec();
}
TEST_F(PreprocessOpenCVReferenceTest_YUV, convert_nv12_full_color_range) {
TEST_F(PreprocessOpenCVReferenceTest_8U, convert_nv12_full_color_range) {
size_t height = 64; // 64/2 = 32 values for R
size_t width = 64; // 64/2 = 32 values for G
int b_step = 5;
@ -130,7 +269,7 @@ TEST_F(PreprocessOpenCVReferenceTest_YUV, convert_nv12_full_color_range) {
Exec();
}
TEST_F(PreprocessOpenCVReferenceTest_YUV, convert_nv12_colored) {
TEST_F(PreprocessOpenCVReferenceTest_8U, convert_nv12_colored) {
auto input_yuv = std::vector<uint8_t> {235, 81, 235, 81, 109, 184};
auto func_shape = Shape{1, 2, 2, 3};
function = create_simple_function(element::u8, func_shape);
@ -180,9 +319,7 @@ TEST_F(PreprocessOpenCVReferenceTest, resize_u8_simple_linear) {
Exec();
}
TEST_F(PreprocessOpenCVReferenceTest, resize_u8_large_picture_linear) {
threshold = 1.f;
abs_threshold = 1.f; // Some pixels still have deviations of 1 step
TEST_F(PreprocessOpenCVReferenceTest_8U, resize_u8_large_picture_linear) {
const size_t input_height = 50;
const size_t input_width = 50;
const size_t func_height = 37;