Support for dynamic EyeLike operator (#7895)

* Support for dynamic EyeLike operator

* Fix broken test

* Address code review comments.

* Address code review comments.

* Bugfix

* Address code review comments.

* Add tests for dynamic cases

* Style apply
This commit is contained in:
Michał Karzyński 2021-10-14 16:50:03 +02:00 committed by GitHub
parent aa6d1f873b
commit 686c7fd57f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 268 additions and 19 deletions

View File

@ -7,20 +7,119 @@
#include <memory>
#include "exceptions.hpp"
#include "ngraph/output_vector.hpp"
#include "utils/common.hpp"
namespace ngraph {
namespace onnx_import {
namespace op {
namespace detail {
namespace {
/// \brief Split a shape returned by a ShapeOf operation into two outputs: width and height.
OutputVector get_shape_width_and_height(const Output<ngraph::Node>& shape) {
const auto axis = ngraph::op::Constant::create(ngraph::element::i64, {1}, {0});
const auto height =
std::make_shared<default_opset::Gather>(shape,
ngraph::op::Constant::create(ngraph::element::i64, {1}, {0}),
axis);
const auto width =
std::make_shared<default_opset::Gather>(shape,
ngraph::op::Constant::create(ngraph::element::i64, {1}, {1}),
axis);
return {width, height};
}
/// \brief Calculate the size of the inner identity matrix and padding values.
/// \param shape Shape of the input tensor returned by a ShapeOf operator.
/// \param k Index of the EyeLike diagonal to be populated with ones.
/// 0 populates the main diagonal, k > 0 populates an upper diagonal,
/// and k < 0 populates a lower diagonal.
///
/// \returns A vector of 5 values. The first value is the size of the inner identity matrix.
/// The second value is the padding value for the left side of the inner identity matrix.
/// The third value is the padding value for the right side of the inner identity matrix.
/// The fourth value is the padding value for the top side of the inner identity matrix.
/// The fifth value is the padding value for the bottom side of the inner identity matrix.
OutputVector eyelike_component_dimensions(const Output<ngraph::Node>& shape, std::int64_t k) {
const auto dims = get_shape_width_and_height(shape);
const auto width = dims.at(0);
const auto height = dims.at(1);
// x1 and y1 are padding values for the left side and top side of the identity matrix.
const auto x1 = std::max(static_cast<int64_t>(0), k);
const auto y1 = std::max(static_cast<int64_t>(0), -k);
const auto x1_const = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {x1});
const auto y1_const = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {y1});
// upper_pads is a helper value for calculating the size of the inner identity matrix.
const auto upper_pads = default_opset::Constant::create(ngraph::element::i64, Shape{2}, {y1, x1});
// a is the size of the inner identity matrix.
const auto zero = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {0});
const auto min_size =
std::make_shared<default_opset::ReduceMin>(std::make_shared<default_opset::Subtract>(shape, upper_pads),
zero,
true);
const auto a = std::make_shared<default_opset::Maximum>(min_size, zero);
// x2 and y2 are padding values for the right side and bottom side of the identity matrix.
// x2 = width - a - x1
// y2 = height - a - y1
const auto x2 =
std::make_shared<default_opset::Subtract>(std::make_shared<default_opset::Subtract>(width, a), x1_const);
const auto y2 =
std::make_shared<default_opset::Subtract>(std::make_shared<default_opset::Subtract>(height, a), y1_const);
return {a, x1_const, x2, y1_const, y2};
}
/// \brief Create a square identity matrix with the specified size and type.
/// \details The identity matrix consists of ones on the main diagonal and zeros elsewhere.
/// \param matrix_size Size of a side of the identity matrix.
/// \param target_type Data type of the identity matrix.
Output<ngraph::Node> square_identity_matrix(const Output<ngraph::Node>& matrix_size, element::Type target_type) {
// Construct a 1D representation of the identity matrix data
// One and zero are the values of the identity matrix.
const auto zero = default_opset::Constant::create(target_type, Shape{1}, {0});
const auto one = default_opset::Constant::create(target_type, Shape{1}, {1});
// One row of the identity matrix.
const auto zeros = std::make_shared<default_opset::Tile>(zero, matrix_size);
const auto one_followed_by_zeros = std::make_shared<default_opset::Concat>(OutputVector{one, zeros}, 0);
// The identity matrix as a 1D representation.
const auto one_int = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {1});
const auto size_minus_one = std::make_shared<default_opset::Subtract>(matrix_size, one_int);
const auto one_d_data = std::make_shared<default_opset::Tile>(one_followed_by_zeros, size_minus_one);
const auto one_d_data_concat = std::make_shared<default_opset::Concat>(OutputVector{one_d_data, one}, 0);
// Reshape the 1D array to a 2D array
const auto output_shape = std::make_shared<default_opset::Concat>(OutputVector{matrix_size, matrix_size}, 0);
const auto diagonal = std::make_shared<default_opset::Reshape>(one_d_data_concat, output_shape, false);
return diagonal;
}
} // namespace
} // namespace detail
namespace set_1 {
OutputVector eye_like(const Node& node) {
const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input.get_shape();
const auto& input_rank = input.get_partial_shape().rank();
CHECK_VALID_NODE(node,
input_rank.compatible(Rank(2)),
"The provided shape rank: ",
input_rank.get_length(),
" is unsupported, only 2D shapes are supported");
const auto shift = node.get_attribute_value<std::int64_t>("k", 0);
std::int64_t dtype;
element::Type target_type;
std::int64_t shift = node.get_attribute_value<std::int64_t>("k", 0);
if (node.has_attribute("dtype")) {
dtype = node.get_attribute_value<std::int64_t>("dtype");
target_type = common::get_ngraph_element_type(dtype);
@ -28,21 +127,26 @@ OutputVector eye_like(const Node& node) {
target_type = input.get_element_type();
}
CHECK_VALID_NODE(node,
input_shape.size() == 2,
"The provided shape rank: ",
input_shape.size(),
" is unsupported, only 2D shapes are supported");
const auto input_shape = std::make_shared<default_opset::ShapeOf>(input);
std::shared_ptr<ngraph::Node> eye_like_matrix = common::shifted_square_identity(input_shape, target_type, shift);
const auto component_dimensions = detail::eyelike_component_dimensions(input_shape, shift);
const auto identity_matrix = detail::square_identity_matrix(component_dimensions.at(0), target_type);
return {eye_like_matrix};
const auto pads_begin =
std::make_shared<default_opset::Concat>(OutputVector{component_dimensions.at(3), component_dimensions.at(1)},
0);
const auto pads_end =
std::make_shared<default_opset::Concat>(OutputVector{component_dimensions.at(4), component_dimensions.at(2)},
0);
const auto zero = default_opset::Constant::create(target_type, Shape{}, {0});
const auto output =
std::make_shared<default_opset::Pad>(identity_matrix, pads_begin, pads_end, zero, ov::op::PadMode::CONSTANT);
return {output};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -11,11 +11,10 @@ namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector eye_like(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -0,0 +1,46 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "EyeLike"
attribute {
name: "k"
i: -1
type: INT
}
}
name: "test_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -11,7 +11,7 @@ graph {
type: INT
}
}
name: "hardmax_graph"
name: "test_graph"
input {
name: "x"
type {

View File

@ -0,0 +1,45 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "EyeLike"
attribute {
name: "k"
i: -1
type: INT
}
}
name: "test_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
}
}
}
}
opset_import {
version: 9
}

