diff --git a/src/bindings/python/tests/test_onnx/test_ops_unary.py b/src/bindings/python/tests/test_onnx/test_ops_unary.py index e9deeb9edc1..2d15a642ef1 100644 --- a/src/bindings/python/tests/test_onnx/test_ops_unary.py +++ b/src/bindings/python/tests/test_onnx/test_ops_unary.py @@ -480,6 +480,11 @@ def test_constant_err(): @pytest.mark.parametrize( ("shape", "shift"), [ + ((1, 1), -1), + ((2, 4), 5), + ((2, 4), 15), + ((2, 4), -5), + ((2, 4), -15), ((4, 4), 0), ((4, 4), 1), ((4, 4), -1), diff --git a/src/common/transformations/include/transformations/op_conversions/eye_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/eye_decomposition.hpp new file mode 100644 index 00000000000..414ef9734b3 --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/eye_decomposition.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API EyeDecomposition; +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * + * @brief Do eye decomposition to sub-graph (model). + */ +class ov::pass::EyeDecomposition : public MatcherPass { +public: + OPENVINO_RTTI("EyeDecomposition", "0"); + EyeDecomposition(); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 4c7a61ef0eb..e00c55376e8 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -95,6 +95,7 @@ #include "transformations/op_conversions/detection_output_downgrade.hpp" #include "transformations/op_conversions/detection_output_upgrade.hpp" #include "transformations/op_conversions/einsum_decomposition.hpp" +#include "transformations/op_conversions/eye_decomposition.hpp" #include "transformations/op_conversions/gather_normalize_negative_indices.hpp" #include "transformations/op_conversions/gelu7_downgrade.hpp" #include "transformations/op_conversions/hsigmoid_decomposition.hpp" @@ -166,6 +167,7 @@ bool ngraph::pass::CommonOptimizations::run_on_model(const std::shared_ptradd_matcher(); decomp->add_matcher(); decomp->add_matcher(); + decomp->add_matcher(); decomp->set_name("ngraph::pass::CommonDecompositions"); // CF is required after all decompositions diff --git a/src/common/transformations/src/transformations/op_conversions/eye_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/eye_decomposition.cpp new file mode 100644 index 00000000000..ad5ffb77005 --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/eye_decomposition.cpp @@ -0,0 +1,147 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/eye_decomposition.hpp" + +#include + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/util/gather_nd_base.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +namespace ov { +namespace pass { + +/** \brief Check if output is rank one and data type can be i32 or i64. */ +const auto is_rank_one_int_shape = [](const Output& output) -> bool { + return pattern::type_matches_any({element::i32, element::i64})(output) && pattern::has_static_shape()(output) && + pattern::rank_equals(1)(output); +}; + +/** \brief Predicate to check eye k node is valid. */ +const auto k_predicate = [](const Output& output) -> bool { + return is_rank_one_int_shape(output) && (output.get_partial_shape()[0].get_length() == 1); +}; + +/** \brief Predicate to check eye batch node is valid. */ +const auto batch_predicate = [](const Output& output) -> bool { + return is_rank_one_int_shape(output) && output.get_partial_shape()[0].get_length(); +}; + +/** + * \brief Make eye model which generate eye matrix. + * + * If 'k' is outside the eye dimension then result matrix will be filled with zeros. + * + * \param reg Node register used store created nodes. + * \param height Height of eye + * \param width Width of eye + * \param k Eye diagonal shift. + * \param dtype Data type of eye. + * + * \return Pointer to decomposed eye model. + */ +std::shared_ptr make_eye_model(NodeRegister& reg, + const Output& height, + const Output& width, + const Output& k, + element::Type dtype) { + const auto zero_int = reg.add(opset9::Constant::create(element::i64, Shape{1}, {0})); + const auto zero = reg.add(opset9::Constant::create(dtype, Shape{1}, {0})); + const auto one = reg.add(opset9::Constant::create(dtype, Shape{1}, {1})); + + const auto k_neg = reg.make(k); + const auto k_axis = reg.make(OutputVector{k_neg, k}, 0); + + const auto eye_shape = reg.make(OutputVector{height, width}, 0); + + // Calculate eye zero padding and internal square eye size. + const auto pad_start = reg.make(eye_shape, reg.make(zero_int, k_axis)); + const auto shape_pad_diff = reg.make(eye_shape, pad_start); + const auto eye_size = reg.make(shape_pad_diff, zero_int, true); + const auto pad_end = reg.make(shape_pad_diff, eye_size); + + // Make 1d-eye as eye_size times of (1, zeros(eye_size)), trimmed at end by eye_size elements. + const auto zeros = reg.make(zero, eye_size); + const auto one_followed_by_zeros = reg.make(OutputVector{one, zeros}, 0); + const auto eye_1d = reg.make(reg.make(one_followed_by_zeros, eye_size), + zero_int, + reg.make(eye_size), + op::PadMode::CONSTANT); + // Reshape 1d-eye to 2d-eye + const auto eye_2d = + reg.make(eye_1d, reg.make(OutputVector{eye_size, eye_size}, 0), false); + + // Pad Eye to get final shape + return reg.make(eye_2d, pad_start, pad_end, op::PadMode::CONSTANT); +} + +/** + * \brief Make eye model as basic 2D eye replicated as specified in batch size. + * + * \param reg Node register used store created nodes. + * \param eye Eye model. + * \param batch 1-D tensor which defines leading batch dimensions of output eye shape. + * + * \return Pointer to decomposed eye model. + */ +std::shared_ptr make_eye_batches(NodeRegister& reg, const Output& eye, const Output& batch) { + const auto eye_tile = reg.make(element::i64, Shape{2}, 1); + + // `batch_repeats` repeat eye matrix as tile only in higher dimensions than 1 by number(s) in batch parameter. + const auto batch_repeats = reg.make(OutputVector{batch, eye_tile}, 0); + + return reg.make(eye, batch_repeats); +} + +EyeDecomposition::EyeDecomposition() { + MATCHER_SCOPE(EyeDecomposition); + + auto p_height = pattern::any_input(); + auto p_width = pattern::any_input(); + auto p_k = pattern::wrap_type(k_predicate); + auto p_batch = pattern::wrap_type(batch_predicate); + + auto p_eye_no_batch = pattern::wrap_type({p_height, p_width, p_k}); + auto p_eye_batch = pattern::wrap_type({p_height, p_width, p_k, p_batch}); + + auto p_eye = std::make_shared(OutputVector{p_eye_batch, p_eye_no_batch}); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto m_eye = std::dynamic_pointer_cast(m.get_match_root()); + + if ((!m_eye) || transformation_callback(m_eye)) { + return false; + } + + NodeRegister copy_reg; + const auto& pattern_to_output = m.get_pattern_value_map(); + + const auto dtype = m_eye->get_out_type(); + const auto width = pattern_to_output.at(p_width); + const auto height = pattern_to_output.at(p_height); + const auto k = pattern_to_output.at(p_k); + + auto eye = make_eye_model(copy_reg, height, width, k, dtype); + + if (pattern_to_output.find(p_batch) != pattern_to_output.end()) { + eye = make_eye_batches(copy_reg, eye, pattern_to_output.at(p_batch)); + } + + eye->set_friendly_name(m_eye->get_friendly_name()); + ov::copy_runtime_info(m_eye, copy_reg.get()); + ov::replace_node(m_eye, eye); + return true; + }; + + auto m = std::make_shared(p_eye, matcher_name); + register_matcher(m, callback); +} + +} // namespace pass +} // namespace ov diff --git a/src/core/include/openvino/pass/graph_rewrite.hpp b/src/core/include/openvino/pass/graph_rewrite.hpp index a309a559823..b057ae7a47f 100644 --- a/src/core/include/openvino/pass/graph_rewrite.hpp +++ b/src/core/include/openvino/pass/graph_rewrite.hpp @@ -17,6 +17,52 @@ using graph_rewrite_callback = std::function; using recurrent_graph_rewrite_callback = std::function; using handler_callback = std::function& node)>; namespace pass { +/// \brief Register openvino node pointers into container. +/// Can create and/or add existing node pointers into register +class NodeRegister { +public: + /// \brief Make new node and add it to register. + /// + /// \tparam T Node type. + /// \tparam Args Node ctor args types. + /// + /// \param args New node ctor arguments. + /// \return Shared pointer to node of type T. + template + std::shared_ptr make(Args&&... args) { + auto node = std::make_shared(std::forward(args)...); + return add(node); + } + + /// \brief Add node to register + /// + /// \tparam T Node type. + /// + /// \param node Node to add + /// + /// \return Shared pointer to new node added of type T. + template + std::shared_ptr add(const std::shared_ptr& node) { + m_nodes.push_back(node); + return node; + } + + /// \brief Get nodes container. + /// + /// \return Const reference to nodes container. + const std::vector>& get() const { + return m_nodes; + } + + /// \brief Clear register. + void clear() { + m_nodes.clear(); + } + +private: + std::vector> m_nodes; //!< Stores added nodes. +}; + /// \brief MatcherPass is a basic block for pattern based transformations. It describes /// pattern and /// action that is applied if pattern is matched. @@ -69,15 +115,12 @@ public: template std::shared_ptr register_new_node(Args&&... args) { - auto node = std::make_shared(std::forward(args)...); - m_new_nodes.push_back(node); - return node; + return m_new_nodes.make(std::forward(args)...); } template std::shared_ptr register_new_node(const std::shared_ptr& node) { - m_new_nodes.push_back(node); - return node; + return m_new_nodes.add(node); } std::shared_ptr register_new_node_(const std::shared_ptr& node) { @@ -85,7 +128,7 @@ public: } const std::vector>& get_new_nodes() { - return m_new_nodes; + return m_new_nodes.get(); } void clear_new_nodes() { m_new_nodes.clear(); @@ -104,7 +147,7 @@ protected: private: handler_callback m_handler; std::shared_ptr m_matcher; - std::vector> m_new_nodes; + NodeRegister m_new_nodes; }; /// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function diff --git a/src/core/src/op/eye.cpp b/src/core/src/op/eye.cpp index 02bc7880192..07c7bcf38b1 100644 --- a/src/core/src/op/eye.cpp +++ b/src/core/src/op/eye.cpp @@ -26,6 +26,7 @@ bool evaluate_eye(const ov::HostTensorPtr& out, const int64_t diagonal_index) { NGRAPH_TYPE_CASE(evaluate, bf16, out, diagonal_index); NGRAPH_TYPE_CASE(evaluate, i32, out, diagonal_index); NGRAPH_TYPE_CASE(evaluate, f32, out, diagonal_index); + NGRAPH_TYPE_CASE(evaluate, f64, out, diagonal_index); NGRAPH_TYPE_CASE(evaluate, i64, out, diagonal_index); default: rc = false; diff --git a/src/frontends/onnx/frontend/src/op/eye_like.cpp b/src/frontends/onnx/frontend/src/op/eye_like.cpp index 83748441d60..4fa5a2de241 100644 --- a/src/frontends/onnx/frontend/src/op/eye_like.cpp +++ b/src/frontends/onnx/frontend/src/op/eye_like.cpp @@ -8,6 +8,7 @@ #include "exceptions.hpp" #include "ngraph/output_vector.hpp" +#include "openvino/op/eye.hpp" #include "utils/common.hpp" namespace ngraph { @@ -30,77 +31,6 @@ OutputVector get_shape_width_and_height(const Output& shape) { return {width, height}; } - -/// \brief Calculate the size of the inner identity matrix and padding values. -/// \param shape Shape of the input tensor returned by a ShapeOf operator. -/// \param k Index of the EyeLike diagonal to be populated with ones. -/// 0 populates the main diagonal, k > 0 populates an upper diagonal, -/// and k < 0 populates a lower diagonal. -/// -/// \returns A vector of 5 values. The first value is the size of the inner identity matrix. -/// The second value is the padding value for the left side of the inner identity matrix. -/// The third value is the padding value for the right side of the inner identity matrix. -/// The fourth value is the padding value for the top side of the inner identity matrix. -/// The fifth value is the padding value for the bottom side of the inner identity matrix. -OutputVector eyelike_component_dimensions(const Output& shape, std::int64_t k) { - const auto dims = get_shape_width_and_height(shape); - const auto width = dims.at(0); - const auto height = dims.at(1); - - // x1 and y1 are padding values for the left side and top side of the identity matrix. - const auto x1 = std::max(static_cast(0), k); - const auto y1 = std::max(static_cast(0), -k); - const auto x1_const = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {x1}); - const auto y1_const = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {y1}); - - // upper_pads is a helper value for calculating the size of the inner identity matrix. - const auto upper_pads = default_opset::Constant::create(ngraph::element::i64, Shape{2}, {y1, x1}); - - // a is the size of the inner identity matrix. - const auto zero = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {0}); - const auto min_size = - std::make_shared(std::make_shared(shape, upper_pads), - zero, - true); - const auto a = std::make_shared(min_size, zero); - - // x2 and y2 are padding values for the right side and bottom side of the identity matrix. - // x2 = width - a - x1 - // y2 = height - a - y1 - const auto x2 = - std::make_shared(std::make_shared(width, a), x1_const); - const auto y2 = - std::make_shared(std::make_shared(height, a), y1_const); - - return {a, x1_const, x2, y1_const, y2}; -} - -/// \brief Create a square identity matrix with the specified size and type. -/// \details The identity matrix consists of ones on the main diagonal and zeros elsewhere. -/// \param matrix_size Size of a side of the identity matrix. -/// \param target_type Data type of the identity matrix. -Output square_identity_matrix(const Output& matrix_size, element::Type target_type) { - // Construct a 1D representation of the identity matrix data - // One and zero are the values of the identity matrix. - const auto zero = default_opset::Constant::create(target_type, Shape{1}, {0}); - const auto one = default_opset::Constant::create(target_type, Shape{1}, {1}); - - // One row of the identity matrix. - const auto zeros = std::make_shared(zero, matrix_size); - const auto one_followed_by_zeros = std::make_shared(OutputVector{one, zeros}, 0); - - // The identity matrix as a 1D representation. - const auto one_int = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {1}); - const auto size_minus_one = std::make_shared(matrix_size, one_int); - const auto one_d_data = std::make_shared(one_followed_by_zeros, size_minus_one); - const auto one_d_data_concat = std::make_shared(OutputVector{one_d_data, one}, 0); - - // Reshape the 1D array to a 2D array - const auto output_shape = std::make_shared(OutputVector{matrix_size, matrix_size}, 0); - const auto diagonal = std::make_shared(one_d_data_concat, output_shape, false); - return diagonal; -} - } // namespace } // namespace detail @@ -116,32 +46,22 @@ OutputVector eye_like(const Node& node) { input_rank.get_length(), " is unsupported, only 2D shapes are supported"); - const auto shift = node.get_attribute_value("k", 0); - - std::int64_t dtype; element::Type target_type; if (node.has_attribute("dtype")) { - dtype = node.get_attribute_value("dtype"); + std::int64_t dtype = node.get_attribute_value("dtype"); target_type = common::get_ngraph_element_type(dtype); } else { target_type = input.get_element_type(); } const auto input_shape = std::make_shared(input); + const auto dims = detail::get_shape_width_and_height(input_shape); + const auto width = dims.at(0); + const auto height = dims.at(1); + const auto k = + default_opset::Constant::create(ngraph::element::i64, {1}, {node.get_attribute_value("k", 0)}); - const auto component_dimensions = detail::eyelike_component_dimensions(input_shape, shift); - const auto identity_matrix = detail::square_identity_matrix(component_dimensions.at(0), target_type); - - const auto pads_begin = - std::make_shared(OutputVector{component_dimensions.at(3), component_dimensions.at(1)}, - 0); - const auto pads_end = - std::make_shared(OutputVector{component_dimensions.at(4), component_dimensions.at(2)}, - 0); - - const auto zero = default_opset::Constant::create(target_type, Shape{}, {0}); - const auto output = - std::make_shared(identity_matrix, pads_begin, pads_end, zero, ov::op::PadMode::CONSTANT); + const auto output = std::make_shared(height, width, k, target_type); return {output}; } diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 945ed2dacfd..993a3ccc552 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -86,6 +86,7 @@ #include #include #include +#include "transformations/op_conversions/eye_decomposition.hpp" #include #include @@ -460,6 +461,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr // Allow FP16 Converts to be folded and FP16 constants to be upgraded to FP32 data type pass_config->disable(); pass_config->disable(); + pass_config->disable(); pass_config->disable(); pass_config->disable(); diff --git a/src/tests/functional/inference_engine/transformations/op_conversions/eye_decomposition_test.cpp b/src/tests/functional/inference_engine/transformations/op_conversions/eye_decomposition_test.cpp new file mode 100644 index 00000000000..91b5aa67499 --- /dev/null +++ b/src/tests/functional/inference_engine/transformations/op_conversions/eye_decomposition_test.cpp @@ -0,0 +1,299 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "gtest/gtest.h" +#include "openvino/core/model.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/opsets/opset9.hpp" +#include "transformations/op_conversions/eye_decomposition.hpp" + +using namespace testing; + +/** Helper to get access EyeDecomposition protected methods. */ +class EyeDecompositionWrapper : public ov::pass::EyeDecomposition { +public: + std::shared_ptr exp_eye(const ov::Output& height, + const ov::Output& width, + const ov::Output& k, + ov::element::Type dtype) { + const auto zero_int = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + const auto zero = ov::opset9::Constant::create(dtype, ov::Shape{1}, {0}); + const auto one = ov::opset9::Constant::create(dtype, ov::Shape{1}, {1}); + + const auto k_neg = std::make_shared(k); + const auto k_axis = std::make_shared(ov::OutputVector{k_neg, k}, 0); + + const auto eye_shape = std::make_shared(ov::OutputVector{height, width}, 0); + + // Calculate eye zero padding and internal square eye size. + const auto pad_start = + std::make_shared(eye_shape, std::make_shared(zero_int, k_axis)); + const auto shape_pad_diff = std::make_shared(eye_shape, pad_start); + const auto eye_size = std::make_shared(shape_pad_diff, zero_int, true); + const auto pad_end = std::make_shared(shape_pad_diff, eye_size); + + // Make 1d-eye as eye_size times of (1, zeros(eye_size)), trimmed at end by eye_size elements. + const auto zeros = std::make_shared(zero, eye_size); + const auto one_followed_by_zeros = std::make_shared(ov::OutputVector{one, zeros}, 0); + const auto eye_1d = + std::make_shared(std::make_shared(one_followed_by_zeros, eye_size), + zero_int, + std::make_shared(eye_size), + ov::op::PadMode::CONSTANT); + // Reshape 1d-eye to 2d-eye + const auto eye_2d = std::make_shared( + eye_1d, + std::make_shared(ov::OutputVector{eye_size, eye_size}, 0), + false); + + // Pad Eye to get final shape + return std::make_shared(eye_2d, pad_start, pad_end, ov::op::PadMode::CONSTANT); + } + + std::shared_ptr exp_eye(const ov::Output& height, + const ov::Output& width, + const ov::Output& k, + const ov::Output& batch, + ov::element::Type dtype) { + const auto eye = exp_eye(height, width, k, dtype); + const auto eye_tile = std::make_shared(ov::element::i64, ov::Shape{2}, 1); + + // `batch_repeats` repeat eye matrix as tile only in higher dimensions than 1 by number(s) in batch parameter. + const auto batch_repeats = std::make_shared(ov::OutputVector{batch, eye_tile}, 0); + + return std::make_shared(eye, batch_repeats); + } +}; + +class FakeEye : public ov::op::Op { +public: + FakeEye() = default; + + FakeEye(const ov::Output& num_rows, + const ov::Output& num_columns, + const ov::Output& diagonal_index, + const ov::Output& batch_shape, + const ov::element::Type& out_type) + : Op({num_rows, num_columns, diagonal_index, batch_shape}) { + constructor_validate_and_infer_types(); + } + + FakeEye(const ov::Output& num_rows, + const ov::Output& num_columns, + const ov::Output& diagonal_index, + const ov::element::Type& out_type) + : Op({num_rows, num_columns, diagonal_index}) { + constructor_validate_and_infer_types(); + } + + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { + check_new_args_count(this, new_args); + if (new_args.size() == 3) { + return std::make_shared(new_args[0], new_args[1], new_args[2], ov::element::f32); + } else if (new_args.size() == 4) { + return std::make_shared(new_args[0], new_args[1], new_args[2], new_args[3], ov::element::f32); + } else { + throw ov::Exception("FakeEye has incorrect input number: " + std::to_string(new_args.size())); + } + } +}; + +class EyeTransformationTests : public TransformationTestsF { +protected: + EyeDecompositionWrapper eye_decomposition_wrapper; + + ov::element::Type dtype; + size_t h, w; + int shift; + + void SetUp() override { + TransformationTestsF::SetUp(); + + dtype = ov::element::f32; + h = 4; + w = 4; + } + + template + std::shared_ptr make_test_eye(const ov::Output& k) const { + auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h}); + auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w}); + + return std::make_shared(height, width, k, dtype); + } + + template + std::shared_ptr make_test_eye() const { + auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift}); + + return make_test_eye(k); + } + + template + std::shared_ptr make_test_eye_batch(const ov::Output& batch) const { + auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h}); + auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w}); + auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift}); + + return std::make_shared(height, width, k, batch, dtype); + } +}; + +/** \brief Diagonal shift is not `Constant`, there should be no decompose. */ +TEST_F(EyeTransformationTests, shift_is_not_const) { + { + auto data = std::make_shared(dtype, ov::Shape{h, w}); + auto k = std::make_shared(ov::element::i64, ov::Shape{1}); + auto node = make_test_eye(k); + + model = std::make_shared(ov::NodeVector{node}, ov::ParameterVector{data, k}); + + manager.register_pass(); + } +} + +/** \brief Batch size is not `Constant`, there should be no decompose. */ +TEST_F(EyeTransformationTests, batch_is_not_const) { + { + auto data = std::make_shared(dtype, ov::Shape{h, w}); + auto batch = std::make_shared(ov::element::i64, ov::Shape{2}); + auto node = make_test_eye_batch(batch); + + model = std::make_shared(ov::NodeVector{node}, ov::ParameterVector{data, batch}); + + manager.register_pass(); + } +} + +/** \brief Use fake eye as not supported op type, there should be no decompose. */ +TEST_F(EyeTransformationTests, use_fake_eye) { + { + auto data = std::make_shared(dtype, ov::Shape{h, w}); + auto node = make_test_eye(); + + model = std::make_shared(ov::NodeVector{node}, ov::ParameterVector{data}); + + manager.register_pass(); + } +} + +using EyeTestParameters = std::tuple, // Eye dimensions (height, width) + int // diagonal shift + >; + +class EyeTransformationTestsP : public EyeTransformationTests, public WithParamInterface { +protected: + void SetUp() override { + TransformationTestsF::SetUp(); + + std::tuple dim; + std::tie(dtype, dim, shift) = GetParam(); + std::tie(h, w) = dim; + } +}; + +INSTANTIATE_TEST_SUITE_P(eye_no_diagonal_shift, + EyeTransformationTestsP, + Combine(Values(ov::element::i32, ov::element::f32, ov::element::u8), + Combine(Range(0, 10, 2), Range(0, 10, 2)), + Values(0)), + PrintToStringParamName()); + +INSTANTIATE_TEST_SUITE_P(square_eye_diagonal_shift_within_dim, + EyeTransformationTestsP, + Combine(Values(ov::element::i32, ov::element::f32), + Values(std::make_tuple(4, 4)), + Range(-4, 5)), + PrintToStringParamName()); + +INSTANTIATE_TEST_SUITE_P(rectangular_narrow_eye_diagonal_shift_within_dim, + EyeTransformationTestsP, + Combine(Values(ov::element::i32, ov::element::f32), + Values(std::make_tuple(7, 3)), + Range(-7, 4)), + PrintToStringParamName()); + +INSTANTIATE_TEST_SUITE_P(rectangular_wide_eye_diagonal_shift_within_dim, + EyeTransformationTestsP, + Combine(Values(ov::element::i32, ov::element::f32), + Values(std::make_tuple(2, 4)), + Range(-2, 5)), + PrintToStringParamName()); + +INSTANTIATE_TEST_SUITE_P(eye_diagonal_shift_outside_dim, + EyeTransformationTestsP, + Combine(Values(ov::element::f32), + Combine(Range(6, 10, 2), Range(6, 10, 2)), + Values(-30, -11, 11, 25)), + PrintToStringParamName()); + +/** \brief Test eye decomposition for different data types, dimension and diagonal shift. */ +TEST_P(EyeTransformationTestsP, eye_decompose) { + { + auto data = std::make_shared(dtype, ov::Shape{h, w}); + auto node = make_test_eye(); + + model = std::make_shared(ov::NodeVector{node}, ov::ParameterVector{data}); + + manager.register_pass(); + } + + { + auto data = std::make_shared(dtype, ov::Shape{h, w}); + auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h}); + auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w}); + auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift}); + + auto node = eye_decomposition_wrapper.exp_eye(height, width, k, dtype); + model_ref = std::make_shared(ov::NodeVector{node}, ov::ParameterVector{data}); + } + + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); +} + +class BatchEyeTransformationTests : public EyeTransformationTests, public WithParamInterface> {}; + +INSTANTIATE_TEST_SUITE_P(batch_size, + BatchEyeTransformationTests, + Values(std::vector{1}, + std::vector{2}, + std::vector{2, 1}, + std::vector{2, 3}, + std::vector{3, 5, 1}, + std::vector{3, 5, 4}), + PrintToStringParamName()); + +/** \brief Test eye decomposition for batch sizes and values. */ +TEST_P(BatchEyeTransformationTests, eye_decompose) { + { + auto data = std::make_shared(dtype, ov::Shape{h, w}); + auto batch = ov::opset9::Constant::create(ov::element::i64, ov::Shape{GetParam().size()}, GetParam()); + auto node = make_test_eye_batch(batch); + + model = std::make_shared(ov::NodeVector{node}, ov::ParameterVector{data}); + + manager.register_pass(); + } + + { + auto data = std::make_shared(dtype, ov::Shape{h, w}); + auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h}); + auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w}); + auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift}); + auto batch = ov::opset9::Constant::create(ov::element::i64, ov::Shape{GetParam().size()}, GetParam()); + + auto node = eye_decomposition_wrapper.exp_eye(height, width, k, batch, dtype); + model_ref = std::make_shared(ov::NodeVector{node}, ov::ParameterVector{data}); + } + + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); +} diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index b25dfbe771e..dcecd9672a0 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -4,7 +4,10 @@ #include "ngraph_test_utils.hpp" -TransformationTestsF::TransformationTestsF() : comparator(FunctionsComparator::no_default()) { +TransformationTestsF::TransformationTestsF() + : model(function), + model_ref(function_ref), + comparator(FunctionsComparator::no_default()) { m_unh = std::make_shared(); comparator.enable(FunctionsComparator::CmpValues::NODES); comparator.enable(FunctionsComparator::CmpValues::PRECISIONS); diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp index 126c256a846..5101bb93b0b 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp @@ -43,6 +43,8 @@ public: void enable_soft_names_comparison(); std::shared_ptr function, function_ref; + // Aliases to function and function_ref pointers to be more corresponding with ov namespace. + std::shared_ptr&model, &model_ref; ngraph::pass::Manager manager; FunctionsComparator comparator;