Evaluate, ngraph_reader check and Python nGraph API for GatherElements (#3624)

* add GatherElements evaluate to interpreter backend

* Finally successfully run on backend

* debugged non-typical cases

* added ngraph_reader tests for GatherElements

* added Python API for GatherElements

* apply_style

* corrected python api tests

* applied comments

* style-apply

* finally corrected nGraph Python API for GatherElements

* minor corrections

* style-apply

* replaced quotes

* added blank line

* corrected evaluate and disabled unit-tests for not yet supported plugins

* style-apply

* applied comments: negative tests added and additional checks in evaluate

* added bound check for axis in evaluate

* style-apply

* apply review comments

* fast correct evaluate for GatherElements

* style-apply

* revert changes in interpreter unit_test.manifest for Gather

* 🚀 optimized general solution; added separate calculation for 2D

* 🚀 applied comments

* style-apply
This commit is contained in:
Pavel Esir 2021-01-11 19:52:47 +03:00 committed by GitHub
parent 8ec38cf039
commit d9bd59c7a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1206 additions and 150 deletions

View File

@ -0,0 +1,122 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include "ngraph_reader_tests.hpp"
TEST_F(NGraphReaderTests, ReadGatherElementsNetwork) {
std::string model = R"V0G0N(
<net name="Network" version="10">
<layers>
<layer id="0" name="data" type="Parameter" version="opset1">
<data element_type="f32" shape="3,7,5"/>
<output>
<port id="0" precision="FP32">
<dim>3</dim>
<dim>7</dim>
<dim>5</dim>
</port>
</output>
</layer>
<layer id="1" name="indices" type="Parameter" version="opset1">
<data element_type="i32" shape="3,10,5"/>
<output>
<port id="0" precision="I32">
<dim>3</dim>
<dim>10</dim>
<dim>5</dim>
</port>
</output>
</layer>
<layer id="2" name="MyGatherElements" type="GatherElements" version="opset6">
<data axis="1"/>
<input>
<port id="0">
<dim>3</dim>
<dim>7</dim>
<dim>5</dim>
</port>
<port id="1">
<dim>3</dim>
<dim>10</dim>
<dim>5</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>3</dim>
<dim>10</dim>
<dim>5</dim>
</port>
</output>
</layer>
<layer id="3" name="MyGatherND/sink_port_0" type="Result" version="opset1">
<input>
<port id="0">
<dim>3</dim>
<dim>10</dim>
<dim>5</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
<edge from-layer="1" from-port="0" to-layer="2" to-port="1"/>
<edge from-layer="2" from-port="2" to-layer="3" to-port="0"/>
</edges>
</net>
)V0G0N";
std::string modelV5 = R"V0G0N(
<net name="Network" version="5" precision="FP32" batch="1">
<layers>
<layer id="0" name="data" type="Input" precision="FP32">
<output>
<port id="0" precision="FP32">
<dim>3</dim>
<dim>7</dim>
<dim>5</dim>
</port>
</output>
</layer>
<layer id="1" name="indices" type="Input" precision="I32">
<output>
<port id="0" precision="I32">
<dim>3</dim>
<dim>10</dim>
<dim>5</dim>
</port>
</output>
</layer>
<layer id="2" name="MyGatherElements" type="GatherElements" version="opset6">
<data axis="1"/>
<input>
<port id="0">
<dim>3</dim>
<dim>7</dim>
<dim>5</dim>
</port>
<port id="1">
<dim>3</dim>
<dim>10</dim>
<dim>5</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>3</dim>
<dim>10</dim>
<dim>5</dim>
</port>
</output>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
<edge from-layer="1" from-port="0" to-layer="2" to-port="1"/>
</edges>
</net>
)V0G0N";
compareIRs(model, modelV5, 10);
}

View File

