From 50287625d70e27f04014a223a6806f318d86aa1d Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Thu, 5 May 2022 16:16:36 +0200 Subject: [PATCH] [Eye-9] Python API for Eye-9 (#11552) --- .../src/compatibility/ngraph/__init__.py | 329 +++++++++--------- .../compatibility/ngraph/opset9/__init__.py | 168 +++++++++ .../src/compatibility/ngraph/opset9/ops.py | 48 +++ .../ngraph/utils/node_factory.py | 2 +- .../compatibility/pyngraph/node_factory.cpp | 3 +- .../python/src/openvino/runtime/__init__.py | 33 +- .../src/openvino/runtime/opset9/__init__.py | 168 +++++++++ .../python/src/openvino/runtime/opset9/ops.py | 48 +++ .../src/pyopenvino/graph/node_factory.cpp | 3 +- .../pyopenvino/graph/passes/pattern_ops.cpp | 1 + .../python/tests/test_ngraph/test_eye.py | 102 ++++++ .../test_transformations/test_pattern_ops.py | 5 +- .../test_ngraph/test_eye.py | 103 ++++++ 13 files changed, 828 insertions(+), 185 deletions(-) create mode 100644 src/bindings/python/src/compatibility/ngraph/opset9/__init__.py create mode 100644 src/bindings/python/src/compatibility/ngraph/opset9/ops.py create mode 100644 src/bindings/python/src/openvino/runtime/opset9/__init__.py create mode 100644 src/bindings/python/src/openvino/runtime/opset9/ops.py create mode 100644 src/bindings/python/tests/test_ngraph/test_eye.py create mode 100644 src/bindings/python/tests_compatibility/test_ngraph/test_eye.py diff --git a/src/bindings/python/src/compatibility/ngraph/__init__.py b/src/bindings/python/src/compatibility/ngraph/__init__.py index 9bbd1ad9607..d1ba229cb61 100644 --- a/src/bindings/python/src/compatibility/ngraph/__init__.py +++ b/src/bindings/python/src/compatibility/ngraph/__init__.py @@ -17,170 +17,171 @@ from ngraph.impl import Node from ngraph.impl import PartialShape from ngraph.helpers import function_from_cnn from ngraph.helpers import function_to_cnn -from ngraph.opset8 import absolute -from ngraph.opset8 import absolute as abs -from ngraph.opset8 import acos -from ngraph.opset8 import acosh -from ngraph.opset8 import adaptive_avg_pool -from ngraph.opset8 import adaptive_max_pool -from ngraph.opset8 import add -from ngraph.opset8 import asin -from ngraph.opset8 import asinh -from ngraph.opset8 import assign -from ngraph.opset8 import atan -from ngraph.opset8 import atanh -from ngraph.opset8 import avg_pool -from ngraph.opset8 import batch_norm_inference -from ngraph.opset8 import batch_to_space -from ngraph.opset8 import binary_convolution -from ngraph.opset8 import broadcast -from ngraph.opset8 import bucketize -from ngraph.opset8 import ceiling -from ngraph.opset8 import ceiling as ceil -from ngraph.opset8 import clamp -from ngraph.opset8 import concat -from ngraph.opset8 import constant -from ngraph.opset8 import convert -from ngraph.opset8 import convert_like -from ngraph.opset8 import convolution -from ngraph.opset8 import convolution_backprop_data -from ngraph.opset8 import cos -from ngraph.opset8 import cosh -from ngraph.opset8 import ctc_greedy_decoder -from ngraph.opset8 import ctc_greedy_decoder_seq_len -from ngraph.opset8 import ctc_loss -from ngraph.opset8 import cum_sum -from ngraph.opset8 import cum_sum as cumsum -from ngraph.opset8 import deformable_convolution -from ngraph.opset8 import deformable_psroi_pooling -from ngraph.opset8 import depth_to_space -from ngraph.opset8 import detection_output -from ngraph.opset8 import dft -from ngraph.opset8 import divide -from ngraph.opset8 import einsum -from ngraph.opset8 import elu -from ngraph.opset8 import embedding_bag_offsets_sum -from ngraph.opset8 import embedding_bag_packed_sum -from ngraph.opset8 import embedding_segments_sum -from ngraph.opset8 import extract_image_patches -from ngraph.opset8 import equal -from ngraph.opset8 import erf -from ngraph.opset8 import exp -from ngraph.opset8 import fake_quantize -from ngraph.opset8 import floor -from ngraph.opset8 import floor_mod -from ngraph.opset8 import gather -from ngraph.opset8 import gather_elements -from ngraph.opset8 import gather_nd -from ngraph.opset8 import gather_tree -from ngraph.opset8 import gelu -from ngraph.opset8 import greater -from ngraph.opset8 import greater_equal -from ngraph.opset8 import grn -from ngraph.opset8 import group_convolution -from ngraph.opset8 import group_convolution_backprop_data -from ngraph.opset8 import gru_cell -from ngraph.opset8 import gru_sequence -from ngraph.opset8 import hard_sigmoid -from ngraph.opset8 import hsigmoid -from ngraph.opset8 import hswish -from ngraph.opset8 import idft -from ngraph.opset8 import if_op -from ngraph.opset8 import interpolate -from ngraph.opset8 import i420_to_bgr -from ngraph.opset8 import i420_to_rgb -from ngraph.opset8 import less -from ngraph.opset8 import less_equal -from ngraph.opset8 import log -from ngraph.opset8 import logical_and -from ngraph.opset8 import logical_not -from ngraph.opset8 import logical_or -from ngraph.opset8 import logical_xor -from ngraph.opset8 import log_softmax -from ngraph.opset8 import loop -from ngraph.opset8 import lrn -from ngraph.opset8 import lstm_cell -from ngraph.opset8 import lstm_sequence -from ngraph.opset8 import matmul -from ngraph.opset8 import matrix_nms -from ngraph.opset8 import max_pool -from ngraph.opset8 import maximum -from ngraph.opset8 import minimum -from ngraph.opset8 import mish -from ngraph.opset8 import mod -from ngraph.opset8 import multiclass_nms -from ngraph.opset8 import multiply -from ngraph.opset8 import mvn -from ngraph.opset8 import negative -from ngraph.opset8 import non_max_suppression -from ngraph.opset8 import non_zero -from ngraph.opset8 import normalize_l2 -from ngraph.opset8 import not_equal -from ngraph.opset8 import nv12_to_bgr -from ngraph.opset8 import nv12_to_rgb -from ngraph.opset8 import one_hot -from ngraph.opset8 import pad -from ngraph.opset8 import parameter -from ngraph.opset8 import power -from ngraph.opset8 import prelu -from ngraph.opset8 import prior_box -from ngraph.opset8 import prior_box_clustered -from ngraph.opset8 import psroi_pooling -from ngraph.opset8 import proposal -from ngraph.opset8 import random_uniform -from ngraph.opset8 import range -from ngraph.opset8 import read_value -from ngraph.opset8 import reduce_l1 -from ngraph.opset8 import reduce_l2 -from ngraph.opset8 import reduce_logical_and -from ngraph.opset8 import reduce_logical_or -from ngraph.opset8 import reduce_max -from ngraph.opset8 import reduce_mean -from ngraph.opset8 import reduce_min -from ngraph.opset8 import reduce_prod -from ngraph.opset8 import reduce_sum -from ngraph.opset8 import region_yolo -from ngraph.opset8 import reorg_yolo -from ngraph.opset8 import relu -from ngraph.opset8 import reshape -from ngraph.opset8 import result -from ngraph.opset8 import reverse_sequence -from ngraph.opset8 import rnn_cell -from ngraph.opset8 import rnn_sequence -from ngraph.opset8 import roi_align -from ngraph.opset8 import roi_pooling -from ngraph.opset8 import roll -from ngraph.opset8 import round -from ngraph.opset8 import scatter_elements_update -from ngraph.opset8 import scatter_update -from ngraph.opset8 import select -from ngraph.opset8 import selu -from ngraph.opset8 import shape_of -from ngraph.opset8 import shuffle_channels -from ngraph.opset8 import sigmoid -from ngraph.opset8 import sign -from ngraph.opset8 import sin -from ngraph.opset8 import sinh -from ngraph.opset8 import slice -from ngraph.opset8 import softmax -from ngraph.opset8 import softplus -from ngraph.opset8 import space_to_batch -from ngraph.opset8 import space_to_depth -from ngraph.opset8 import split -from ngraph.opset8 import sqrt -from ngraph.opset8 import squared_difference -from ngraph.opset8 import squeeze -from ngraph.opset8 import strided_slice -from ngraph.opset8 import subtract -from ngraph.opset8 import swish -from ngraph.opset8 import tan -from ngraph.opset8 import tanh -from ngraph.opset8 import tensor_iterator -from ngraph.opset8 import tile -from ngraph.opset8 import topk -from ngraph.opset8 import transpose -from ngraph.opset8 import unsqueeze -from ngraph.opset8 import variadic_split +from ngraph.opset9 import absolute +from ngraph.opset9 import absolute as abs +from ngraph.opset9 import acos +from ngraph.opset9 import acosh +from ngraph.opset9 import adaptive_avg_pool +from ngraph.opset9 import adaptive_max_pool +from ngraph.opset9 import add +from ngraph.opset9 import asin +from ngraph.opset9 import asinh +from ngraph.opset9 import assign +from ngraph.opset9 import atan +from ngraph.opset9 import atanh +from ngraph.opset9 import avg_pool +from ngraph.opset9 import batch_norm_inference +from ngraph.opset9 import batch_to_space +from ngraph.opset9 import binary_convolution +from ngraph.opset9 import broadcast +from ngraph.opset9 import bucketize +from ngraph.opset9 import ceiling +from ngraph.opset9 import ceiling as ceil +from ngraph.opset9 import clamp +from ngraph.opset9 import concat +from ngraph.opset9 import constant +from ngraph.opset9 import convert +from ngraph.opset9 import convert_like +from ngraph.opset9 import convolution +from ngraph.opset9 import convolution_backprop_data +from ngraph.opset9 import cos +from ngraph.opset9 import cosh +from ngraph.opset9 import ctc_greedy_decoder +from ngraph.opset9 import ctc_greedy_decoder_seq_len +from ngraph.opset9 import ctc_loss +from ngraph.opset9 import cum_sum +from ngraph.opset9 import cum_sum as cumsum +from ngraph.opset9 import deformable_convolution +from ngraph.opset9 import deformable_psroi_pooling +from ngraph.opset9 import depth_to_space +from ngraph.opset9 import detection_output +from ngraph.opset9 import dft +from ngraph.opset9 import divide +from ngraph.opset9 import einsum +from ngraph.opset9 import elu +from ngraph.opset9 import embedding_bag_offsets_sum +from ngraph.opset9 import embedding_bag_packed_sum +from ngraph.opset9 import embedding_segments_sum +from ngraph.opset9 import extract_image_patches +from ngraph.opset9 import equal +from ngraph.opset9 import erf +from ngraph.opset9 import exp +from ngraph.opset9 import eye +from ngraph.opset9 import fake_quantize +from ngraph.opset9 import floor +from ngraph.opset9 import floor_mod +from ngraph.opset9 import gather +from ngraph.opset9 import gather_elements +from ngraph.opset9 import gather_nd +from ngraph.opset9 import gather_tree +from ngraph.opset9 import gelu +from ngraph.opset9 import greater +from ngraph.opset9 import greater_equal +from ngraph.opset9 import grn +from ngraph.opset9 import group_convolution +from ngraph.opset9 import group_convolution_backprop_data +from ngraph.opset9 import gru_cell +from ngraph.opset9 import gru_sequence +from ngraph.opset9 import hard_sigmoid +from ngraph.opset9 import hsigmoid +from ngraph.opset9 import hswish +from ngraph.opset9 import idft +from ngraph.opset9 import if_op +from ngraph.opset9 import interpolate +from ngraph.opset9 import i420_to_bgr +from ngraph.opset9 import i420_to_rgb +from ngraph.opset9 import less +from ngraph.opset9 import less_equal +from ngraph.opset9 import log +from ngraph.opset9 import logical_and +from ngraph.opset9 import logical_not +from ngraph.opset9 import logical_or +from ngraph.opset9 import logical_xor +from ngraph.opset9 import log_softmax +from ngraph.opset9 import loop +from ngraph.opset9 import lrn +from ngraph.opset9 import lstm_cell +from ngraph.opset9 import lstm_sequence +from ngraph.opset9 import matmul +from ngraph.opset9 import matrix_nms +from ngraph.opset9 import max_pool +from ngraph.opset9 import maximum +from ngraph.opset9 import minimum +from ngraph.opset9 import mish +from ngraph.opset9 import mod +from ngraph.opset9 import multiclass_nms +from ngraph.opset9 import multiply +from ngraph.opset9 import mvn +from ngraph.opset9 import negative +from ngraph.opset9 import non_max_suppression +from ngraph.opset9 import non_zero +from ngraph.opset9 import normalize_l2 +from ngraph.opset9 import not_equal +from ngraph.opset9 import nv12_to_bgr +from ngraph.opset9 import nv12_to_rgb +from ngraph.opset9 import one_hot +from ngraph.opset9 import pad +from ngraph.opset9 import parameter +from ngraph.opset9 import power +from ngraph.opset9 import prelu +from ngraph.opset9 import prior_box +from ngraph.opset9 import prior_box_clustered +from ngraph.opset9 import psroi_pooling +from ngraph.opset9 import proposal +from ngraph.opset9 import random_uniform +from ngraph.opset9 import range +from ngraph.opset9 import read_value +from ngraph.opset9 import reduce_l1 +from ngraph.opset9 import reduce_l2 +from ngraph.opset9 import reduce_logical_and +from ngraph.opset9 import reduce_logical_or +from ngraph.opset9 import reduce_max +from ngraph.opset9 import reduce_mean +from ngraph.opset9 import reduce_min +from ngraph.opset9 import reduce_prod +from ngraph.opset9 import reduce_sum +from ngraph.opset9 import region_yolo +from ngraph.opset9 import reorg_yolo +from ngraph.opset9 import relu +from ngraph.opset9 import reshape +from ngraph.opset9 import result +from ngraph.opset9 import reverse_sequence +from ngraph.opset9 import rnn_cell +from ngraph.opset9 import rnn_sequence +from ngraph.opset9 import roi_align +from ngraph.opset9 import roi_pooling +from ngraph.opset9 import roll +from ngraph.opset9 import round +from ngraph.opset9 import scatter_elements_update +from ngraph.opset9 import scatter_update +from ngraph.opset9 import select +from ngraph.opset9 import selu +from ngraph.opset9 import shape_of +from ngraph.opset9 import shuffle_channels +from ngraph.opset9 import sigmoid +from ngraph.opset9 import sign +from ngraph.opset9 import sin +from ngraph.opset9 import sinh +from ngraph.opset9 import slice +from ngraph.opset9 import softmax +from ngraph.opset9 import softplus +from ngraph.opset9 import space_to_batch +from ngraph.opset9 import space_to_depth +from ngraph.opset9 import split +from ngraph.opset9 import sqrt +from ngraph.opset9 import squared_difference +from ngraph.opset9 import squeeze +from ngraph.opset9 import strided_slice +from ngraph.opset9 import subtract +from ngraph.opset9 import swish +from ngraph.opset9 import tan +from ngraph.opset9 import tanh +from ngraph.opset9 import tensor_iterator +from ngraph.opset9 import tile +from ngraph.opset9 import topk +from ngraph.opset9 import transpose +from ngraph.opset9 import unsqueeze +from ngraph.opset9 import variadic_split # Extend Node class to support binary operators diff --git a/src/bindings/python/src/compatibility/ngraph/opset9/__init__.py b/src/bindings/python/src/compatibility/ngraph/opset9/__init__.py new file mode 100644 index 00000000000..c531ef3476f --- /dev/null +++ b/src/bindings/python/src/compatibility/ngraph/opset9/__init__.py @@ -0,0 +1,168 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from ngraph.opset1.ops import absolute +from ngraph.opset1.ops import absolute as abs +from ngraph.opset1.ops import acos +from ngraph.opset4.ops import acosh +from ngraph.opset8.ops import adaptive_avg_pool +from ngraph.opset8.ops import adaptive_max_pool +from ngraph.opset1.ops import add +from ngraph.opset1.ops import asin +from ngraph.opset4.ops import asinh +from ngraph.opset3.ops import assign +from ngraph.opset1.ops import atan +from ngraph.opset4.ops import atanh +from ngraph.opset1.ops import avg_pool +from ngraph.opset5.ops import batch_norm_inference +from ngraph.opset2.ops import batch_to_space +from ngraph.opset1.ops import binary_convolution +from ngraph.opset3.ops import broadcast +from ngraph.opset3.ops import bucketize +from ngraph.opset1.ops import ceiling +from ngraph.opset1.ops import ceiling as ceil +from ngraph.opset1.ops import clamp +from ngraph.opset1.ops import concat +from ngraph.opset1.ops import constant +from ngraph.opset1.ops import convert +from ngraph.opset1.ops import convert_like +from ngraph.opset1.ops import convolution +from ngraph.opset1.ops import convolution_backprop_data +from ngraph.opset1.ops import cos +from ngraph.opset1.ops import cosh +from ngraph.opset1.ops import ctc_greedy_decoder +from ngraph.opset6.ops import ctc_greedy_decoder_seq_len +from ngraph.opset4.ops import ctc_loss +from ngraph.opset3.ops import cum_sum +from ngraph.opset3.ops import cum_sum as cumsum +from ngraph.opset8.ops import deformable_convolution +from ngraph.opset1.ops import deformable_psroi_pooling +from ngraph.opset1.ops import depth_to_space +from ngraph.opset8.ops import detection_output +from ngraph.opset7.ops import dft +from ngraph.opset1.ops import divide +from ngraph.opset7.ops import einsum +from ngraph.opset1.ops import elu +from ngraph.opset3.ops import embedding_bag_offsets_sum +from ngraph.opset3.ops import embedding_bag_packed_sum +from ngraph.opset3.ops import embedding_segments_sum +from ngraph.opset3.ops import extract_image_patches +from ngraph.opset1.ops import equal +from ngraph.opset1.ops import erf +from ngraph.opset1.ops import exp +from ngraph.opset9.ops import eye +from ngraph.opset1.ops import fake_quantize +from ngraph.opset1.ops import floor +from ngraph.opset1.ops import floor_mod +from ngraph.opset8.ops import gather +from ngraph.opset6.ops import gather_elements +from ngraph.opset8.ops import gather_nd +from ngraph.opset1.ops import gather_tree +from ngraph.opset7.ops import gelu +from ngraph.opset1.ops import greater +from ngraph.opset1.ops import greater_equal +from ngraph.opset1.ops import grn +from ngraph.opset1.ops import group_convolution +from ngraph.opset1.ops import group_convolution_backprop_data +from ngraph.opset3.ops import gru_cell +from ngraph.opset5.ops import gru_sequence +from ngraph.opset1.ops import hard_sigmoid +from ngraph.opset5.ops import hsigmoid +from ngraph.opset4.ops import hswish +from ngraph.opset7.ops import idft +from ngraph.opset8.ops import if_op +from ngraph.opset1.ops import interpolate +from ngraph.opset8.ops import i420_to_bgr +from ngraph.opset8.ops import i420_to_rgb +from ngraph.opset1.ops import less +from ngraph.opset1.ops import less_equal +from ngraph.opset1.ops import log +from ngraph.opset1.ops import logical_and +from ngraph.opset1.ops import logical_not +from ngraph.opset1.ops import logical_or +from ngraph.opset1.ops import logical_xor +from ngraph.opset5.ops import log_softmax +from ngraph.opset5.ops import loop +from ngraph.opset1.ops import lrn +from ngraph.opset4.ops import lstm_cell +from ngraph.opset5.ops import lstm_sequence +from ngraph.opset1.ops import matmul +from ngraph.opset8.ops import matrix_nms +from ngraph.opset8.ops import max_pool +from ngraph.opset1.ops import maximum +from ngraph.opset1.ops import minimum +from ngraph.opset4.ops import mish +from ngraph.opset1.ops import mod +from ngraph.opset8.ops import multiclass_nms +from ngraph.opset1.ops import multiply +from ngraph.opset6.ops import mvn +from ngraph.opset1.ops import negative +from ngraph.opset5.ops import non_max_suppression +from ngraph.opset3.ops import non_zero +from ngraph.opset1.ops import normalize_l2 +from ngraph.opset1.ops import not_equal +from ngraph.opset8.ops import nv12_to_bgr +from ngraph.opset8.ops import nv12_to_rgb +from ngraph.opset1.ops import one_hot +from ngraph.opset1.ops import pad +from ngraph.opset1.ops import parameter +from ngraph.opset1.ops import power +from ngraph.opset1.ops import prelu +from ngraph.opset8.ops import prior_box +from ngraph.opset1.ops import prior_box_clustered +from ngraph.opset1.ops import psroi_pooling +from ngraph.opset4.ops import proposal +from ngraph.opset8.ops import random_uniform +from ngraph.opset1.ops import range +from ngraph.opset3.ops import read_value +from ngraph.opset4.ops import reduce_l1 +from ngraph.opset4.ops import reduce_l2 +from ngraph.opset1.ops import reduce_logical_and +from ngraph.opset1.ops import reduce_logical_or +from ngraph.opset1.ops import reduce_max +from ngraph.opset1.ops import reduce_mean +from ngraph.opset1.ops import reduce_min +from ngraph.opset1.ops import reduce_prod +from ngraph.opset1.ops import reduce_sum +from ngraph.opset1.ops import region_yolo +from ngraph.opset2.ops import reorg_yolo +from ngraph.opset1.ops import relu +from ngraph.opset1.ops import reshape +from ngraph.opset1.ops import result +from ngraph.opset1.ops import reverse_sequence +from ngraph.opset3.ops import rnn_cell +from ngraph.opset5.ops import rnn_sequence +from ngraph.opset3.ops import roi_align +from ngraph.opset2.ops import roi_pooling +from ngraph.opset7.ops import roll +from ngraph.opset5.ops import round +from ngraph.opset3.ops import scatter_elements_update +from ngraph.opset3.ops import scatter_update +from ngraph.opset1.ops import select +from ngraph.opset1.ops import selu +from ngraph.opset3.ops import shape_of +from ngraph.opset3.ops import shuffle_channels +from ngraph.opset1.ops import sigmoid +from ngraph.opset1.ops import sign +from ngraph.opset1.ops import sin +from ngraph.opset1.ops import sinh +from ngraph.opset8.ops import slice +from ngraph.opset8.ops import softmax +from ngraph.opset4.ops import softplus +from ngraph.opset2.ops import space_to_batch +from ngraph.opset1.ops import space_to_depth +from ngraph.opset1.ops import split +from ngraph.opset1.ops import sqrt +from ngraph.opset1.ops import squared_difference +from ngraph.opset1.ops import squeeze +from ngraph.opset1.ops import strided_slice +from ngraph.opset1.ops import subtract +from ngraph.opset4.ops import swish +from ngraph.opset1.ops import tan +from ngraph.opset1.ops import tanh +from ngraph.opset1.ops import tensor_iterator +from ngraph.opset1.ops import tile +from ngraph.opset3.ops import topk +from ngraph.opset1.ops import transpose +from ngraph.opset1.ops import unsqueeze +from ngraph.opset1.ops import variadic_split diff --git a/src/bindings/python/src/compatibility/ngraph/opset9/ops.py b/src/bindings/python/src/compatibility/ngraph/opset9/ops.py new file mode 100644 index 00000000000..cd38f699f69 --- /dev/null +++ b/src/bindings/python/src/compatibility/ngraph/opset9/ops.py @@ -0,0 +1,48 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Factory functions for all ngraph ops.""" +from functools import partial +from typing import Optional + +import numpy as np +from ngraph.impl import Node +from ngraph.opset_utils import _get_node_factory +from ngraph.utils.decorators import nameable_op +from ngraph.utils.types import ( + NodeInput, + as_nodes, + as_node +) + +_get_node_factory_opset9 = partial(_get_node_factory, "opset9") + + +# -------------------------------------------- ops ------------------------------------------------ + + +@nameable_op +def eye( + num_rows: NodeInput, + num_columns: NodeInput, + diagonal_index: NodeInput, + output_type: str, + batch_shape: Optional[NodeInput] = None, + name: Optional[str] = None, +) -> Node: + """Return a node which performs eye operation. + + :param num_rows: The node providing row number tensor. + :param num_columns: The node providing column number tensor. + :param diagonal_index: The node providing the index of the diagonal to be populated. + :param output_type: Specifies the output tensor type, supports any numeric types. + :param batch_shape: The node providing the leading batch dimensions of output shape. Optionally. + :param name: The optional new name for output node. + :return: New node performing deformable convolution operation. + """ + if batch_shape is not None: + inputs = as_nodes(num_rows, num_columns, diagonal_index, batch_shape) + else: + inputs = as_nodes(num_rows, num_columns, diagonal_index) + + return _get_node_factory_opset9().create("Eye", inputs, {"output_type": output_type}) diff --git a/src/bindings/python/src/compatibility/ngraph/utils/node_factory.py b/src/bindings/python/src/compatibility/ngraph/utils/node_factory.py index 1f712515fd5..4c18d81fa98 100644 --- a/src/bindings/python/src/compatibility/ngraph/utils/node_factory.py +++ b/src/bindings/python/src/compatibility/ngraph/utils/node_factory.py @@ -12,7 +12,7 @@ from ngraph.impl import Node, Output from ngraph.exceptions import UserInputError -DEFAULT_OPSET = "opset8" +DEFAULT_OPSET = "opset9" class NodeFactory(object): diff --git a/src/bindings/python/src/compatibility/pyngraph/node_factory.cpp b/src/bindings/python/src/compatibility/pyngraph/node_factory.cpp index 609e4202e24..23a712a60a0 100644 --- a/src/bindings/python/src/compatibility/pyngraph/node_factory.cpp +++ b/src/bindings/python/src/compatibility/pyngraph/node_factory.cpp @@ -83,6 +83,7 @@ private: {"opset6", OpsetFunction(ngraph::get_opset6)}, {"opset7", OpsetFunction(ngraph::get_opset7)}, {"opset8", OpsetFunction(ngraph::get_opset8)}, + {"opset9", OpsetFunction(ngraph::get_opset9)}, }; auto it = s_opsets.find(opset_ver); @@ -92,7 +93,7 @@ private: return it->second(); } - const ngraph::OpSet& m_opset = ngraph::get_opset8(); + const ngraph::OpSet& m_opset = ngraph::get_opset9(); std::unordered_map> m_variables; }; } // namespace diff --git a/src/bindings/python/src/openvino/runtime/__init__.py b/src/bindings/python/src/openvino/runtime/__init__.py index c037fddd589..51ccc76e230 100644 --- a/src/bindings/python/src/openvino/runtime/__init__.py +++ b/src/bindings/python/src/openvino/runtime/__init__.py @@ -56,6 +56,7 @@ from openvino.runtime import opset5 from openvino.runtime import opset6 from openvino.runtime import opset7 from openvino.runtime import opset8 +from openvino.runtime import opset9 # Import properties API from openvino.pyopenvino import properties @@ -65,19 +66,19 @@ from openvino.runtime.ie_api import tensor_from_file from openvino.runtime.ie_api import compile_model # Extend Node class to support binary operators -Node.__add__ = opset8.add -Node.__sub__ = opset8.subtract -Node.__mul__ = opset8.multiply -Node.__div__ = opset8.divide -Node.__truediv__ = opset8.divide -Node.__radd__ = lambda left, right: opset8.add(right, left) -Node.__rsub__ = lambda left, right: opset8.subtract(right, left) -Node.__rmul__ = lambda left, right: opset8.multiply(right, left) -Node.__rdiv__ = lambda left, right: opset8.divide(right, left) -Node.__rtruediv__ = lambda left, right: opset8.divide(right, left) -Node.__eq__ = opset8.equal -Node.__ne__ = opset8.not_equal -Node.__lt__ = opset8.less -Node.__le__ = opset8.less_equal -Node.__gt__ = opset8.greater -Node.__ge__ = opset8.greater_equal +Node.__add__ = opset9.add +Node.__sub__ = opset9.subtract +Node.__mul__ = opset9.multiply +Node.__div__ = opset9.divide +Node.__truediv__ = opset9.divide +Node.__radd__ = lambda left, right: opset9.add(right, left) +Node.__rsub__ = lambda left, right: opset9.subtract(right, left) +Node.__rmul__ = lambda left, right: opset9.multiply(right, left) +Node.__rdiv__ = lambda left, right: opset9.divide(right, left) +Node.__rtruediv__ = lambda left, right: opset9.divide(right, left) +Node.__eq__ = opset9.equal +Node.__ne__ = opset9.not_equal +Node.__lt__ = opset9.less +Node.__le__ = opset9.less_equal +Node.__gt__ = opset9.greater +Node.__ge__ = opset9.greater_equal diff --git a/src/bindings/python/src/openvino/runtime/opset9/__init__.py b/src/bindings/python/src/openvino/runtime/opset9/__init__.py new file mode 100644 index 00000000000..a6c3291aee9 --- /dev/null +++ b/src/bindings/python/src/openvino/runtime/opset9/__init__.py @@ -0,0 +1,168 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from openvino.runtime.opset1.ops import absolute +from openvino.runtime.opset1.ops import absolute as abs +from openvino.runtime.opset1.ops import acos +from openvino.runtime.opset4.ops import acosh +from openvino.runtime.opset8.ops import adaptive_avg_pool +from openvino.runtime.opset8.ops import adaptive_max_pool +from openvino.runtime.opset1.ops import add +from openvino.runtime.opset1.ops import asin +from openvino.runtime.opset4.ops import asinh +from openvino.runtime.opset3.ops import assign +from openvino.runtime.opset1.ops import atan +from openvino.runtime.opset4.ops import atanh +from openvino.runtime.opset1.ops import avg_pool +from openvino.runtime.opset5.ops import batch_norm_inference +from openvino.runtime.opset2.ops import batch_to_space +from openvino.runtime.opset1.ops import binary_convolution +from openvino.runtime.opset3.ops import broadcast +from openvino.runtime.opset3.ops import bucketize +from openvino.runtime.opset1.ops import ceiling +from openvino.runtime.opset1.ops import ceiling as ceil +from openvino.runtime.opset1.ops import clamp +from openvino.runtime.opset1.ops import concat +from openvino.runtime.opset1.ops import constant +from openvino.runtime.opset1.ops import convert +from openvino.runtime.opset1.ops import convert_like +from openvino.runtime.opset1.ops import convolution +from openvino.runtime.opset1.ops import convolution_backprop_data +from openvino.runtime.opset1.ops import cos +from openvino.runtime.opset1.ops import cosh +from openvino.runtime.opset1.ops import ctc_greedy_decoder +from openvino.runtime.opset6.ops import ctc_greedy_decoder_seq_len +from openvino.runtime.opset4.ops import ctc_loss +from openvino.runtime.opset3.ops import cum_sum +from openvino.runtime.opset3.ops import cum_sum as cumsum +from openvino.runtime.opset8.ops import deformable_convolution +from openvino.runtime.opset1.ops import deformable_psroi_pooling +from openvino.runtime.opset1.ops import depth_to_space +from openvino.runtime.opset8.ops import detection_output +from openvino.runtime.opset7.ops import dft +from openvino.runtime.opset1.ops import divide +from openvino.runtime.opset7.ops import einsum +from openvino.runtime.opset1.ops import elu +from openvino.runtime.opset3.ops import embedding_bag_offsets_sum +from openvino.runtime.opset3.ops import embedding_bag_packed_sum +from openvino.runtime.opset3.ops import embedding_segments_sum +from openvino.runtime.opset3.ops import extract_image_patches +from openvino.runtime.opset1.ops import equal +from openvino.runtime.opset1.ops import erf +from openvino.runtime.opset1.ops import exp +from openvino.runtime.opset9.ops import eye +from openvino.runtime.opset1.ops import fake_quantize +from openvino.runtime.opset1.ops import floor +from openvino.runtime.opset1.ops import floor_mod +from openvino.runtime.opset8.ops import gather +from openvino.runtime.opset6.ops import gather_elements +from openvino.runtime.opset8.ops import gather_nd +from openvino.runtime.opset1.ops import gather_tree +from openvino.runtime.opset7.ops import gelu +from openvino.runtime.opset1.ops import greater +from openvino.runtime.opset1.ops import greater_equal +from openvino.runtime.opset1.ops import grn +from openvino.runtime.opset1.ops import group_convolution +from openvino.runtime.opset1.ops import group_convolution_backprop_data +from openvino.runtime.opset3.ops import gru_cell +from openvino.runtime.opset5.ops import gru_sequence +from openvino.runtime.opset1.ops import hard_sigmoid +from openvino.runtime.opset5.ops import hsigmoid +from openvino.runtime.opset4.ops import hswish +from openvino.runtime.opset7.ops import idft +from openvino.runtime.opset8.ops import if_op +from openvino.runtime.opset1.ops import interpolate +from openvino.runtime.opset8.ops import i420_to_bgr +from openvino.runtime.opset8.ops import i420_to_rgb +from openvino.runtime.opset1.ops import less +from openvino.runtime.opset1.ops import less_equal +from openvino.runtime.opset1.ops import log +from openvino.runtime.opset1.ops import logical_and +from openvino.runtime.opset1.ops import logical_not +from openvino.runtime.opset1.ops import logical_or +from openvino.runtime.opset1.ops import logical_xor +from openvino.runtime.opset5.ops import log_softmax +from openvino.runtime.opset5.ops import loop +from openvino.runtime.opset1.ops import lrn +from openvino.runtime.opset4.ops import lstm_cell +from openvino.runtime.opset5.ops import lstm_sequence +from openvino.runtime.opset1.ops import matmul +from openvino.runtime.opset8.ops import matrix_nms +from openvino.runtime.opset8.ops import max_pool +from openvino.runtime.opset1.ops import maximum +from openvino.runtime.opset1.ops import minimum +from openvino.runtime.opset4.ops import mish +from openvino.runtime.opset1.ops import mod +from openvino.runtime.opset8.ops import multiclass_nms +from openvino.runtime.opset1.ops import multiply +from openvino.runtime.opset6.ops import mvn +from openvino.runtime.opset1.ops import negative +from openvino.runtime.opset5.ops import non_max_suppression +from openvino.runtime.opset3.ops import non_zero +from openvino.runtime.opset1.ops import normalize_l2 +from openvino.runtime.opset1.ops import not_equal +from openvino.runtime.opset8.ops import nv12_to_bgr +from openvino.runtime.opset8.ops import nv12_to_rgb +from openvino.runtime.opset1.ops import one_hot +from openvino.runtime.opset1.ops import pad +from openvino.runtime.opset1.ops import parameter +from openvino.runtime.opset1.ops import power +from openvino.runtime.opset1.ops import prelu +from openvino.runtime.opset8.ops import prior_box +from openvino.runtime.opset1.ops import prior_box_clustered +from openvino.runtime.opset1.ops import psroi_pooling +from openvino.runtime.opset4.ops import proposal +from openvino.runtime.opset1.ops import range +from openvino.runtime.opset8.ops import random_uniform +from openvino.runtime.opset3.ops import read_value +from openvino.runtime.opset4.ops import reduce_l1 +from openvino.runtime.opset4.ops import reduce_l2 +from openvino.runtime.opset1.ops import reduce_logical_and +from openvino.runtime.opset1.ops import reduce_logical_or +from openvino.runtime.opset1.ops import reduce_max +from openvino.runtime.opset1.ops import reduce_mean +from openvino.runtime.opset1.ops import reduce_min +from openvino.runtime.opset1.ops import reduce_prod +from openvino.runtime.opset1.ops import reduce_sum +from openvino.runtime.opset1.ops import region_yolo +from openvino.runtime.opset2.ops import reorg_yolo +from openvino.runtime.opset1.ops import relu +from openvino.runtime.opset1.ops import reshape +from openvino.runtime.opset1.ops import result +from openvino.runtime.opset1.ops import reverse_sequence +from openvino.runtime.opset3.ops import rnn_cell +from openvino.runtime.opset5.ops import rnn_sequence +from openvino.runtime.opset3.ops import roi_align +from openvino.runtime.opset2.ops import roi_pooling +from openvino.runtime.opset7.ops import roll +from openvino.runtime.opset5.ops import round +from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset3.ops import scatter_update +from openvino.runtime.opset1.ops import select +from openvino.runtime.opset1.ops import selu +from openvino.runtime.opset3.ops import shape_of +from openvino.runtime.opset3.ops import shuffle_channels +from openvino.runtime.opset1.ops import sigmoid +from openvino.runtime.opset1.ops import sign +from openvino.runtime.opset1.ops import sin +from openvino.runtime.opset1.ops import sinh +from openvino.runtime.opset8.ops import slice +from openvino.runtime.opset8.ops import softmax +from openvino.runtime.opset4.ops import softplus +from openvino.runtime.opset2.ops import space_to_batch +from openvino.runtime.opset1.ops import space_to_depth +from openvino.runtime.opset1.ops import split +from openvino.runtime.opset1.ops import sqrt +from openvino.runtime.opset1.ops import squared_difference +from openvino.runtime.opset1.ops import squeeze +from openvino.runtime.opset1.ops import strided_slice +from openvino.runtime.opset1.ops import subtract +from openvino.runtime.opset4.ops import swish +from openvino.runtime.opset1.ops import tan +from openvino.runtime.opset1.ops import tanh +from openvino.runtime.opset1.ops import tensor_iterator +from openvino.runtime.opset1.ops import tile +from openvino.runtime.opset3.ops import topk +from openvino.runtime.opset1.ops import transpose +from openvino.runtime.opset1.ops import unsqueeze +from openvino.runtime.opset1.ops import variadic_split diff --git a/src/bindings/python/src/openvino/runtime/opset9/ops.py b/src/bindings/python/src/openvino/runtime/opset9/ops.py new file mode 100644 index 00000000000..187c7f08f66 --- /dev/null +++ b/src/bindings/python/src/openvino/runtime/opset9/ops.py @@ -0,0 +1,48 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Factory functions for all ngraph ops.""" +from functools import partial +from typing import Optional + +import numpy as np +from openvino.runtime import Node +from openvino.runtime.opset_utils import _get_node_factory +from openvino.runtime.utils.decorators import nameable_op +from openvino.runtime.utils.types import ( + NodeInput, + as_nodes, + as_node +) + +_get_node_factory_opset9 = partial(_get_node_factory, "opset9") + + +# -------------------------------------------- ops ------------------------------------------------ + + +@nameable_op +def eye( + num_rows: NodeInput, + num_columns: NodeInput, + diagonal_index: NodeInput, + output_type: str, + batch_shape: Optional[NodeInput] = None, + name: Optional[str] = None, +) -> Node: + """Return a node which performs eye operation. + + :param num_rows: The node providing row number tensor. + :param num_columns: The node providing column number tensor. + :param diagonal_index: The node providing the index of the diagonal to be populated. + :param output_type: Specifies the output tensor type, supports any numeric types. + :param batch_shape: The node providing the leading batch dimensions of output shape. Optionally. + :param name: The optional new name for output node. + :return: New node performing deformable convolution operation. + """ + if batch_shape is not None: + inputs = as_nodes(num_rows, num_columns, diagonal_index, batch_shape) + else: + inputs = as_nodes(num_rows, num_columns, diagonal_index) + + return _get_node_factory_opset9().create("Eye", inputs, {"output_type": output_type}) diff --git a/src/bindings/python/src/pyopenvino/graph/node_factory.cpp b/src/bindings/python/src/pyopenvino/graph/node_factory.cpp index 0b55213824c..26be3b6e6c4 100644 --- a/src/bindings/python/src/pyopenvino/graph/node_factory.cpp +++ b/src/bindings/python/src/pyopenvino/graph/node_factory.cpp @@ -83,6 +83,7 @@ private: {"opset6", OpsetFunction(ov::get_opset6)}, {"opset7", OpsetFunction(ov::get_opset7)}, {"opset8", OpsetFunction(ov::get_opset8)}, + {"opset9", OpsetFunction(ov::get_opset9)}, }; auto it = s_opsets.find(opset_ver); @@ -92,7 +93,7 @@ private: return it->second(); } - const ov::OpSet& m_opset = ov::get_opset8(); + const ov::OpSet& m_opset = ov::get_opset9(); std::unordered_map> m_variables; }; } // namespace diff --git a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp index 0bd0510778d..e76b240e56e 100644 --- a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp +++ b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp @@ -45,6 +45,7 @@ ov::NodeTypeInfo get_type(const std::string& type_name) { {"opset6", ngraph::get_opset6}, {"opset7", ngraph::get_opset7}, {"opset8", ngraph::get_opset8}, + {"opset9", ngraph::get_opset9}, }; if (!get_opset.count(opset_type)) { diff --git a/src/bindings/python/tests/test_ngraph/test_eye.py b/src/bindings/python/tests/test_ngraph/test_eye.py new file mode 100644 index 00000000000..6f08e7a4496 --- /dev/null +++ b/src/bindings/python/tests/test_ngraph/test_eye.py @@ -0,0 +1,102 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import openvino.runtime.opset9 as ov +import numpy as np +import pytest + +from tests.runtime import get_runtime +from openvino.runtime.utils.types import get_element_type_str +from openvino.runtime.utils.types import get_element_type + + +@pytest.mark.parametrize( + "num_rows, num_columns, diagonal_index, out_type", + [ + pytest.param(2, 5, 0, np.float32), + pytest.param(5, 3, 2, np.int64), + pytest.param(3, 3, -1, np.float16), + pytest.param(5, 5, -10, np.float32), + ], +) +def test_eye_rectangle(num_rows, num_columns, diagonal_index, out_type): + num_rows_array = np.array([num_rows], np.int32) + num_columns_array = np.array([num_columns], np.int32) + diagonal_index_array = np.array([diagonal_index], np.int32) + num_rows_tensor = ov.constant(num_rows_array) + num_columns_tensor = ov.constant(num_columns_array) + diagonal_index_tensor = ov.constant(diagonal_index_array) + + # Create with param names + eye_node = ov.eye(num_rows=num_rows_tensor, + num_columns=num_columns_tensor, + diagonal_index=diagonal_index_tensor, + output_type=get_element_type_str(out_type)) + + # Create with default orded + eye_node = ov.eye(num_rows_tensor, + num_columns_tensor, + diagonal_index_tensor, + get_element_type_str(out_type)) + + expected_results = np.eye(num_rows, M=num_columns, k=diagonal_index, dtype=np.float32) + + assert eye_node.get_type_name() == "Eye" + assert eye_node.get_output_size() == 1 + assert eye_node.get_output_element_type(0) == get_element_type(out_type) + assert tuple(eye_node.get_output_shape(0)) == expected_results.shape + + # TODO: Enable with Eye reference implementation + # runtime = get_runtime() + # computation = runtime.computation(eye_node) + # eye_results = computation() + # assert np.allclose(eye_results, expected_results) + + +@pytest.mark.parametrize( + "num_rows, num_columns, diagonal_index, batch_shape, out_type", + [ + pytest.param(2, 5, 0, [1], np.float32), + pytest.param(5, 3, 2, [2, 2], np.int64), + pytest.param(3, 3, -1, [1, 3, 2], np.float16), + pytest.param(5, 5, -10, [1, 1], np.float32), + ], +) +def test_eye_batch_shape(num_rows, num_columns, diagonal_index, batch_shape, out_type): + num_rows_array = np.array([num_rows], np.int32) + num_columns_array = np.array([num_columns], np.int32) + diagonal_index_array = np.array([diagonal_index], np.int32) + batch_shape_array = np.array(batch_shape, np.int32) + num_rows_tensor = ov.constant(num_rows_array) + num_columns_tensor = ov.constant(num_columns_array) + diagonal_index_tensor = ov.constant(diagonal_index_array) + batch_shape_tensor = ov.constant(batch_shape_array) + + # Create with param names + eye_node = ov.eye(num_rows=num_rows_tensor, + num_columns=num_columns_tensor, + diagonal_index=diagonal_index_tensor, + batch_shape=batch_shape_tensor, + output_type=get_element_type_str(out_type)) + + # Create with default orded + eye_node = ov.eye(num_rows_tensor, + num_columns_tensor, + diagonal_index_tensor, + get_element_type_str(out_type), + batch_shape_tensor) + + output_shape = [*batch_shape, 1, 1] + one_matrix = np.eye(num_rows, M=num_columns, k=diagonal_index, dtype=np.float32) + expected_results = np.tile(one_matrix, output_shape) + + assert eye_node.get_type_name() == "Eye" + assert eye_node.get_output_size() == 1 + assert eye_node.get_output_element_type(0) == get_element_type(out_type) + assert tuple(eye_node.get_output_shape(0)) == expected_results.shape + + # TODO: Enable with Eye reference implementation + # runtime = get_runtime() + # computation = runtime.computation(eye_node) + # eye_results = computation() + # assert np.allclose(eye_results, expected_results) diff --git a/src/bindings/python/tests/test_transformations/test_pattern_ops.py b/src/bindings/python/tests/test_transformations/test_pattern_ops.py index a670ae96d43..b3df9224f9e 100644 --- a/src/bindings/python/tests/test_transformations/test_pattern_ops.py +++ b/src/bindings/python/tests/test_transformations/test_pattern_ops.py @@ -12,12 +12,13 @@ from utils.utils import expect_exception def test_wrap_type_pattern_type(): - for i in range(1, 9): + last_opstet_number = 9 + for i in range(1, last_opstet_number + 1): WrapType("opset{}.Parameter".format(i)) WrapType("opset{}::Parameter".format(i)) # Negative check not to forget to update opset map in get_type function - expect_exception(lambda: WrapType("opset9.Parameter"), "Unsupported opset type: opset9") + expect_exception(lambda: WrapType("opset10.Parameter"), "Unsupported opset type: opset10") # Generic negative test cases expect_exception(lambda: WrapType("")) diff --git a/src/bindings/python/tests_compatibility/test_ngraph/test_eye.py b/src/bindings/python/tests_compatibility/test_ngraph/test_eye.py new file mode 100644 index 00000000000..069a7d80f0b --- /dev/null +++ b/src/bindings/python/tests_compatibility/test_ngraph/test_eye.py @@ -0,0 +1,103 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import ngraph as ng +import numpy as np +import pytest + +from ngraph.utils.types import get_element_type +from ngraph.utils.types import get_element_type_str +from tests_compatibility.runtime import get_runtime +from tests_compatibility.test_ngraph.util import run_op_node + + +@pytest.mark.parametrize( + "num_rows, num_columns, diagonal_index, out_type", + [ + pytest.param(2, 5, 0, np.float32), + pytest.param(5, 3, 2, np.int64), + pytest.param(3, 3, -1, np.float16), + pytest.param(5, 5, -10, np.float32), + ], +) +def test_eye_rectangle(num_rows, num_columns, diagonal_index, out_type): + num_rows_array = np.array([num_rows], np.int32) + num_columns_array = np.array([num_columns], np.int32) + diagonal_index_array = np.array([diagonal_index], np.int32) + num_rows_tensor = ng.constant(num_rows_array) + num_columns_tensor = ng.constant(num_columns_array) + diagonal_index_tensor = ng.constant(diagonal_index_array) + + # Create with param names + eye_node = ng.eye(num_rows=num_rows_tensor, + num_columns=num_columns_tensor, + diagonal_index=diagonal_index_tensor, + output_type=get_element_type_str(out_type)) + + # Create with default orded + eye_node = ng.eye(num_rows_tensor, + num_columns_tensor, + diagonal_index_tensor, + get_element_type_str(out_type)) + + expected_results = np.eye(num_rows, M=num_columns, k=diagonal_index, dtype=np.float32) + + assert eye_node.get_type_name() == "Eye" + assert eye_node.get_output_size() == 1 + assert eye_node.get_output_element_type(0) == get_element_type(out_type) + assert tuple(eye_node.get_output_shape(0)) == expected_results.shape + + # TODO: Enable with Eye reference implementation + # runtime = get_runtime() + # computation = runtime.computation(eye_node) + # eye_results = computation() + # assert np.allclose(eye_results, expected_results) + + +@pytest.mark.parametrize( + "num_rows, num_columns, diagonal_index, batch_shape, out_type", + [ + pytest.param(2, 5, 0, [1], np.float32), + pytest.param(5, 3, 2, [2, 2], np.int64), + pytest.param(3, 3, -1, [1, 3, 2], np.float16), + pytest.param(5, 5, -10, [1, 1], np.float32), + ], +) +def test_eye_batch_shape(num_rows, num_columns, diagonal_index, batch_shape, out_type): + num_rows_array = np.array([num_rows], np.int32) + num_columns_array = np.array([num_columns], np.int32) + diagonal_index_array = np.array([diagonal_index], np.int32) + batch_shape_array = np.array(batch_shape, np.int32) + num_rows_tensor = ng.constant(num_rows_array) + num_columns_tensor = ng.constant(num_columns_array) + diagonal_index_tensor = ng.constant(diagonal_index_array) + batch_shape_tensor = ng.constant(batch_shape_array) + + # Create with param names + eye_node = ng.eye(num_rows=num_rows_tensor, + num_columns=num_columns_tensor, + diagonal_index=diagonal_index_tensor, + batch_shape=batch_shape_tensor, + output_type=get_element_type_str(out_type)) + + # Create with default orded + eye_node = ng.eye(num_rows_tensor, + num_columns_tensor, + diagonal_index_tensor, + get_element_type_str(out_type), + batch_shape_tensor) + + output_shape = [*batch_shape, 1, 1] + one_matrix = np.eye(num_rows, M=num_columns, k=diagonal_index, dtype=np.float32) + expected_results = np.tile(one_matrix, output_shape) + + assert eye_node.get_type_name() == "Eye" + assert eye_node.get_output_size() == 1 + assert eye_node.get_output_element_type(0) == get_element_type(out_type) + assert tuple(eye_node.get_output_shape(0)) == expected_results.shape + + # TODO: Enable with Eye reference implementation + # runtime = get_runtime() + # computation = runtime.computation(eye_node) + # eye_results = computation() + # assert np.allclose(eye_results, expected_results)