Add decomposition transformation for eye 9 (#12403)

* Add eye decomposition transformation

* Fix EyeLike generation when diagonal shift
outside dimensions

* Add batch shape to eye decomposition

* Eye decomposition clean-up

* Remove reference part if no eye decompose in
decomposition tests

* Eye-Like use eye operator

* Disable eye decomposition for CPU plugin

* Use opset9 instead of ops in eye decomposition

* Apply transformations header style
to eye_decomposition.hpp

* Add model reference in eye decomposition tests
- use opset9 instead of ov::op:vX namespace

* Refactor eye decomposition:
- match style of other transformations
- add NodeRegister class to make and collect created nodes
- use `NodeRegister` in transformation for copy runtime info
- use `NodeRegister` in `MatcherPass` to replace new `register_new_node`
This commit is contained in:
Pawel Raasz 2022-08-24 16:59:25 +02:00 committed by GitHub
parent 3339d5a372
commit ed7275adf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 546 additions and 96 deletions

View File

@ -480,6 +480,11 @@ def test_constant_err():
@pytest.mark.parametrize(
("shape", "shift"),
[
((1, 1), -1),
((2, 4), 5),
((2, 4), 15),
((2, 4), -5),
((2, 4), -15),
((4, 4), 0),
((4, 4), 1),
((4, 4), -1),

View File

@ -0,0 +1,26 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API EyeDecomposition;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
*
* @brief Do eye decomposition to sub-graph (model).
*/
class ov::pass::EyeDecomposition : public MatcherPass {
public:
OPENVINO_RTTI("EyeDecomposition", "0");
EyeDecomposition();
};

View File

@ -95,6 +95,7 @@
#include "transformations/op_conversions/detection_output_downgrade.hpp"
#include "transformations/op_conversions/detection_output_upgrade.hpp"
#include "transformations/op_conversions/einsum_decomposition.hpp"
#include "transformations/op_conversions/eye_decomposition.hpp"
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
#include "transformations/op_conversions/gelu7_downgrade.hpp"
#include "transformations/op_conversions/hsigmoid_decomposition.hpp"
@ -166,6 +167,7 @@ bool ngraph::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ngrap
decomp->add_matcher<ngraph::pass::GatherNegativeConstIndicesNormalize>();
decomp->add_matcher<ngraph::pass::DropoutWithRandomUniformReplacer>();
decomp->add_matcher<ngraph::pass::TransposeReshapeEliminationForMatmul>();
decomp->add_matcher<ov::pass::EyeDecomposition>();
decomp->set_name("ngraph::pass::CommonDecompositions");
// CF is required after all decompositions

View File

@ -0,0 +1,147 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/eye_decomposition.hpp"
#include <memory>
#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/util/gather_nd_base.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
namespace ov {
namespace pass {
/** \brief Check if output is rank one and data type can be i32 or i64. */
const auto is_rank_one_int_shape = [](const Output<Node>& output) -> bool {
return pattern::type_matches_any({element::i32, element::i64})(output) && pattern::has_static_shape()(output) &&
pattern::rank_equals(1)(output);
};
/** \brief Predicate to check eye k node is valid. */
const auto k_predicate = [](const Output<Node>& output) -> bool {
return is_rank_one_int_shape(output) && (output.get_partial_shape()[0].get_length() == 1);
};
/** \brief Predicate to check eye batch node is valid. */
const auto batch_predicate = [](const Output<Node>& output) -> bool {
return is_rank_one_int_shape(output) && output.get_partial_shape()[0].get_length();
};
/**
* \brief Make eye model which generate eye matrix.
*
* If 'k' is outside the eye dimension then result matrix will be filled with zeros.
*
* \param reg Node register used store created nodes.
* \param height Height of eye
* \param width Width of eye
* \param k Eye diagonal shift.
* \param dtype Data type of eye.
*
* \return Pointer to decomposed eye model.
*/
std::shared_ptr<Node> make_eye_model(NodeRegister& reg,
const Output<Node>& height,
const Output<Node>& width,
const Output<Node>& k,
element::Type dtype) {
const auto zero_int = reg.add(opset9::Constant::create(element::i64, Shape{1}, {0}));
const auto zero = reg.add(opset9::Constant::create(dtype, Shape{1}, {0}));
const auto one = reg.add(opset9::Constant::create(dtype, Shape{1}, {1}));
const auto k_neg = reg.make<opset9::Negative>(k);
const auto k_axis = reg.make<opset9::Concat>(OutputVector{k_neg, k}, 0);
const auto eye_shape = reg.make<opset9::Concat>(OutputVector{height, width}, 0);
// Calculate eye zero padding and internal square eye size.
const auto pad_start = reg.make<opset9::Minimum>(eye_shape, reg.make<opset9::Maximum>(zero_int, k_axis));
const auto shape_pad_diff = reg.make<opset9::Subtract>(eye_shape, pad_start);
const auto eye_size = reg.make<opset9::ReduceMin>(shape_pad_diff, zero_int, true);
const auto pad_end = reg.make<opset9::Subtract>(shape_pad_diff, eye_size);
// Make 1d-eye as eye_size times of (1, zeros(eye_size)), trimmed at end by eye_size elements.
const auto zeros = reg.make<opset9::Tile>(zero, eye_size);
const auto one_followed_by_zeros = reg.make<opset9::Concat>(OutputVector{one, zeros}, 0);
const auto eye_1d = reg.make<opset9::Pad>(reg.make<opset9::Tile>(one_followed_by_zeros, eye_size),
zero_int,
reg.make<opset9::Negative>(eye_size),
op::PadMode::CONSTANT);
// Reshape 1d-eye to 2d-eye
const auto eye_2d =
reg.make<opset9::Reshape>(eye_1d, reg.make<opset9::Concat>(OutputVector{eye_size, eye_size}, 0), false);
// Pad Eye to get final shape
return reg.make<opset9::Pad>(eye_2d, pad_start, pad_end, op::PadMode::CONSTANT);
}
/**
* \brief Make eye model as basic 2D eye replicated as specified in batch size.
*
* \param reg Node register used store created nodes.
* \param eye Eye model.
* \param batch 1-D tensor which defines leading batch dimensions of output eye shape.
*
* \return Pointer to decomposed eye model.
*/
std::shared_ptr<Node> make_eye_batches(NodeRegister& reg, const Output<Node>& eye, const Output<Node>& batch) {
const auto eye_tile = reg.make<opset9::Constant>(element::i64, Shape{2}, 1);
// `batch_repeats` repeat eye matrix as tile only in higher dimensions than 1 by number(s) in batch parameter.
const auto batch_repeats = reg.make<opset9::Concat>(OutputVector{batch, eye_tile}, 0);
return reg.make<opset9::Tile>(eye, batch_repeats);
}
EyeDecomposition::EyeDecomposition() {
MATCHER_SCOPE(EyeDecomposition);
auto p_height = pattern::any_input();
auto p_width = pattern::any_input();
auto p_k = pattern::wrap_type<opset9::Constant>(k_predicate);
auto p_batch = pattern::wrap_type<opset9::Constant>(batch_predicate);
auto p_eye_no_batch = pattern::wrap_type<opset9::Eye>({p_height, p_width, p_k});
auto p_eye_batch = pattern::wrap_type<opset9::Eye>({p_height, p_width, p_k, p_batch});
auto p_eye = std::make_shared<pattern::op::Or>(OutputVector{p_eye_batch, p_eye_no_batch});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto m_eye = std::dynamic_pointer_cast<opset9::Eye>(m.get_match_root());
if ((!m_eye) || transformation_callback(m_eye)) {
return false;
}
NodeRegister copy_reg;
const auto& pattern_to_output = m.get_pattern_value_map();
const auto dtype = m_eye->get_out_type();
const auto width = pattern_to_output.at(p_width);
const auto height = pattern_to_output.at(p_height);
const auto k = pattern_to_output.at(p_k);
auto eye = make_eye_model(copy_reg, height, width, k, dtype);
if (pattern_to_output.find(p_batch) != pattern_to_output.end()) {
eye = make_eye_batches(copy_reg, eye, pattern_to_output.at(p_batch));
}
eye->set_friendly_name(m_eye->get_friendly_name());
ov::copy_runtime_info(m_eye, copy_reg.get());
ov::replace_node(m_eye, eye);
return true;
};
auto m = std::make_shared<pattern::Matcher>(p_eye, matcher_name);
register_matcher(m, callback);
}
} // namespace pass
} // namespace ov

View File

@ -17,6 +17,52 @@ using graph_rewrite_callback = std::function<bool(pass::pattern::Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(pass::pattern::RecurrentMatcher& m)>;
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
namespace pass {
/// \brief Register openvino node pointers into container.
/// Can create and/or add existing node pointers into register
class NodeRegister {
public:
/// \brief Make new node and add it to register.
///
/// \tparam T Node type.
/// \tparam Args Node ctor args types.
///
/// \param args New node ctor arguments.
/// \return Shared pointer to node of type T.
template <typename T, class... Args>
std::shared_ptr<T> make(Args&&... args) {
auto node = std::make_shared<T>(std::forward<Args>(args)...);
return add(node);
}
/// \brief Add node to register
///
/// \tparam T Node type.
///
/// \param node Node to add
///
/// \return Shared pointer to new node added of type T.
template <typename T>
std::shared_ptr<T> add(const std::shared_ptr<T>& node) {
m_nodes.push_back(node);
return node;
}
/// \brief Get nodes container.
///
/// \return Const reference to nodes container.
const std::vector<std::shared_ptr<ov::Node>>& get() const {
return m_nodes;
}
/// \brief Clear register.
void clear() {
m_nodes.clear();
}
private:
std::vector<std::shared_ptr<Node>> m_nodes; //!< Stores added nodes.
};
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
/// pattern and
/// action that is applied if pattern is matched.
@ -69,15 +115,12 @@ public:
template <typename T, class... Args>
std::shared_ptr<T> register_new_node(Args&&... args) {
auto node = std::make_shared<T>(std::forward<Args>(args)...);
m_new_nodes.push_back(node);
return node;
return m_new_nodes.make<T>(std::forward<Args>(args)...);
}
template <typename T>
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
m_new_nodes.push_back(node);
return node;
return m_new_nodes.add(node);
}
std::shared_ptr<ov::Node> register_new_node_(const std::shared_ptr<ov::Node>& node) {
@ -85,7 +128,7 @@ public:
}
const std::vector<std::shared_ptr<ov::Node>>& get_new_nodes() {
return m_new_nodes;
return m_new_nodes.get();
}
void clear_new_nodes() {
m_new_nodes.clear();
@ -104,7 +147,7 @@ protected:
private:
handler_callback m_handler;
std::shared_ptr<pattern::Matcher> m_matcher;
std::vector<std::shared_ptr<ov::Node>> m_new_nodes;
NodeRegister m_new_nodes;
};
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function

View File

@ -26,6 +26,7 @@ bool evaluate_eye(const ov::HostTensorPtr& out, const int64_t diagonal_index) {
NGRAPH_TYPE_CASE(evaluate, bf16, out, diagonal_index);
NGRAPH_TYPE_CASE(evaluate, i32, out, diagonal_index);
NGRAPH_TYPE_CASE(evaluate, f32, out, diagonal_index);
NGRAPH_TYPE_CASE(evaluate, f64, out, diagonal_index);
NGRAPH_TYPE_CASE(evaluate, i64, out, diagonal_index);
default:
rc = false;

View File

@ -8,6 +8,7 @@
#include "exceptions.hpp"
#include "ngraph/output_vector.hpp"
#include "openvino/op/eye.hpp"
#include "utils/common.hpp"
namespace ngraph {
@ -30,77 +31,6 @@ OutputVector get_shape_width_and_height(const Output<ngraph::Node>& shape) {
return {width, height};
}
/// \brief Calculate the size of the inner identity matrix and padding values.
/// \param shape Shape of the input tensor returned by a ShapeOf operator.
/// \param k Index of the EyeLike diagonal to be populated with ones.
/// 0 populates the main diagonal, k > 0 populates an upper diagonal,
/// and k < 0 populates a lower diagonal.
///
/// \returns A vector of 5 values. The first value is the size of the inner identity matrix.
/// The second value is the padding value for the left side of the inner identity matrix.
/// The third value is the padding value for the right side of the inner identity matrix.
/// The fourth value is the padding value for the top side of the inner identity matrix.
/// The fifth value is the padding value for the bottom side of the inner identity matrix.
OutputVector eyelike_component_dimensions(const Output<ngraph::Node>& shape, std::int64_t k) {
const auto dims = get_shape_width_and_height(shape);
const auto width = dims.at(0);
const auto height = dims.at(1);
// x1 and y1 are padding values for the left side and top side of the identity matrix.
const auto x1 = std::max(static_cast<int64_t>(0), k);
const auto y1 = std::max(static_cast<int64_t>(0), -k);
const auto x1_const = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {x1});
const auto y1_const = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {y1});
// upper_pads is a helper value for calculating the size of the inner identity matrix.
const auto upper_pads = default_opset::Constant::create(ngraph::element::i64, Shape{2}, {y1, x1});
// a is the size of the inner identity matrix.
const auto zero = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {0});
const auto min_size =
std::make_shared<default_opset::ReduceMin>(std::make_shared<default_opset::Subtract>(shape, upper_pads),
zero,
true);
const auto a = std::make_shared<default_opset::Maximum>(min_size, zero);
// x2 and y2 are padding values for the right side and bottom side of the identity matrix.
// x2 = width - a - x1
// y2 = height - a - y1
const auto x2 =
std::make_shared<default_opset::Subtract>(std::make_shared<default_opset::Subtract>(width, a), x1_const);
const auto y2 =
std::make_shared<default_opset::Subtract>(std::make_shared<default_opset::Subtract>(height, a), y1_const);
return {a, x1_const, x2, y1_const, y2};
}
/// \brief Create a square identity matrix with the specified size and type.
/// \details The identity matrix consists of ones on the main diagonal and zeros elsewhere.
/// \param matrix_size Size of a side of the identity matrix.
/// \param target_type Data type of the identity matrix.
Output<ngraph::Node> square_identity_matrix(const Output<ngraph::Node>& matrix_size, element::Type target_type) {
// Construct a 1D representation of the identity matrix data
// One and zero are the values of the identity matrix.
const auto zero = default_opset::Constant::create(target_type, Shape{1}, {0});
const auto one = default_opset::Constant::create(target_type, Shape{1}, {1});
// One row of the identity matrix.
const auto zeros = std::make_shared<default_opset::Tile>(zero, matrix_size);
const auto one_followed_by_zeros = std::make_shared<default_opset::Concat>(OutputVector{one, zeros}, 0);
// The identity matrix as a 1D representation.
const auto one_int = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {1});
const auto size_minus_one = std::make_shared<default_opset::Subtract>(matrix_size, one_int);
const auto one_d_data = std::make_shared<default_opset::Tile>(one_followed_by_zeros, size_minus_one);
const auto one_d_data_concat = std::make_shared<default_opset::Concat>(OutputVector{one_d_data, one}, 0);
// Reshape the 1D array to a 2D array
const auto output_shape = std::make_shared<default_opset::Concat>(OutputVector{matrix_size, matrix_size}, 0);
const auto diagonal = std::make_shared<default_opset::Reshape>(one_d_data_concat, output_shape, false);
return diagonal;
}
} // namespace
} // namespace detail
@ -116,32 +46,22 @@ OutputVector eye_like(const Node& node) {
input_rank.get_length(),
" is unsupported, only 2D shapes are supported");
const auto shift = node.get_attribute_value<std::int64_t>("k", 0);
std::int64_t dtype;
element::Type target_type;
if (node.has_attribute("dtype")) {
dtype = node.get_attribute_value<std::int64_t>("dtype");
std::int64_t dtype = node.get_attribute_value<std::int64_t>("dtype");
target_type = common::get_ngraph_element_type(dtype);
} else {
target_type = input.get_element_type();
}
const auto input_shape = std::make_shared<default_opset::ShapeOf>(input);
const auto dims = detail::get_shape_width_and_height(input_shape);
const auto width = dims.at(0);
const auto height = dims.at(1);
const auto k =
default_opset::Constant::create(ngraph::element::i64, {1}, {node.get_attribute_value<std::int64_t>("k", 0)});
const auto component_dimensions = detail::eyelike_component_dimensions(input_shape, shift);
const auto identity_matrix = detail::square_identity_matrix(component_dimensions.at(0), target_type);
const auto pads_begin =
std::make_shared<default_opset::Concat>(OutputVector{component_dimensions.at(3), component_dimensions.at(1)},
0);
const auto pads_end =
std::make_shared<default_opset::Concat>(OutputVector{component_dimensions.at(4), component_dimensions.at(2)},
0);
const auto zero = default_opset::Constant::create(target_type, Shape{}, {0});
const auto output =
std::make_shared<default_opset::Pad>(identity_matrix, pads_begin, pads_end, zero, ov::op::PadMode::CONSTANT);
const auto output = std::make_shared<ov::op::v9::Eye>(height, width, k, target_type);
return {output};
}