@ -0,0 +1,149 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/coordinate_index.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T, typename U>
void gather_elements(const T* data,
const U* indices,
T* out,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& out_shape,
int64_t axis)
{
if (axis < 0)
{
axis += data_shape.size();
}
if (axis < 0 || axis >= data_shape.size())
{
throw std::domain_error{
"axis for GatherElements exceeds allowed range [0, data_rank)"};
}
// in 1D case results can be achieved without additional calculations
if (data_shape.size() == 1)
{
for (int64_t i = 0; i < indices_shape[0]; i++)
{
if (indices[i] > data_shape[0])
{
throw std::domain_error{
"indices values of GatherElement exceed data size"};
}
out[i] = data[indices[i]];
}
return;
}
// 2D case is most frequent in order to run faster simpler separate solution
// implemented
size_t num_rows = indices_shape[0];
size_t num_columns = indices_shape[1];
size_t data_num_columns = data_shape[1];
if (data_shape.size() == 2)
{
int64_t idx;
if (axis == 0)
{
for (int64_t i = 0; i < num_rows; i++)
for (int64_t j = 0; j < num_columns; j++)
{
idx = indices[num_columns * i + j];
if (idx < 0 || idx > data_shape[0] - 1)
{
throw std::domain_error{
"indices values of GatherElement exceed data size"};
}
out[num_columns * i + j] = data[data_num_columns * idx + j];
}
return;
}
else // axis == 1
{
for (int64_t i = 0; i < num_rows; i++)
for (int64_t j = 0; j < num_columns; j++)
{
idx = indices[num_columns * i + j];
if (idx < 0 || idx > data_shape[1] - 1)
{
throw std::domain_error{
"indices values of GatherElement exceed data size"};
}
out[num_columns * i + j] = data[data_num_columns * i + idx];
}
return;
}
}
/*
assume data and indices are 5D and axis = 2
size of indices(N0,N1,N2,N3,N4)
size of data (N0,N1,N2',N3,N4)
the offset for indices will be
N4*N3*N2*N1*n0 + N4*N3*N2*n1 + N4*N3*n2 + N4*n3 + n4
and for data
N4*N3*N2'*N1*n0 + N4*N3*N2'*n1 + N4*N3*n2' + N4*n3 + n4
all values (except n2') are fixed or gradually increase
most of offset calculations are shared. We can rewrite offset for data as follows
data_offset = N4*N3*N2'(N1*n0 + n1) + N4*N3*n2' + (N4*n3 + n4)
N4*N3*N2' - outer_sum_inc
N4*N3*N2'(N1*n0 + n1) - outer_sum
N4*N3*n2' - n2' is red from indices tensor n2' = indices[n0,n1,n2,n3,n4]
(N4*n3 + n4) - inner_sum
*/
size_t max_inner_sum = 1;
for (int i = axis + 1; i < indices_shape.size(); i++)
max_inner_sum *= indices_shape[i];
size_t max_outer_sum = 1, outer_sum_inc = 1;
for (int i = 0; i < axis; i++)
max_outer_sum *= indices_shape[i];
for (int i = axis; i < data_shape.size(); i++)
outer_sum_inc *= data_shape[i];
max_outer_sum *= outer_sum_inc;
for (size_t outer_sum = 0, i = 0; outer_sum < max_outer_sum;
outer_sum += outer_sum_inc)
for (size_t k = 0; k < indices_shape[axis]; k++)
for (size_t inner_sum = 0; inner_sum < max_inner_sum; inner_sum++)
{
if (indices[i] < 0 || indices[i] > data_shape[axis] - 1)
{
throw std::domain_error{
"indices values of GatherElement exceed data size"};
}
out[i] = data[outer_sum + max_inner_sum * indices[i] + inner_sum];
i++;
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -45,6 +45,7 @@ packages = [
"ngraph.opset3",
"ngraph.opset4",
"ngraph.opset5",
"ngraph.opset6",
"ngraph.utils",
"ngraph.impl",
"ngraph.impl.op",

View File

@ -28,153 +28,154 @@ from ngraph.impl import Function
from ngraph.helpers import function_from_cnn
from ngraph.helpers import function_to_cnn
from ngraph.opset5 import absolute
from ngraph.opset5 import absolute as abs
from ngraph.opset5 import acos
from ngraph.opset5 import acosh
from ngraph.opset5 import add
from ngraph.opset5 import asin
from ngraph.opset5 import asinh
from ngraph.opset5 import assign
from ngraph.opset5 import atan
from ngraph.opset5 import atanh
from ngraph.opset5 import avg_pool
from ngraph.opset5 import batch_norm_inference
from ngraph.opset5 import batch_to_space
from ngraph.opset5 import binary_convolution
from ngraph.opset5 import broadcast
from ngraph.opset5 import bucketize
from ngraph.opset5 import ceiling
from ngraph.opset5 import ceiling as ceil
from ngraph.opset5 import clamp
from ngraph.opset5 import concat
from ngraph.opset5 import constant
from ngraph.opset5 import convert
from ngraph.opset5 import convert_like
from ngraph.opset5 import convolution
from ngraph.opset5 import convolution_backprop_data
from ngraph.opset5 import cos
from ngraph.opset5 import cosh
from ngraph.opset5 import ctc_greedy_decoder
from ngraph.opset5 import ctc_loss
from ngraph.opset5 import cum_sum
from ngraph.opset5 import cum_sum as cumsum
from ngraph.opset5 import deformable_convolution
from ngraph.opset5 import deformable_psroi_pooling
from ngraph.opset5 import depth_to_space
from ngraph.opset5 import detection_output
from ngraph.opset5 import divide
from ngraph.opset5 import elu
from ngraph.opset5 import embedding_bag_offsets_sum
from ngraph.opset5 import embedding_bag_packed_sum
from ngraph.opset5 import embedding_segments_sum
from ngraph.opset5 import extract_image_patches
from ngraph.opset5 import equal
from ngraph.opset5 import erf
from ngraph.opset5 import exp
from ngraph.opset5 import fake_quantize
from ngraph.opset5 import floor
from ngraph.opset5 import floor_mod
from ngraph.opset5 import gather
from ngraph.opset5 import gather_nd
from ngraph.opset5 import gather_tree
from ngraph.opset5 import gelu
from ngraph.opset5 import greater
from ngraph.opset5 import greater_equal
from ngraph.opset5 import grn
from ngraph.opset5 import group_convolution
from ngraph.opset5 import group_convolution_backprop_data
from ngraph.opset5 import gru_cell
from ngraph.opset5 import gru_sequence
from ngraph.opset5 import hard_sigmoid
from ngraph.opset5 import hsigmoid
from ngraph.opset5 import hswish
from ngraph.opset5 import interpolate
from ngraph.opset5 import less
from ngraph.opset5 import less_equal
from ngraph.opset5 import log
from ngraph.opset5 import logical_and
from ngraph.opset5 import logical_not
from ngraph.opset5 import logical_or
from ngraph.opset5 import logical_xor
from ngraph.opset5 import log_softmax
from ngraph.opset5 import loop
from ngraph.opset5 import lrn
from ngraph.opset5 import lstm_cell
from ngraph.opset5 import lstm_sequence
from ngraph.opset5 import matmul
from ngraph.opset5 import max_pool
from ngraph.opset5 import maximum
from ngraph.opset5 import minimum
from ngraph.opset5 import mish
from ngraph.opset5 import mod
from ngraph.opset5 import multiply
from ngraph.opset5 import mvn
from ngraph.opset5 import negative
from ngraph.opset5 import non_max_suppression
from ngraph.opset5 import non_zero
from ngraph.opset5 import normalize_l2
from ngraph.opset5 import not_equal
from ngraph.opset5 import one_hot
from ngraph.opset5 import pad
from ngraph.opset5 import parameter
from ngraph.opset5 import power
from ngraph.opset5 import prelu
from ngraph.opset5 import prior_box
from ngraph.opset5 import prior_box_clustered
from ngraph.opset5 import psroi_pooling
from ngraph.opset5 import proposal
from ngraph.opset5 import range
from ngraph.opset5 import read_value
from ngraph.opset5 import reduce_l1
from ngraph.opset5 import reduce_l2
from ngraph.opset5 import reduce_logical_and
from ngraph.opset5 import reduce_logical_or
from ngraph.opset5 import reduce_max
from ngraph.opset5 import reduce_mean
from ngraph.opset5 import reduce_min
from ngraph.opset5 import reduce_prod
from ngraph.opset5 import reduce_sum
from ngraph.opset5 import region_yolo
from ngraph.opset5 import reorg_yolo
from ngraph.opset5 import relu
from ngraph.opset5 import reshape
from ngraph.opset5 import result
from ngraph.opset5 import reverse_sequence
from ngraph.opset5 import rnn_cell
from ngraph.opset5 import rnn_sequence
from ngraph.opset5 import roi_align
from ngraph.opset5 import roi_pooling
from ngraph.opset5 import round
from ngraph.opset5 import scatter_elements_update
from ngraph.opset5 import scatter_update
from ngraph.opset5 import select
from ngraph.opset5 import selu
from ngraph.opset5 import shape_of
from ngraph.opset5 import shuffle_channels
from ngraph.opset5 import sigmoid
from ngraph.opset5 import sign
from ngraph.opset5 import sin
from ngraph.opset5 import sinh
from ngraph.opset5 import softmax
from ngraph.opset5 import softplus
from ngraph.opset5 import space_to_batch
from ngraph.opset5 import space_to_depth
from ngraph.opset5 import split
from ngraph.opset5 import sqrt
from ngraph.opset5 import squared_difference
from ngraph.opset5 import squeeze
from ngraph.opset5 import strided_slice
from ngraph.opset5 import subtract
from ngraph.opset5 import swish
from ngraph.opset5 import tan
from ngraph.opset5 import tanh
from ngraph.opset5 import tensor_iterator
from ngraph.opset5 import tile
from ngraph.opset5 import topk
from ngraph.opset5 import transpose
from ngraph.opset5 import unsqueeze
from ngraph.opset5 import variadic_split
from ngraph.opset6 import absolute
from ngraph.opset6 import absolute as abs
from ngraph.opset6 import acos
from ngraph.opset6 import acosh
from ngraph.opset6 import add
from ngraph.opset6 import asin
from ngraph.opset6 import asinh
from ngraph.opset6 import assign
from ngraph.opset6 import atan
from ngraph.opset6 import atanh
from ngraph.opset6 import avg_pool
from ngraph.opset6 import batch_norm_inference
from ngraph.opset6 import batch_to_space
from ngraph.opset6 import binary_convolution
from ngraph.opset6 import broadcast
from ngraph.opset6 import bucketize
from ngraph.opset6 import ceiling
from ngraph.opset6 import ceiling as ceil
from ngraph.opset6 import clamp
from ngraph.opset6 import concat
from ngraph.opset6 import constant
from ngraph.opset6 import convert
from ngraph.opset6 import convert_like
from ngraph.opset6 import convolution
from ngraph.opset6 import convolution_backprop_data
from ngraph.opset6 import cos
from ngraph.opset6 import cosh
from ngraph.opset6 import ctc_greedy_decoder
from ngraph.opset6 import ctc_loss
from ngraph.opset6 import cum_sum
from ngraph.opset6 import cum_sum as cumsum
from ngraph.opset6 import deformable_convolution
from ngraph.opset6 import deformable_psroi_pooling
from ngraph.opset6 import depth_to_space
from ngraph.opset6 import detection_output
from ngraph.opset6 import divide
from ngraph.opset6 import elu
from ngraph.opset6 import embedding_bag_offsets_sum
from ngraph.opset6 import embedding_bag_packed_sum
from ngraph.opset6 import embedding_segments_sum
from ngraph.opset6 import extract_image_patches
from ngraph.opset6 import equal
from ngraph.opset6 import erf
from ngraph.opset6 import exp
from ngraph.opset6 import fake_quantize
from ngraph.opset6 import floor
from ngraph.opset6 import floor_mod
from ngraph.opset6 import gather
from ngraph.opset6 import gather_elements
from ngraph.opset6 import gather_nd
from ngraph.opset6 import gather_tree
from ngraph.opset6 import gelu
from ngraph.opset6 import greater
from ngraph.opset6 import greater_equal
from ngraph.opset6 import grn
from ngraph.opset6 import group_convolution
from ngraph.opset6 import group_convolution_backprop_data
from ngraph.opset6 import gru_cell
from ngraph.opset6 import gru_sequence
from ngraph.opset6 import hard_sigmoid
from ngraph.opset6 import hsigmoid
from ngraph.opset6 import hswish
from ngraph.opset6 import interpolate
from ngraph.opset6 import less
from ngraph.opset6 import less_equal
from ngraph.opset6 import log
from ngraph.opset6 import logical_and
from ngraph.opset6 import logical_not
from ngraph.opset6 import logical_or
from ngraph.opset6 import logical_xor
from ngraph.opset6 import log_softmax
from ngraph.opset6 import loop
from ngraph.opset6 import lrn
from ngraph.opset6 import lstm_cell
from ngraph.opset6 import lstm_sequence
from ngraph.opset6 import matmul
from ngraph.opset6 import max_pool
from ngraph.opset6 import maximum
from ngraph.opset6 import minimum
from ngraph.opset6 import mish
from ngraph.opset6 import mod
from ngraph.opset6 import multiply
from ngraph.opset6 import mvn
from ngraph.opset6 import negative
from ngraph.opset6 import non_max_suppression
from ngraph.opset6 import non_zero
from ngraph.opset6 import normalize_l2
from ngraph.opset6 import not_equal
from ngraph.opset6 import one_hot
from ngraph.opset6 import pad
from ngraph.opset6 import parameter
from ngraph.opset6 import power
from ngraph.opset6 import prelu
from ngraph.opset6 import prior_box
from ngraph.opset6 import prior_box_clustered
from ngraph.opset6 import psroi_pooling
from ngraph.opset6 import proposal
from ngraph.opset6 import range
from ngraph.opset6 import read_value
from ngraph.opset6 import reduce_l1
from ngraph.opset6 import reduce_l2
from ngraph.opset6 import reduce_logical_and
from ngraph.opset6 import reduce_logical_or
from ngraph.opset6 import reduce_max
from ngraph.opset6 import reduce_mean
from ngraph.opset6 import reduce_min
from ngraph.opset6 import reduce_prod
from ngraph.opset6 import reduce_sum
from ngraph.opset6 import region_yolo
from ngraph.opset6 import reorg_yolo
from ngraph.opset6 import relu
from ngraph.opset6 import reshape
from ngraph.opset6 import result
from ngraph.opset6 import reverse_sequence
from ngraph.opset6 import rnn_cell
from ngraph.opset6 import rnn_sequence
from ngraph.opset6 import roi_align
from ngraph.opset6 import roi_pooling
from ngraph.opset6 import round
from ngraph.opset6 import scatter_elements_update
from ngraph.opset6 import scatter_update
from ngraph.opset6 import select
from ngraph.opset6 import selu
from ngraph.opset6 import shape_of
from ngraph.opset6 import shuffle_channels
from ngraph.opset6 import sigmoid
from ngraph.opset6 import sign
from ngraph.opset6 import sin
from ngraph.opset6 import sinh
from ngraph.opset6 import softmax
from ngraph.opset6 import softplus
from ngraph.opset6 import space_to_batch
from ngraph.opset6 import space_to_depth
from ngraph.opset6 import split
from ngraph.opset6 import sqrt
from ngraph.opset6 import squared_difference
from ngraph.opset6 import squeeze
from ngraph.opset6 import strided_slice
from ngraph.opset6 import subtract
from ngraph.opset6 import swish
from ngraph.opset6 import tan
from ngraph.opset6 import tanh
from ngraph.opset6 import tensor_iterator
from ngraph.opset6 import tile
from ngraph.opset6 import topk
from ngraph.opset6 import transpose
from ngraph.opset6 import unsqueeze
from ngraph.opset6 import variadic_split
# Extend Node class to support binary operators

View File

@ -0,0 +1,165 @@
# ******************************************************************************
# Copyright 2017-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
from ngraph.opset1.ops import absolute
from ngraph.opset1.ops import absolute as abs
from ngraph.opset1.ops import acos
from ngraph.opset4.ops import acosh
from ngraph.opset1.ops import add
from ngraph.opset1.ops import asin
from ngraph.opset4.ops import asinh
from ngraph.opset3.ops import assign
from ngraph.opset1.ops import atan
from ngraph.opset4.ops import atanh
from ngraph.opset1.ops import avg_pool
from ngraph.opset5.ops import batch_norm_inference
from ngraph.opset2.ops import batch_to_space
from ngraph.opset1.ops import binary_convolution
from ngraph.opset3.ops import broadcast
from ngraph.opset3.ops import bucketize
from ngraph.opset1.ops import ceiling
from ngraph.opset1.ops import ceiling as ceil
from ngraph.opset1.ops import clamp
from ngraph.opset1.ops import concat
from ngraph.opset1.ops import constant
from ngraph.opset1.ops import convert
from ngraph.opset1.ops import convert_like
from ngraph.opset1.ops import convolution
from ngraph.opset1.ops import convolution_backprop_data
from ngraph.opset1.ops import cos
from ngraph.opset1.ops import cosh
from ngraph.opset1.ops import ctc_greedy_decoder
from ngraph.opset4.ops import ctc_loss
from ngraph.opset3.ops import cum_sum
from ngraph.opset3.ops import cum_sum as cumsum
from ngraph.opset1.ops import deformable_convolution
from ngraph.opset1.ops import deformable_psroi_pooling
from ngraph.opset1.ops import depth_to_space
from ngraph.opset1.ops import detection_output
from ngraph.opset1.ops import divide
from ngraph.opset1.ops import elu
from ngraph.opset3.ops import embedding_bag_offsets_sum
from ngraph.opset3.ops import embedding_bag_packed_sum
from ngraph.opset3.ops import embedding_segments_sum
from ngraph.opset3.ops import extract_image_patches
from ngraph.opset1.ops import equal
from ngraph.opset1.ops import erf
from ngraph.opset1.ops import exp
from ngraph.opset1.ops import fake_quantize
from ngraph.opset1.ops import floor
from ngraph.opset1.ops import floor_mod
from ngraph.opset1.ops import gather
from ngraph.opset6.ops import gather_elements
from ngraph.opset5.ops import gather_nd
from ngraph.opset1.ops import gather_tree
from ngraph.opset2.ops import gelu
from ngraph.opset1.ops import greater
from ngraph.opset1.ops import greater_equal
from ngraph.opset1.ops import grn
from ngraph.opset1.ops import group_convolution
from ngraph.opset1.ops import group_convolution_backprop_data
from ngraph.opset3.ops import gru_cell
from ngraph.opset5.ops import gru_sequence
from ngraph.opset1.ops import hard_sigmoid
from ngraph.opset5.ops import hsigmoid
from ngraph.opset4.ops import hswish
from ngraph.opset1.ops import interpolate
from ngraph.opset1.ops import less
from ngraph.opset1.ops import less_equal
from ngraph.opset1.ops import log
from ngraph.opset1.ops import logical_and
from ngraph.opset1.ops import logical_not
from ngraph.opset1.ops import logical_or
from ngraph.opset1.ops import logical_xor
from ngraph.opset5.ops import log_softmax
from ngraph.opset5.ops import loop
from ngraph.opset1.ops import lrn
from ngraph.opset4.ops import lstm_cell
from ngraph.opset1.ops import lstm_sequence
from ngraph.opset1.ops import matmul
from ngraph.opset1.ops import max_pool
from ngraph.opset1.ops import maximum
from ngraph.opset1.ops import minimum
from ngraph.opset4.ops import mish
from ngraph.opset1.ops import mod
from ngraph.opset1.ops import multiply
from ngraph.opset2.ops import mvn
from ngraph.opset1.ops import negative
from ngraph.opset5.ops import non_max_suppression
from ngraph.opset3.ops import non_zero
from ngraph.opset1.ops import normalize_l2
from ngraph.opset1.ops import not_equal
from ngraph.opset1.ops import one_hot
from ngraph.opset1.ops import pad
from ngraph.opset1.ops import parameter
from ngraph.opset1.ops import power
from ngraph.opset1.ops import prelu
from ngraph.opset1.ops import prior_box
from ngraph.opset1.ops import prior_box_clustered
from ngraph.opset1.ops import psroi_pooling
from ngraph.opset4.ops import proposal
from ngraph.opset1.ops import range
from ngraph.opset3.ops import read_value
from ngraph.opset4.ops import reduce_l1
from ngraph.opset4.ops import reduce_l2
from ngraph.opset1.ops import reduce_logical_and
from ngraph.opset1.ops import reduce_logical_or
from ngraph.opset1.ops import reduce_max
from ngraph.opset1.ops import reduce_mean
from ngraph.opset1.ops import reduce_min
from ngraph.opset1.ops import reduce_prod
from ngraph.opset1.ops import reduce_sum
from ngraph.opset1.ops import region_yolo
from ngraph.opset2.ops import reorg_yolo
from ngraph.opset1.ops import relu
from ngraph.opset1.ops import reshape
from ngraph.opset1.ops import result
from ngraph.opset1.ops import reverse_sequence
from ngraph.opset3.ops import rnn_cell
from ngraph.opset5.ops import rnn_sequence
from ngraph.opset3.ops import roi_align
from ngraph.opset2.ops import roi_pooling
from ngraph.opset5.ops import round
from ngraph.opset3.ops import scatter_elements_update
from ngraph.opset3.ops import scatter_update
from ngraph.opset1.ops import select
from ngraph.opset1.ops import selu
from ngraph.opset3.ops import shape_of
from ngraph.opset3.ops import shuffle_channels
from ngraph.opset1.ops import sigmoid
from ngraph.opset1.ops import sign
from ngraph.opset1.ops import sin
from ngraph.opset1.ops import sinh
from ngraph.opset1.ops import softmax
from ngraph.opset4.ops import softplus
from ngraph.opset2.ops import space_to_batch
from ngraph.opset1.ops import space_to_depth
from ngraph.opset1.ops import split
from ngraph.opset1.ops import sqrt
from ngraph.opset1.ops import squared_difference
from ngraph.opset1.ops import squeeze
from ngraph.opset1.ops import strided_slice
from ngraph.opset1.ops import subtract
from ngraph.opset4.ops import swish
from ngraph.opset1.ops import tan
from ngraph.opset1.ops import tanh
from ngraph.opset1.ops import tensor_iterator
from ngraph.opset1.ops import tile
from ngraph.opset3.ops import topk
from ngraph.opset1.ops import transpose
from ngraph.opset1.ops import unsqueeze
from ngraph.opset1.ops import variadic_split

View File

@ -0,0 +1,81 @@
# ******************************************************************************
# Copyright 2017-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""Factory functions for all ngraph ops."""
from typing import Callable, Iterable, List, Optional, Set, Union
import numpy as np
from functools import partial
from ngraph.impl import Node, Shape
from ngraph.impl.op import Constant, Parameter
from ngraph.opset_utils import _get_node_factory
from ngraph.utils.decorators import binary_op, nameable_op, unary_op
from ngraph.utils.input_validation import (
assert_list_of_ints,
check_valid_attributes,
is_non_negative_value,
is_positive_value,
)
from ngraph.utils.node_factory import NodeFactory
from ngraph.utils.tensor_iterator_types import (
GraphBody,
TensorIteratorSliceInputDesc,
TensorIteratorMergedInputDesc,
TensorIteratorInvariantInputDesc,
TensorIteratorBodyOutputDesc,
TensorIteratorConcatOutputDesc,
)
from ngraph.utils.types import (
NodeInput,
NumericData,
NumericType,
ScalarData,
TensorShape,
as_node,
as_nodes,
get_dtype,
get_element_type,
get_element_type_str,
make_constant_node,
)
_get_node_factory_opset6 = partial(_get_node_factory, "opset6")
# -------------------------------------------- ops ------------------------------------------------
@nameable_op
def gather_elements(
data: NodeInput,
indices: NodeInput,
axis: Optional[int] = 0,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs GatherND.
@param data: N-D tensor with data for gathering
@param indices: N-D tensor with indices by which data is gathered
@param axis: axis along which elements are gathered
@return: The new node which performs GatherElements
"""
inputs = as_nodes(data, indices)
attributes = {
"axis": axis
}
return _get_node_factory_opset6().create("GatherElements", inputs, attributes)

View File

@ -5,7 +5,7 @@ from _pyngraph import NodeFactory as _NodeFactory
from ngraph.impl import Node, Output
DEFAULT_OPSET = "opset5"
DEFAULT_OPSET = "opset6"
class NodeFactory(object):

View File

@ -92,6 +92,7 @@ namespace
{"opset3", OpsetFunction(ngraph::get_opset3)},
{"opset4", OpsetFunction(ngraph::get_opset4)},
{"opset5", OpsetFunction(ngraph::get_opset5)},
{"opset6", OpsetFunction(ngraph::get_opset6)},
};
auto it = s_opsets.find(opset_ver);
@ -102,7 +103,7 @@ namespace
return it->second();
}
const ngraph::OpSet& m_opset{ngraph::get_opset5()};
const ngraph::OpSet& m_opset{ngraph::get_opset6()};
};
}

View File

@ -16,7 +16,7 @@
import numpy as np
import ngraph as ng
from ngraph.impl import Type
from ngraph.impl import Type, Shape
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
@ -213,3 +213,18 @@ def test_gather_nd():
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32
def test_gather_elements():
indices_type = np.int32
data_dtype = np.float32
data = ng.parameter(Shape([2, 5]), dtype=data_dtype, name="data")
indices = ng.parameter(Shape([2, 100]), dtype=indices_type, name="indices")
axis = 1
expected_shape = [2, 100]
node = ng.gather_elements(data, indices, axis)
assert node.get_type_name() == "GatherElements"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32

View File

@ -275,6 +275,7 @@ set(MULTI_TEST_SRC
backend/function_name.in.cpp
backend/fused_op.in.cpp
backend/gather.in.cpp
backend/gather_elements.in.cpp
backend/gather_nd.in.cpp
backend/gelu.in.cpp
backend/group_convolution.in.cpp

View File

@ -0,0 +1,450 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
NGRAPH_TEST(${BACKEND_NAME}, evaluate_1D_gather_elements_3_indices_int32)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{3});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{7});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
test_case.add_input<int32_t>({1, 2, 3});
test_case.add_input<int32_t>({1, 2, 0, 2, 0, 0, 2});
test_case.add_expected_output<int32_t>(vector<int32_t>{2, 3, 1, 3, 1, 1, 3});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_2x2_indices_int32_axis_0)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
test_case.add_input<int32_t>({1, 2,
3, 4});
test_case.add_input<int32_t>({0, 1,
0, 0});
test_case.add_expected_output<int32_t>(vector<int32_t>{1, 4,
1, 2});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_2x2_indices_int32_axis_1)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2});
int64_t axis = 1;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{1, 2,
3, 4};
std::vector<int32_t> indices{0, 1,
0, 0};
test_case.add_multiple_inputs<int32_t>({data, indices});
test_case.add_expected_output<int32_t>(vector<int32_t>{1, 2,
3, 3});
test_case.run();
// clang-format on
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_2x2_indices_int32_axis_minus_1)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2});
int64_t axis = -1;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{1, 2,
3, 4};
std::vector<int32_t> indices{0, 1,
0, 0};
test_case.add_multiple_inputs<int32_t>({data, indices});
test_case.add_expected_output<int32_t>(vector<int32_t>{1, 2,
3, 3});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_2x3_indices_int32)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{3, 3});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{1, 2, 3,
4, 5, 6,
7, 8, 9};
std::vector<int32_t> indices{1, 2, 0,
2, 0, 0};
test_case.add_multiple_inputs<int32_t>({data, indices});
test_case.add_expected_output<int32_t>(vector<int32_t>{4, 8, 3,
7, 2, 3});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_3D_gather_elements_3x2x2_indices_int32)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{3, 2, 2});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{3, 2, 2});
int64_t axis = -1;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, 12};
std::vector<int32_t> indices{1, 0,
0, 1,
1, 1,
1, 0,
0, 0,
1, 1};
test_case.add_multiple_inputs<int32_t>({data, indices});
test_case.add_expected_output<int32_t>(vector<int32_t>{2, 1,
3, 4,
6, 6,
8, 7,
9, 9,
12, 12});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_4D_gather_elements_3x2x2x2_indices_int64)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{3, 2, 2, 2});
auto arg2 = make_shared<op::Parameter>(element::i64, PartialShape{3, 2, 2, 4});
int64_t axis = -1;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{ 1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, 12,
13, 14,
15, 16,
17, 18,
19, 20,
21, 22,
23, 24};
std::vector<int64_t> indices{1, 0, 0, 0,
0, 1, 1, 0,
1, 1, 1, 1,
1, 0, 0, 1,
0, 0, 0, 1,
1, 1, 1, 0,
0, 0, 0, 0,
1, 0, 1, 0,
1, 1, 1, 1,
1, 0, 1, 0,
1, 0, 0, 1,
0, 0, 0, 0};
test_case.add_input<int32_t>(data);
test_case.add_input<int64_t>(indices);
test_case.add_expected_output<int32_t>(vector<int32_t>{2, 1, 1, 1,
3, 4, 4, 3,
6, 6, 6, 6,
8, 7, 7, 8,
9, 9, 9, 10,
12, 12, 12, 11,
13, 13, 13, 13,
16, 15, 16, 15,
18, 18, 18, 18,
20, 19, 20, 19,
22, 21, 21, 22,
23, 23, 23, 23});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_3D_gather_elements_3x2x2_indices_int64)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{3, 2, 2});
auto arg2 = make_shared<op::Parameter>(element::i64, PartialShape{3, 2, 2});
int64_t axis = -1;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, 12};
std::vector<int64_t> indices{1, 0,
0, 1,
1, 1,
1, 0,
0, 0,
1, 1};
test_case.add_input<int32_t>(data);
test_case.add_input<int64_t>(indices);
test_case.add_expected_output<int32_t>(vector<int32_t>{2, 1,
3, 4,
6, 6,
8, 7,
9, 9,
12, 12});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_3x2_data_bool)
{
auto arg1 = make_shared<op::Parameter>(element::boolean, PartialShape{3, 2});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<bool> data{true, false, true,
true, false, false};
std::vector<int32_t> indices{0, 1,
0, 2};
test_case.add_input<bool>(data);
test_case.add_input<int32_t>(indices);
test_case.add_expected_output<bool>(vector<bool>{true, true,
true, false});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_2x3_data_float32)
{
auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape{3, 3});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<float> data{1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f,
7.0f, 8.0f, 9.0f};
std::vector<int32_t> indices{1, 2, 0,
2, 0, 0};
test_case.add_input<float>(data);
test_case.add_input<int32_t>(indices);
test_case.add_expected_output<float>(vector<float>{4.0f, 8.0f, 3.0f,
7.0f, 2.0f, 3.0f});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_2x2x1_data_float32)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 2, 1});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{4, 2, 1});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{5,
4,
1,
4};
std::vector<int32_t> indices{0,
0,
1,
1,
1,
1,
0,
1};
test_case.add_input<int32_t>(data);
test_case.add_input<int32_t>(indices);
test_case.add_expected_output<float>(vector<float>{5,
4,
1,
4,
1,
4,
5,
4});
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_1D_gather_elements_negative_test)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{3});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{7});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
std::vector<int32_t> data{1, 2, 3};
std::vector<int32_t> indices{1, 2, 0, 2, 0, 0, 8};
test_case.add_multiple_inputs<int32_t>({data, indices});
test_case.add_expected_output<int32_t>(vector<int32_t>{2, 3, 1, 3, 1, 1, 3});
try
{
test_case.run();
// Should have thrown, so fail if it didn't
FAIL() << "Evaluate out ouf bound indices check failed";
}
catch (const std::domain_error& error)
{
ASSERT_EQ(error.what(), std::string("indices values of GatherElement exceed data size"));
}
catch (...)
{
FAIL() << "Evaluate out ouf bound indices check failed";
}
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_2D_gather_elements_negative_test)
{
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape{3, 3});
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3});
int64_t axis = 0;
auto gather_el = make_shared<op::v6::GatherElements>(arg1, arg2, axis);
auto fun = make_shared<Function>(OutputVector{gather_el}, ParameterVector{arg1, arg2});
auto test_case = test::TestCase<TestEngine>(fun);
// clang-format off
std::vector<int32_t> data{1, 2, 3,
4, 5, 6,
7, 8, 9};
std::vector<int32_t> indices{1, 3, 0,
2, 0, 0};
test_case.add_multiple_inputs<int32_t>({data, indices});
test_case.add_expected_output<int32_t>(vector<int32_t>{4, 8, 3,
7, 2, 3});
// clang-format on
try
{
test_case.run();
// Should have thrown, so fail if it didn't
FAIL() << "Evaluate out ouf bound indices check failed";
}
catch (const std::domain_error& error)
{
ASSERT_EQ(error.what(), std::string("indices values of GatherElement exceed data size"));
}
catch (...)
{
FAIL() << "Evaluate out ouf bound indices check failed";
}
}

View File

@ -1540,3 +1540,31 @@ onnx_controlflow_loop_infinite
# unsupported dynamic ops
onnx_dyn_shapes_reduce_max_dynamic_input_rank_negative_axis
IE_GPU.range_v4_trunc_inputs
# not implemented yet on CPU and GPU plugins
IE_CPU.evaluate_1D_gather_elements_3_indices_int32
IE_CPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_0
IE_CPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_1
IE_CPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_minus_1
IE_CPU.evaluate_2D_gather_elements_2x3_indices_int32
IE_CPU.evaluate_3D_gather_elements_3x2x2_indices_int32
IE_CPU.evaluate_3D_gather_elements_3x2x2_indices_int64
IE_CPU.evaluate_2D_gather_elements_3x2_data_bool
IE_CPU.evaluate_2D_gather_elements_2x3_data_float32
IE_CPU.evaluate_1D_gather_elements_negative_test
IE_CPU.evaluate_2D_gather_elements_negative_test
IE_CPU.evaluate_2D_gather_elements_2x2x1_data_float32
IE_CPU.evaluate_4D_gather_elements_3x2x2x2_indices_int64
IE_GPU.evaluate_1D_gather_elements_3_indices_int32
IE_GPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_0
IE_GPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_1
IE_GPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_minus_1
IE_GPU.evaluate_2D_gather_elements_2x3_indices_int32
IE_GPU.evaluate_3D_gather_elements_3x2x2_indices_int32
IE_GPU.evaluate_3D_gather_elements_3x2x2_indices_int64
IE_GPU.evaluate_2D_gather_elements_3x2_data_bool
IE_GPU.evaluate_2D_gather_elements_2x3_data_float32
IE_GPU.evaluate_1D_gather_elements_negative_test
IE_GPU.evaluate_2D_gather_elements_negative_test
IE_GPU.evaluate_2D_gather_elements_2x2x1_data_float32
IE_GPU.evaluate_4D_gather_elements_3x2x2x2_indices_int64

View File

@ -35,6 +35,7 @@
#include <ngraph/runtime/reference/embedding_segments_sum.hpp>
#include <ngraph/runtime/reference/extract_image_patches.hpp>
#include <ngraph/runtime/reference/fake_quantize.hpp>
#include <ngraph/runtime/reference/gather_elements.hpp>
#include <ngraph/runtime/reference/gather_nd.hpp>
#include <ngraph/runtime/reference/gather_tree.hpp>
#include <ngraph/runtime/reference/gelu.hpp>
@ -1580,6 +1581,45 @@ namespace
return true;
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v6::GatherElements>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
Shape params_shape = inputs[0]->get_shape();
Shape indices_shape = inputs[1]->get_shape();
outputs[0]->set_shape(indices_shape);
if (inputs[1]->get_element_type() == element::i64)
{
runtime::reference::gather_elements<T, int64_t>(inputs[0]->get_data_ptr<ET>(),
inputs[1]->get_data_ptr<int64_t>(),
outputs[0]->get_data_ptr<ET>(),
inputs[0]->get_shape(),
inputs[1]->get_shape(),
outputs[0]->get_shape(),
op->get_axis());
}
else if (inputs[1]->get_element_type() == element::i32)
{
runtime::reference::gather_elements<T, int32_t>(inputs[0]->get_data_ptr<ET>(),
inputs[1]->get_data_ptr<int32_t>(),
outputs[0]->get_data_ptr<ET>(),
inputs[0]->get_shape(),
inputs[1]->get_shape(),
outputs[0]->get_shape(),
op->get_axis());
}
else
{
throw ngraph_error("Unexpected indices type");
}
return true;
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v5::GatherND>& op,
const HostTensorVector& outputs,

View File

@ -86,3 +86,5 @@ NGRAPH_OP(LSTMSequence, op::v5)
NGRAPH_OP(NonMaxSuppression, op::v5)
NGRAPH_OP(RNNSequence, op::v5)
NGRAPH_OP(Round, op::v5)
NGRAPH_OP(GatherElements, op::v6)