View File

@ -2525,6 +2525,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_eye_like) {
const auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{3, 4}, {5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f});
test_case.add_expected_output<float>(Shape{3, 4}, {0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f});
test_case.run();

View File

@ -1254,3 +1254,25 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_space_to_depth_dynamic_input) {
test_case.add_expected_output(Shape{1, 8, 2, 2}, expected_output);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_eye_like_dyn_shape) {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dynamic_shapes/eye_like_dyn_shape.onnx"));
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
test_case.add_input<float>(Shape{3, 4}, {5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f});
test_case.add_expected_output<float>(Shape{3, 4}, {0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_eye_like_dyn_rank) {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dynamic_shapes/eye_like_dyn_rank.onnx"));
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
test_case.add_input<float>(Shape{3, 4}, {5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f, 5.f});
test_case.add_expected_output<float>(Shape{3, 4}, {0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f});
test_case.run();
}

View File

@ -55,8 +55,7 @@ TEST(onnx_importer, exception_msg_onnx_node_validation_failure) {
// This test should throw a std error because of attempt to access shape from dynamic tensor.
TEST(onnx_importer, exception_msg_std_err_wrapped) {
try {
onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dynamic_shapes/eye_link_dyn_shape.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like_wrong_shape.onnx"));
// Should have thrown, so fail if it didn't
FAIL() << "ONNX Importer did not detected incorrect model!";
} catch (const std::exception& e) {

View File

@ -235,6 +235,8 @@ onnx_model_rnn_defaults_fwd_const_dynamic
onnx_model_depth_to_space_dynamic_input
onnx_model_space_to_depth_dynamic_input
squeeze_dynamic
onnx_model_eye_like_dyn_shape
onnx_model_eye_like_dyn_rank
# Constant network

View File

@ -483,3 +483,34 @@ def test_constant_err():
ng_results = run_node(node, [])
assert np.allclose(ng_results, [values])
@pytest.mark.parametrize(
"shape, shift",
[
((4, 4), 0),
((4, 4), 1),
((4, 4), -1),
((4, 4), 2),
((4, 4), -2),
((4, 4), 3),
((4, 4), -3),
((3, 4), 0),
((3, 4), 1),
((3, 4), -1),
((3, 4), 2),
((3, 4), -2),
((5, 3), 0),
((5, 3), 1),
((5, 3), -1),
((5, 3), 2),
((5, 3), -2),
],
)
def test_eye_like(shape, shift):
input_tensor = np.arange(np.prod(shape)).reshape(shape)
node = onnx.helper.make_node("EyeLike", inputs=["x"], outputs=["y"], k=shift)
result = run_node(node, [input_tensor])[0]
assert np.allclose(result, np.eye(shape[0], shape[1], k=shift))