View File

@ -86,6 +86,7 @@
#include <transformations/op_conversions/convert_roi_align_v9_to_v3.hpp>
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
#include <transformations/op_conversions/softsign_decomposition.hpp>
#include "transformations/op_conversions/eye_decomposition.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp>
@ -460,6 +461,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
// Allow FP16 Converts to be folded and FP16 constants to be upgraded to FP32 data type
pass_config->disable<ov::pass::DisableDecompressionConvertConstantFolding>();
pass_config->disable<ov::pass::ConvertCompressedOnlyToLegacy>();
pass_config->disable<ov::pass::EyeDecomposition>();
pass_config->disable<ngraph::pass::ConvertGELU>();
pass_config->disable<ngraph::pass::ConvertShuffleChannels3>();

View File

@ -0,0 +1,299 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <string>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "openvino/core/model.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/opsets/opset9.hpp"
#include "transformations/op_conversions/eye_decomposition.hpp"
using namespace testing;
/** Helper to get access EyeDecomposition protected methods. */
class EyeDecompositionWrapper : public ov::pass::EyeDecomposition {
public:
std::shared_ptr<ov::Node> exp_eye(const ov::Output<ov::Node>& height,
const ov::Output<ov::Node>& width,
const ov::Output<ov::Node>& k,
ov::element::Type dtype) {
const auto zero_int = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0});
const auto zero = ov::opset9::Constant::create(dtype, ov::Shape{1}, {0});
const auto one = ov::opset9::Constant::create(dtype, ov::Shape{1}, {1});
const auto k_neg = std::make_shared<ov::opset9::Negative>(k);
const auto k_axis = std::make_shared<ov::opset9::Concat>(ov::OutputVector{k_neg, k}, 0);
const auto eye_shape = std::make_shared<ov::opset9::Concat>(ov::OutputVector{height, width}, 0);
// Calculate eye zero padding and internal square eye size.
const auto pad_start =
std::make_shared<ov::opset9::Minimum>(eye_shape, std::make_shared<ov::opset9::Maximum>(zero_int, k_axis));
const auto shape_pad_diff = std::make_shared<ov::opset9::Subtract>(eye_shape, pad_start);
const auto eye_size = std::make_shared<ov::opset9::ReduceMin>(shape_pad_diff, zero_int, true);
const auto pad_end = std::make_shared<ov::opset9::Subtract>(shape_pad_diff, eye_size);
// Make 1d-eye as eye_size times of (1, zeros(eye_size)), trimmed at end by eye_size elements.
const auto zeros = std::make_shared<ov::opset9::Tile>(zero, eye_size);
const auto one_followed_by_zeros = std::make_shared<ov::opset9::Concat>(ov::OutputVector{one, zeros}, 0);
const auto eye_1d =
std::make_shared<ov::opset9::Pad>(std::make_shared<ov::opset9::Tile>(one_followed_by_zeros, eye_size),
zero_int,
std::make_shared<ov::opset9::Negative>(eye_size),
ov::op::PadMode::CONSTANT);
// Reshape 1d-eye to 2d-eye
const auto eye_2d = std::make_shared<ov::opset9::Reshape>(
eye_1d,
std::make_shared<ov::opset9::Concat>(ov::OutputVector{eye_size, eye_size}, 0),
false);
// Pad Eye to get final shape
return std::make_shared<ov::opset9::Pad>(eye_2d, pad_start, pad_end, ov::op::PadMode::CONSTANT);
}
std::shared_ptr<ov::Node> exp_eye(const ov::Output<ov::Node>& height,
const ov::Output<ov::Node>& width,
const ov::Output<ov::Node>& k,
const ov::Output<ov::Node>& batch,
ov::element::Type dtype) {
const auto eye = exp_eye(height, width, k, dtype);
const auto eye_tile = std::make_shared<ov::opset9::Constant>(ov::element::i64, ov::Shape{2}, 1);
// `batch_repeats` repeat eye matrix as tile only in higher dimensions than 1 by number(s) in batch parameter.
const auto batch_repeats = std::make_shared<ov::opset9::Concat>(ov::OutputVector{batch, eye_tile}, 0);
return std::make_shared<ov::opset9::Tile>(eye, batch_repeats);
}
};
class FakeEye : public ov::op::Op {
public:
FakeEye() = default;
FakeEye(const ov::Output<ov::Node>& num_rows,
const ov::Output<ov::Node>& num_columns,
const ov::Output<ov::Node>& diagonal_index,
const ov::Output<ov::Node>& batch_shape,
const ov::element::Type& out_type)
: Op({num_rows, num_columns, diagonal_index, batch_shape}) {
constructor_validate_and_infer_types();
}
FakeEye(const ov::Output<ov::Node>& num_rows,
const ov::Output<ov::Node>& num_columns,
const ov::Output<ov::Node>& diagonal_index,
const ov::element::Type& out_type)
: Op({num_rows, num_columns, diagonal_index}) {
constructor_validate_and_infer_types();
}
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
check_new_args_count(this, new_args);
if (new_args.size() == 3) {
return std::make_shared<FakeEye>(new_args[0], new_args[1], new_args[2], ov::element::f32);
} else if (new_args.size() == 4) {
return std::make_shared<FakeEye>(new_args[0], new_args[1], new_args[2], new_args[3], ov::element::f32);
} else {
throw ov::Exception("FakeEye has incorrect input number: " + std::to_string(new_args.size()));
}
}
};
class EyeTransformationTests : public TransformationTestsF {
protected:
EyeDecompositionWrapper eye_decomposition_wrapper;
ov::element::Type dtype;
size_t h, w;
int shift;
void SetUp() override {
TransformationTestsF::SetUp();
dtype = ov::element::f32;
h = 4;
w = 4;
}
template <class TEye>
std::shared_ptr<TEye> make_test_eye(const ov::Output<ov::Node>& k) const {
auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h});
auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w});
return std::make_shared<TEye>(height, width, k, dtype);
}
template <class TEye>
std::shared_ptr<TEye> make_test_eye() const {
auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift});
return make_test_eye<TEye>(k);
}
template <class TEye>
std::shared_ptr<TEye> make_test_eye_batch(const ov::Output<ov::Node>& batch) const {
auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h});
auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w});
auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift});
return std::make_shared<TEye>(height, width, k, batch, dtype);
}
};
/** \brief Diagonal shift is not `Constant`, there should be no decompose. */
TEST_F(EyeTransformationTests, shift_is_not_const) {
{
auto data = std::make_shared<ov::opset9::Parameter>(dtype, ov::Shape{h, w});
auto k = std::make_shared<ov::opset9::Parameter>(ov::element::i64, ov::Shape{1});
auto node = make_test_eye<ov::opset9::Eye>(k);
model = std::make_shared<ov::Model>(ov::NodeVector{node}, ov::ParameterVector{data, k});
manager.register_pass<ov::pass::EyeDecomposition>();
}
}
/** \brief Batch size is not `Constant`, there should be no decompose. */
TEST_F(EyeTransformationTests, batch_is_not_const) {
{
auto data = std::make_shared<ov::opset9::Parameter>(dtype, ov::Shape{h, w});
auto batch = std::make_shared<ov::opset9::Parameter>(ov::element::i64, ov::Shape{2});
auto node = make_test_eye_batch<ov::opset9::Eye>(batch);
model = std::make_shared<ov::Model>(ov::NodeVector{node}, ov::ParameterVector{data, batch});
manager.register_pass<ov::pass::EyeDecomposition>();
}
}
/** \brief Use fake eye as not supported op type, there should be no decompose. */
TEST_F(EyeTransformationTests, use_fake_eye) {
{
auto data = std::make_shared<ov::opset9::Parameter>(dtype, ov::Shape{h, w});
auto node = make_test_eye<FakeEye>();
model = std::make_shared<ov::Model>(ov::NodeVector{node}, ov::ParameterVector{data});
manager.register_pass<ov::pass::EyeDecomposition>();
}
}
using EyeTestParameters = std::tuple<ov::element::Type, // Eye element type
std::tuple<size_t, size_t>, // Eye dimensions (height, width)
int // diagonal shift
>;
class EyeTransformationTestsP : public EyeTransformationTests, public WithParamInterface<EyeTestParameters> {
protected:
void SetUp() override {
TransformationTestsF::SetUp();
std::tuple<size_t, size_t> dim;
std::tie(dtype, dim, shift) = GetParam();
std::tie(h, w) = dim;
}
};
INSTANTIATE_TEST_SUITE_P(eye_no_diagonal_shift,
EyeTransformationTestsP,
Combine(Values(ov::element::i32, ov::element::f32, ov::element::u8),
Combine(Range<size_t>(0, 10, 2), Range<size_t>(0, 10, 2)),
Values(0)),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(square_eye_diagonal_shift_within_dim,
EyeTransformationTestsP,
Combine(Values(ov::element::i32, ov::element::f32),
Values(std::make_tuple(4, 4)),
Range(-4, 5)),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(rectangular_narrow_eye_diagonal_shift_within_dim,
EyeTransformationTestsP,
Combine(Values(ov::element::i32, ov::element::f32),
Values(std::make_tuple(7, 3)),
Range(-7, 4)),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(rectangular_wide_eye_diagonal_shift_within_dim,
EyeTransformationTestsP,
Combine(Values(ov::element::i32, ov::element::f32),
Values(std::make_tuple(2, 4)),
Range(-2, 5)),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(eye_diagonal_shift_outside_dim,
EyeTransformationTestsP,
Combine(Values(ov::element::f32),
Combine(Range<size_t>(6, 10, 2), Range<size_t>(6, 10, 2)),
Values(-30, -11, 11, 25)),
PrintToStringParamName());
/** \brief Test eye decomposition for different data types, dimension and diagonal shift. */
TEST_P(EyeTransformationTestsP, eye_decompose) {
{
auto data = std::make_shared<ov::opset9::Parameter>(dtype, ov::Shape{h, w});
auto node = make_test_eye<ov::opset9::Eye>();
model = std::make_shared<ov::Model>(ov::NodeVector{node}, ov::ParameterVector{data});
manager.register_pass<ov::pass::EyeDecomposition>();
}
{
auto data = std::make_shared<ov::opset9::Parameter>(dtype, ov::Shape{h, w});
auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h});
auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w});
auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift});
auto node = eye_decomposition_wrapper.exp_eye(height, width, k, dtype);
model_ref = std::make_shared<ov::Model>(ov::NodeVector{node}, ov::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
class BatchEyeTransformationTests : public EyeTransformationTests, public WithParamInterface<std::vector<size_t>> {};
INSTANTIATE_TEST_SUITE_P(batch_size,
BatchEyeTransformationTests,
Values(std::vector<size_t>{1},
std::vector<size_t>{2},
std::vector<size_t>{2, 1},
std::vector<size_t>{2, 3},
std::vector<size_t>{3, 5, 1},
std::vector<size_t>{3, 5, 4}),
PrintToStringParamName());
/** \brief Test eye decomposition for batch sizes and values. */
TEST_P(BatchEyeTransformationTests, eye_decompose) {
{
auto data = std::make_shared<ov::opset9::Parameter>(dtype, ov::Shape{h, w});
auto batch = ov::opset9::Constant::create(ov::element::i64, ov::Shape{GetParam().size()}, GetParam());
auto node = make_test_eye_batch<ov::opset9::Eye>(batch);
model = std::make_shared<ov::Model>(ov::NodeVector{node}, ov::ParameterVector{data});
manager.register_pass<ov::pass::EyeDecomposition>();
}
{
auto data = std::make_shared<ov::opset9::Parameter>(dtype, ov::Shape{h, w});
auto height = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {h});
auto width = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {w});
auto k = ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {shift});
auto batch = ov::opset9::Constant::create(ov::element::i64, ov::Shape{GetParam().size()}, GetParam());
auto node = eye_decomposition_wrapper.exp_eye(height, width, k, batch, dtype);
model_ref = std::make_shared<ov::Model>(ov::NodeVector{node}, ov::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}

View File

@ -4,7 +4,10 @@
#include "ngraph_test_utils.hpp"
TransformationTestsF::TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
TransformationTestsF::TransformationTestsF()
: model(function),
model_ref(function_ref),
comparator(FunctionsComparator::no_default()) {
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
comparator.enable(FunctionsComparator::CmpValues::NODES);
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);

View File

@ -43,6 +43,8 @@ public:
void enable_soft_names_comparison();
std::shared_ptr<ov::Model> function, function_ref;
// Aliases to function and function_ref pointers to be more corresponding with ov namespace.
std::shared_ptr<ov::Model>&model, &model_ref;
ngraph::pass::Manager manager;
FunctionsComparator comparator;