diff --git a/docs/template_plugin/tests/functional/subgraph_reference/preprocess.cpp b/docs/template_plugin/tests/functional/subgraph_reference/preprocess.cpp index c7c74ccf986..af3fb3ccb61 100644 --- a/docs/template_plugin/tests/functional/subgraph_reference/preprocess.cpp +++ b/docs/template_plugin/tests/functional/subgraph_reference/preprocess.cpp @@ -746,6 +746,101 @@ static RefPreprocessParams pre_and_post_processing() { return res; } +static RefPreprocessParams rgb_to_bgr() { + RefPreprocessParams res("rgb_to_bgr"); + res.function = []() { + auto f = create_simple_function(element::f32, Shape{2, 1, 1, 3}); + f = PrePostProcessor().input(InputInfo() + .tensor(InputTensorInfo().set_color_format(ColorFormat::RGB)) + .preprocess(PreProcessSteps().convert_color(ColorFormat::BGR))).build(f); + return f; + }; + + res.inputs.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector{1, 2, 3, 4, 5, 6}); + res.expected.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector{3, 2, 1, 6, 5, 4}); + return res; +} + +static RefPreprocessParams bgr_to_rgb() { + RefPreprocessParams res("bgr_to_rgb"); + res.function = []() { + auto f = create_simple_function(element::f32, Shape{2, 1, 1, 3}); + f = PrePostProcessor().input(InputInfo() + .tensor(InputTensorInfo().set_color_format(ColorFormat::BGR)) + .preprocess(PreProcessSteps().convert_color(ColorFormat::RGB))).build(f); + return f; + }; + + res.inputs.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector{1, 2, 3, 4, 5, 6}); + res.expected.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector{3, 2, 1, 6, 5, 4}); + return res; +} + +static RefPreprocessParams reverse_channels_nchw() { + RefPreprocessParams res("reverse_channels_nchw"); + res.function = []() { + auto f = create_simple_function(element::f32, PartialShape{1, 2, 2, 2}); + f = PrePostProcessor().input(InputInfo() + .tensor(InputTensorInfo().set_layout("NCHW")) + .preprocess(PreProcessSteps().reverse_channels())).build(f); + return f; + }; + + res.inputs.emplace_back(Shape{1, 2, 2, 2}, element::f32, std::vector{1, 2, 3, 4, 5, 6, 7, 8}); + res.expected.emplace_back(Shape{1, 2, 2, 2}, element::f32, std::vector{5, 6, 7, 8, 1, 2, 3, 4}); + return res; +} + +static RefPreprocessParams reverse_channels_dyn_layout() { + RefPreprocessParams res("reverse_channels_dyn_layout"); + res.function = []() { + auto f = create_simple_function(element::f32, PartialShape{1, 1, 3, 2}); + f = PrePostProcessor().input(InputInfo() + .tensor(InputTensorInfo().set_color_format(ColorFormat::BGR).set_layout("...CN")) + .preprocess(PreProcessSteps().convert_color(ColorFormat::RGB))).build(f); + return f; + }; + + res.inputs.emplace_back(Shape{1, 1, 3, 2}, element::f32, std::vector{1, 2, 3, 4, 5, 6}); + res.expected.emplace_back(Shape{1, 1, 3, 2}, element::f32, std::vector{5, 6, 3, 4, 1, 2}); + return res; +} + +static RefPreprocessParams reverse_dyn_shape() { + RefPreprocessParams res("reverse_dyn_shape"); + res.function = []() { + auto f = create_simple_function(element::u8, PartialShape{Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic()}); + f = PrePostProcessor().input(InputInfo() + .tensor(InputTensorInfo().set_layout("NCHW")) + .preprocess(PreProcessSteps().reverse_channels())).build(f); + return f; + }; + + res.inputs.emplace_back(element::u8, Shape{2, 2, 1, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + res.expected.emplace_back(Shape{2, 2, 1, 3}, element::u8, std::vector{4, 5, 6, 1, 2, 3, 10, 11, 12, 7, 8, 9}); + return res; +} + +static RefPreprocessParams reverse_fully_dyn_shape() { + RefPreprocessParams res("reverse_fully_dyn_shape"); + res.function = []() { + auto f = create_simple_function(element::u8, PartialShape::dynamic()); + auto p = PreProcessSteps(); + p.reverse_channels(); + f = PrePostProcessor().input(InputInfo() + .tensor(InputTensorInfo().set_layout("...C??")) + .preprocess(std::move(p))).build(f); + return f; + }; + + res.inputs.emplace_back(element::u8, Shape{2, 2, 1, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + res.expected.emplace_back(Shape{2, 2, 1, 3}, element::u8, std::vector{4, 5, 6, 1, 2, 3, 10, 11, 12, 7, 8, 9}); + return res; +} + std::vector allPreprocessTests() { return std::vector { simple_mean_scale(), @@ -773,7 +868,13 @@ std::vector allPreprocessTests() { convert_color_nv12_layout_resize(), element_type_before_convert_color_nv12(), postprocess_2_inputs_basic(), - pre_and_post_processing() + pre_and_post_processing(), + rgb_to_bgr(), + bgr_to_rgb(), + reverse_channels_nchw(), + reverse_channels_dyn_layout(), + reverse_dyn_shape(), + reverse_fully_dyn_shape() }; } diff --git a/ngraph/core/include/openvino/core/preprocess/preprocess_steps.hpp b/ngraph/core/include/openvino/core/preprocess/preprocess_steps.hpp index 6c08f8b1b2c..cacac99e324 100644 --- a/ngraph/core/include/openvino/core/preprocess/preprocess_steps.hpp +++ b/ngraph/core/include/openvino/core/preprocess/preprocess_steps.hpp @@ -228,6 +228,27 @@ public: /// /// \return Rvalue reference to 'this' to allow chaining with other calls in a builder-like manner. PreProcessSteps&& convert_layout(const Layout& dst_layout = {}) &&; + + /// \brief Reverse channels operation - Lvalue version. + /// + /// \details Adds appropriate operation which reverses channels layout. Operation requires layout having 'C' + /// dimension Operation convert_color (RGB<->BGR) does reversing of channels also, but only for NHWC layout + /// + /// \example Example: when user data has 'NCHW' layout (example is [1, 3, 224, 224] RGB order) but network expects + /// BGR planes order. Preprocessing may look like this: + /// + /// \code{.cpp} auto proc = + /// PrePostProcessor() + /// .input(InputInfo() + /// .tensor(InputTensorInfo().set_layout("NCHW")) // User data is NCHW + /// .preprocess(PreProcessSteps() + /// .reverse_channels() + /// ); + /// \endcode + /// + /// \return Reference to 'this' to allow chaining with other calls in a builder-like manner. + PreProcessSteps& reverse_channels() &; + PreProcessSteps&& reverse_channels() &&; }; } // namespace preprocess diff --git a/ngraph/core/src/preprocess/color_utils.cpp b/ngraph/core/src/preprocess/color_utils.cpp index 60eda3cbb00..eca26b10054 100644 --- a/ngraph/core/src/preprocess/color_utils.cpp +++ b/ngraph/core/src/preprocess/color_utils.cpp @@ -15,6 +15,10 @@ std::unique_ptr ColorFormatInfo::get(ColorFormat format) { case ColorFormat::NV12_TWO_PLANES: res.reset(new ColorFormatInfoNV12_TwoPlanes(format)); break; + case ColorFormat::RGB: + case ColorFormat::BGR: + res.reset(new ColorFormatNHWC(format)); + break; default: res.reset(new ColorFormatInfo(format)); break; diff --git a/ngraph/core/src/preprocess/color_utils.hpp b/ngraph/core/src/preprocess/color_utils.hpp index 6db9fe4116f..00f1dae2823 100644 --- a/ngraph/core/src/preprocess/color_utils.hpp +++ b/ngraph/core/src/preprocess/color_utils.hpp @@ -70,9 +70,18 @@ protected: }; // --- Derived classes --- -class ColorFormatInfoNV12_Single : public ColorFormatInfo { +class ColorFormatNHWC : public ColorFormatInfo { public: - explicit ColorFormatInfoNV12_Single(ColorFormat format) : ColorFormatInfo(format) {} + explicit ColorFormatNHWC(ColorFormat format) : ColorFormatInfo(format) {} + + Layout default_layout() const override { + return "NHWC"; + } +}; + +class ColorFormatInfoNV12_Single : public ColorFormatNHWC { +public: + explicit ColorFormatInfoNV12_Single(ColorFormat format) : ColorFormatNHWC(format) {} protected: PartialShape calculate_shape(size_t plane_num, const PartialShape& image_shape) const override { @@ -85,15 +94,11 @@ protected: } return result; } - - Layout default_layout() const override { - return "NHWC"; - } }; -class ColorFormatInfoNV12_TwoPlanes : public ColorFormatInfo { +class ColorFormatInfoNV12_TwoPlanes : public ColorFormatNHWC { public: - explicit ColorFormatInfoNV12_TwoPlanes(ColorFormat format) : ColorFormatInfo(format) {} + explicit ColorFormatInfoNV12_TwoPlanes(ColorFormat format) : ColorFormatNHWC(format) {} size_t planes_count() const override { return 2; @@ -119,10 +124,6 @@ protected: } return result; } - - Layout default_layout() const override { - return "NHWC"; - } }; } // namespace preprocess diff --git a/ngraph/core/src/preprocess/pre_post_process.cpp b/ngraph/core/src/preprocess/pre_post_process.cpp index d959b2c8b84..fe58a6e5ca4 100644 --- a/ngraph/core/src/preprocess/pre_post_process.cpp +++ b/ngraph/core/src/preprocess/pre_post_process.cpp @@ -790,6 +790,16 @@ PreProcessSteps&& PreProcessSteps::custom(const CustomPreprocessOp& preprocess_c return std::move(*this); } +PreProcessSteps& PreProcessSteps::reverse_channels() & { + m_impl->add_reverse_channels(); + return *this; +} + +PreProcessSteps&& PreProcessSteps::reverse_channels() && { + m_impl->add_reverse_channels(); + return std::move(*this); +} + // --------------------- OutputTensorInfo ------------------ OutputTensorInfo::OutputTensorInfo() : m_impl(std::unique_ptr(new OutputTensorInfoImpl())) {} OutputTensorInfo::OutputTensorInfo(OutputTensorInfo&&) noexcept = default; diff --git a/ngraph/core/src/preprocess/preprocess_steps_impl.cpp b/ngraph/core/src/preprocess/preprocess_steps_impl.cpp index f1f7dad447f..6af75d55d64 100644 --- a/ngraph/core/src/preprocess/preprocess_steps_impl.cpp +++ b/ngraph/core/src/preprocess/preprocess_steps_impl.cpp @@ -178,9 +178,9 @@ void PreStepsList::add_convert_layout_impl(const Layout& layout) { } void PreStepsList::add_convert_color_impl(const ColorFormat& dst_format) { - m_actions.emplace_back([&, dst_format](const std::vector>& nodes, - const std::shared_ptr& function, - PreprocessingContext& context) { + m_actions.emplace_back([dst_format](const std::vector>& nodes, + const std::shared_ptr& function, + PreprocessingContext& context) { if (context.color_format() == dst_format) { return std::make_tuple(nodes, false); } @@ -221,6 +221,12 @@ void PreStepsList::add_convert_color_impl(const ColorFormat& dst_format) { context.color_format() = dst_format; return std::make_tuple(std::vector>{convert}, true); } + if ((context.color_format() == ColorFormat::RGB || context.color_format() == ColorFormat::BGR) && + (dst_format == ColorFormat::RGB || dst_format == ColorFormat::BGR)) { + auto res = reverse_channels(nodes, function, context); + context.color_format() = dst_format; + return res; + } OPENVINO_ASSERT(false, "Source color format '", color_format_name(context.color_format()), @@ -228,6 +234,46 @@ void PreStepsList::add_convert_color_impl(const ColorFormat& dst_format) { }); } +void PreStepsList::add_reverse_channels() { + m_actions.emplace_back([](const std::vector>& nodes, + const std::shared_ptr& function, + PreprocessingContext& context) { + return reverse_channels(nodes, function, context); + }); +} + +std::tuple>, bool> PreStepsList::reverse_channels(const std::vector>& nodes, + const std::shared_ptr& function, + PreprocessingContext& context) { + OPENVINO_ASSERT(nodes.size() == 1, "Internal error: can't reverse channels for multi-plane inputs"); + OPENVINO_ASSERT(ov::layout::has_channels(context.layout()), + "Layout ", + context.layout().to_string(), + " doesn't have `channels` dimension"); + auto channels_idx = ov::layout::channels(context.layout()); + // Get shape of user's input tensor (e.g. Tensor[1, 3, 224, 224] -> {1, 3, 224, 224}) + auto shape_of = std::make_shared(nodes[0]); // E.g. {1, 3, 224, 224} + + auto constant_chan_idx = op::v0::Constant::create(element::i32, {}, {channels_idx}); // E.g. 1 + auto constant_chan_axis = op::v0::Constant::create(element::i32, {}, {0}); + // Gather will return scalar with number of channels (e.g. 3) + auto gather_channels_num = std::make_shared(shape_of, constant_chan_idx, constant_chan_axis); + + // Create Range from channels_num-1 to 0 (e.g. {2, 1, 0}) + auto const_minus1 = op::v0::Constant::create(element::i64, {}, {-1}); + auto channels_num_minus1 = std::make_shared(gather_channels_num, const_minus1); // E.g. 3-1=2 + // Add range + auto range_to = op::v0::Constant::create(element::i64, {}, {-1}); + auto range_step = op::v0::Constant::create(element::i64, {}, {-1}); + // E.g. {2, 1, 0} + auto range = std::make_shared(channels_num_minus1, range_to, range_step, element::i32); + + // Gather slices in reverse order (indexes are specified by 'range' operation) + auto constant_axis = op::v0::Constant::create(element::i32, {1}, {channels_idx}); + auto convert = std::make_shared(nodes[0], range, constant_axis); + return std::make_tuple(std::vector>{convert}, false); +} + //------------- Post processing ------ void PostStepsList::add_convert_impl(const element::Type& type) { m_actions.emplace_back([type](const Output& node, PostprocessingContext& ctxt) { diff --git a/ngraph/core/src/preprocess/preprocess_steps_impl.hpp b/ngraph/core/src/preprocess/preprocess_steps_impl.hpp index 00e15d50027..f6c96cac96c 100644 --- a/ngraph/core/src/preprocess/preprocess_steps_impl.hpp +++ b/ngraph/core/src/preprocess/preprocess_steps_impl.hpp @@ -155,6 +155,7 @@ public: void add_resize_impl(ResizeAlgorithm alg, int dst_height, int dst_width); void add_convert_layout_impl(const Layout& layout); void add_convert_color_impl(const ColorFormat& dst_format); + void add_reverse_channels(); const std::list& actions() const { return m_actions; @@ -163,6 +164,11 @@ public: return m_actions; } +private: + static std::tuple>, bool> reverse_channels(const std::vector>& nodes, + const std::shared_ptr& function, + PreprocessingContext& context); + private: std::list m_actions; }; diff --git a/ngraph/test/preprocess.cpp b/ngraph/test/preprocess.cpp index db6981089f2..0d1a4d3af5a 100644 --- a/ngraph/test/preprocess.cpp +++ b/ngraph/test/preprocess.cpp @@ -697,6 +697,27 @@ TEST(pre_post_process, preprocess_convert_layout_same) { EXPECT_EQ(size_old, f->get_ordered_ops().size()); } +TEST(pre_post_process, preprocess_reverse_channels_multiple_planes) { + auto f = create_simple_function(element::f32, Shape{1, 3, 2, 2}); + EXPECT_THROW( + f = PrePostProcessor() + .input(InputInfo() + .tensor(InputTensorInfo().set_color_format(ColorFormat::NV12_TWO_PLANES, {"Y", "UV"})) + .preprocess(PreProcessSteps().reverse_channels())) + .build(f), + ov::AssertFailure); +} + +TEST(pre_post_process, preprocess_reverse_channels_no_c_dim) { + auto f = create_simple_function(element::f32, Shape{1, 3, 2, 2}); + EXPECT_THROW(f = PrePostProcessor() + .input(InputInfo() + .tensor(InputTensorInfo().set_layout("N?HW")) + .preprocess(PreProcessSteps().reverse_channels())) + .build(f), + ov::AssertFailure); +} + // --- PostProcess - set/convert element type --- TEST(pre_post_process, postprocess_convert_element_type_explicit) {