Remove old Scatter operations (#1265)
This commit is contained in:
parent
45d1b4eb19
commit
9dedb39cfc
@ -312,12 +312,8 @@ set (SRC
|
||||
op/reverse.hpp
|
||||
op/reverse_sequence.cpp
|
||||
op/reverse_sequence.hpp
|
||||
op/scatter_add.cpp
|
||||
op/scatter_add.hpp
|
||||
op/scatter_elements_update.cpp
|
||||
op/scatter_elements_update.hpp
|
||||
op/scatter_nd_add.cpp
|
||||
op/scatter_nd_add.hpp
|
||||
op/scatter_nd_update.cpp
|
||||
op/scatter_nd_update.hpp
|
||||
op/scatter_update.cpp
|
||||
@ -413,8 +409,6 @@ set (SRC
|
||||
op/fused/rnn_cell.hpp
|
||||
op/fused/scale_shift.cpp
|
||||
op/fused/scale_shift.hpp
|
||||
op/fused/scatter_nd.cpp
|
||||
op/fused/scatter_nd.hpp
|
||||
op/fused/stack.cpp
|
||||
op/fused/stack.hpp
|
||||
op/fused/selu.cpp
|
||||
|
@ -1,140 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/fused/scatter_nd.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/convert.hpp"
|
||||
#include "ngraph/op/scatter_nd_add.hpp"
|
||||
#include "ngraph/op/select.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v0::ScatterND::type_info;
|
||||
|
||||
op::v0::ScatterND::ScatterND(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& updates)
|
||||
: op::util::FusedOp({data, indices, updates})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v0::ScatterND::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ScatterND>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
||||
void op::v0::ScatterND::pre_validate_and_infer_types()
|
||||
{
|
||||
const static int DATA = 0;
|
||||
const static int INDICES = 1;
|
||||
const static int UPDATES = 2;
|
||||
|
||||
element::Type data_et = input_value(DATA).get_element_type();
|
||||
element::Type indices_et = input_value(INDICES).get_element_type();
|
||||
element::Type updates_et = input_value(UPDATES).get_element_type();
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_et == element::i32 || indices_et == element::i64,
|
||||
"Indices element type must be i64 or i32.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_et == updates_et,
|
||||
"Updates element type must be the same as element type of data.");
|
||||
|
||||
const PartialShape& data_ps = get_input_partial_shape(DATA);
|
||||
const PartialShape& indices_ps = get_input_partial_shape(INDICES);
|
||||
const PartialShape& updates_ps = get_input_partial_shape(UPDATES);
|
||||
|
||||
if (data_ps.rank().is_static())
|
||||
{
|
||||
const size_t data_rank = data_ps.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this, data_rank >= 1, "Data rank is expected to be at least 1.");
|
||||
}
|
||||
|
||||
if (indices_ps.rank().is_static())
|
||||
{
|
||||
const size_t indices_rank = indices_ps.rank().get_length();
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, indices_rank >= 1, "Indices rank is expected to be at least 1.");
|
||||
}
|
||||
|
||||
if (indices_ps.rank().is_static() && data_ps.rank().is_static())
|
||||
{
|
||||
const size_t indices_rank = indices_ps.rank().get_length();
|
||||
const size_t last_dim_pos = indices_rank - 1;
|
||||
const Dimension indices_last_dim = indices_ps[last_dim_pos];
|
||||
if (indices_last_dim.is_static())
|
||||
{
|
||||
const size_t indices_last_dim_value = indices_last_dim.get_length();
|
||||
const size_t data_rank = data_ps.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_last_dim_value <= data_rank,
|
||||
"Last dimension of indices can be at most the rank of data.");
|
||||
|
||||
if (updates_ps.rank().is_static())
|
||||
{
|
||||
const size_t expected_updates_rank =
|
||||
data_rank + indices_rank - indices_last_dim_value - 1;
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
updates_ps.rank().get_length() == expected_updates_rank,
|
||||
"Updates rank is expected to be equal data_rank + indices_rank - "
|
||||
"indices_shape[-1] - 1.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, data_et, data_ps);
|
||||
}
|
||||
|
||||
NodeVector op::ScatterND::decompose_op() const
|
||||
{
|
||||
const auto data = input_value(0);
|
||||
const auto indices = input_value(1);
|
||||
const auto updates = input_value(2);
|
||||
|
||||
const Shape& data_shape = data.get_shape();
|
||||
const Shape& updates_shape = updates.get_shape();
|
||||
|
||||
element::Type data_et = data.get_element_type();
|
||||
|
||||
// Create a boolean mask that matches the data tensor shape and
|
||||
// contains 'true' values in the positions indicated by 'indices'
|
||||
// and 'false' values everywhere else.
|
||||
|
||||
const auto true_values = op::Constant::create(element::i64, updates_shape, {1});
|
||||
const auto false_values = op::Constant::create(element::i64, data_shape, {0});
|
||||
|
||||
const auto mask = std::make_shared<op::v0::ScatterNDAdd>(false_values, indices, true_values);
|
||||
|
||||
const auto mask_bool = std::make_shared<op::v0::Convert>(mask, element::boolean);
|
||||
|
||||
const auto zeros = op::Constant::create(data_et, data_shape, {0});
|
||||
|
||||
// Create an intermediate node that will contain the original data and
|
||||
// zeros in the positions indicated by indices.
|
||||
|
||||
const auto intermediate = std::make_shared<op::v0::Select>(mask_bool, zeros, data);
|
||||
|
||||
return {std::make_shared<op::v0::ScatterNDAdd>(intermediate, indices, updates)};
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v0
|
||||
{
|
||||
/// \brief Replace values within provided tensor by `updates` according to `indices`.
|
||||
class NGRAPH_API ScatterND : public op::util::FusedOp
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ScatterND", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
ScatterND() = default;
|
||||
/// \param data The tensor whithn slice-values will be updated
|
||||
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
|
||||
/// \param updates The tensor of replacement-slice-values
|
||||
ScatterND(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& updates);
|
||||
|
||||
void pre_validate_and_infer_types() override;
|
||||
|
||||
virtual NodeVector decompose_op() const override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
}
|
||||
using v0::ScatterND;
|
||||
}
|
||||
}
|
@ -181,11 +181,7 @@ NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(Round, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(ROIAlign, ngraph::op::v3, 3)
|
||||
NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(ScatterAdd, ngraph::op::v3, 3)
|
||||
NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3, 3)
|
||||
NGRAPH_OP(ScatterND, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(ScatterNDAdd, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(ScatterUpdate, ngraph::op::v3, 3)
|
||||
NGRAPH_OP(Select, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(Select, ngraph::op::v1, 1)
|
||||
|
@ -1,103 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/scatter_add.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
static int INPUTS = 0;
|
||||
static int INDICES = 1;
|
||||
static int UPDATES = 2;
|
||||
|
||||
constexpr NodeTypeInfo op::v0::ScatterAdd::type_info;
|
||||
|
||||
shared_ptr<Node> op::v0::ScatterAdd::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ScatterAdd>(new_args.at(INPUTS), new_args.at(INDICES), new_args.at(UPDATES));
|
||||
}
|
||||
|
||||
void op::v0::ScatterAdd::validate_and_infer_types()
|
||||
{
|
||||
element::Type inputs_et = get_input_element_type(INPUTS);
|
||||
element::Type indices_et = get_input_element_type(INDICES);
|
||||
element::Type updates_et = get_input_element_type(UPDATES);
|
||||
|
||||
const PartialShape& inputs_shape = get_input_partial_shape(INPUTS);
|
||||
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
|
||||
const PartialShape& updates_shape = get_input_partial_shape(UPDATES);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_et == element::i32 || indices_et == element::i64,
|
||||
"Indices element type must be i64 or i32");
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, updates_et == inputs_et, "Updates element type must be the same as Inputs");
|
||||
|
||||
// updates rank must be at indices rank + inputs rank - 1
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
|
||||
updates_shape.rank().is_dynamic() ||
|
||||
updates_shape.rank().get_length() ==
|
||||
indices_shape.rank().get_length() +
|
||||
inputs_shape.rank().get_length() - 1,
|
||||
"Updates rank is expected to be indices rank + inputs rank - 1");
|
||||
|
||||
bool compatible = true;
|
||||
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
|
||||
{
|
||||
for (size_t i = 0; i < indices_shape.rank().get_length(); i++)
|
||||
{
|
||||
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
|
||||
}
|
||||
for (size_t i = 1; i < inputs_shape.rank().get_length(); i++)
|
||||
{
|
||||
compatible = compatible &&
|
||||
updates_shape[indices_shape.rank().get_length() + i - 1].same_scheme(
|
||||
inputs_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, compatible, "Updates shape must be indices_shape + inputs_shape[1:]");
|
||||
|
||||
set_output_type(0, inputs_et, inputs_shape);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
//
|
||||
// Introduced in Opset 3
|
||||
//
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
constexpr NodeTypeInfo op::v3::ScatterAdd::type_info;
|
||||
|
||||
op::v3::ScatterAdd::ScatterAdd(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& updates,
|
||||
const Output<Node>& axis)
|
||||
: util::ScatterBase(data, indices, updates, axis)
|
||||
{
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v3::ScatterAdd::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v3::ScatterAdd>(
|
||||
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/scatter_nd_add.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::ScatterNDAdd::type_info;
|
||||
|
||||
shared_ptr<Node> op::ScatterNDAdd::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ScatterNDAdd>(new_args.at(op::util::ScatterNDBase::INPUTS),
|
||||
new_args.at(op::util::ScatterNDBase::INDICES),
|
||||
new_args.at(op::util::ScatterNDBase::UPDATES));
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/scatter_nd_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v0
|
||||
{
|
||||
/// \brief Add updates to slices from inputs addressed by indices
|
||||
class NGRAPH_API ScatterNDAdd : public util::ScatterNDBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ScatterNDAdd", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
ScatterNDAdd() = default;
|
||||
/// \param inputs Tensor
|
||||
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
|
||||
/// \param updates Tensor: Must have same type as inputs
|
||||
ScatterNDAdd(const Output<Node>& inputs,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& updates)
|
||||
: util::ScatterNDBase(inputs, indices, updates)
|
||||
{
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
}
|
||||
using v0::ScatterNDAdd;
|
||||
}
|
||||
}
|
@ -85,7 +85,6 @@
|
||||
#include "ngraph/op/fused/prelu.hpp"
|
||||
#include "ngraph/op/fused/rnn_cell.hpp"
|
||||
#include "ngraph/op/fused/scale_shift.hpp"
|
||||
#include "ngraph/op/fused/scatter_nd.hpp"
|
||||
#include "ngraph/op/fused/selu.hpp"
|
||||
#include "ngraph/op/fused/shuffle_channels.hpp"
|
||||
#include "ngraph/op/fused/softmax_crossentropy.hpp"
|
||||
@ -151,9 +150,7 @@
|
||||
#include "ngraph/op/roi_align.hpp"
|
||||
#include "ngraph/op/roi_pooling.hpp"
|
||||
#include "ngraph/op/round.hpp"
|
||||
#include "ngraph/op/scatter_add.hpp"
|
||||
#include "ngraph/op/scatter_elements_update.hpp"
|
||||
#include "ngraph/op/scatter_nd_add.hpp"
|
||||
#include "ngraph/op/scatter_nd_update.hpp"
|
||||
#include "ngraph/op/scatter_update.hpp"
|
||||
#include "ngraph/op/select.hpp"
|
||||
|
@ -1,119 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T, typename U>
|
||||
void scatter_add(T* inputs,
|
||||
U* indices,
|
||||
T* updates,
|
||||
T* out,
|
||||
const Shape& inputs_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& updates_shape,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
using namespace std;
|
||||
// Copy inputs to out
|
||||
memcpy(out, inputs, sizeof(T) * shape_size(inputs_shape));
|
||||
// Create a CoordinateTransform for "indices"
|
||||
size_t indices_ndim = static_cast<size_t>(indices_shape.size());
|
||||
Coordinate indices_start_corner(indices_ndim, 0);
|
||||
Coordinate indices_end_corner(indices_shape);
|
||||
Strides indices_strides(indices_ndim, 1);
|
||||
AxisVector indices_axis_order(indices_ndim);
|
||||
iota(indices_axis_order.begin(), indices_axis_order.end(), 0);
|
||||
CoordinateTransform indices_transform(indices_shape,
|
||||
indices_start_corner,
|
||||
indices_end_corner,
|
||||
indices_strides,
|
||||
indices_axis_order);
|
||||
// Create an outer CoordinateTransform for "update"
|
||||
size_t updates_ndim = static_cast<size_t>(updates_shape.size());
|
||||
Coordinate updates_outer_start_corner(updates_ndim, 0);
|
||||
Coordinate updates_outer_end_corner(updates_shape);
|
||||
for (size_t i = indices_ndim; i < updates_ndim; i++)
|
||||
{
|
||||
updates_outer_end_corner[i] = 1;
|
||||
}
|
||||
Strides updates_strides(updates_ndim, 1);
|
||||
AxisVector updates_axis_order(updates_ndim);
|
||||
iota(updates_axis_order.begin(), updates_axis_order.end(), 0);
|
||||
CoordinateTransform updates_outer_transform(updates_shape,
|
||||
updates_outer_start_corner,
|
||||
updates_outer_end_corner,
|
||||
updates_strides,
|
||||
updates_axis_order);
|
||||
// Common vars for out
|
||||
size_t out_ndim = static_cast<size_t>(out_shape.size());
|
||||
Strides out_strides(out_ndim, 1);
|
||||
AxisVector out_axis_order(out_ndim);
|
||||
iota(out_axis_order.begin(), out_axis_order.end(), 0);
|
||||
// Visit one updates silce and one out silce at a time.
|
||||
auto updates_outer_coord_iter = updates_outer_transform.begin();
|
||||
for (const Coordinate& indices_coord : indices_transform)
|
||||
{
|
||||
auto indices_index = indices_transform.index(indices_coord);
|
||||
U slice_index = indices[indices_index];
|
||||
// Create CoordinateTransform for out slice
|
||||
Coordinate out_start_corner(out_ndim, 0);
|
||||
Coordinate out_end_corner(out_shape);
|
||||
out_start_corner[0] = static_cast<size_t>(slice_index);
|
||||
out_end_corner[0] = out_start_corner[0] + 1;
|
||||
CoordinateTransform out_transform(
|
||||
out_shape, out_start_corner, out_end_corner, out_strides, out_axis_order);
|
||||
// Create CoordinateTransform for updates slice
|
||||
Coordinate updates_inner_start_corner = *updates_outer_coord_iter;
|
||||
Coordinate updates_inner_end_corner(updates_shape);
|
||||
for (size_t i = 0; i < indices_ndim; i++)
|
||||
{
|
||||
updates_inner_end_corner[i] = updates_inner_start_corner[i] + 1;
|
||||
}
|
||||
CoordinateTransform updates_inner_transform(updates_shape,
|
||||
updates_inner_start_corner,
|
||||
updates_inner_end_corner,
|
||||
updates_strides,
|
||||
updates_axis_order);
|
||||
|
||||
// Add one element from updates to inputs at a time
|
||||
auto updates_inner_coord_iter = updates_inner_transform.begin();
|
||||
for (const Coordinate& out_coord : out_transform)
|
||||
{
|
||||
if (updates_inner_coord_iter == updates_inner_transform.end())
|
||||
{
|
||||
break;
|
||||
}
|
||||
out[out_transform.index(out_coord)] +=
|
||||
updates[updates_inner_transform.index(*updates_inner_coord_iter)];
|
||||
updates_inner_coord_iter++;
|
||||
}
|
||||
updates_outer_coord_iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,114 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T, typename U>
|
||||
void scatter_nd_add(T* inputs,
|
||||
U* indices,
|
||||
T* updates,
|
||||
T* out,
|
||||
const Shape& inputs_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& updates_shape,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
using namespace std;
|
||||
// Copy inputs to out
|
||||
memcpy(out, inputs, sizeof(T) * shape_size(inputs_shape));
|
||||
// Create a CoordinateTransform for "indices" that visits only the first element
|
||||
// along inner most axis
|
||||
size_t indices_ndim = static_cast<size_t>(indices_shape.size());
|
||||
Coordinate indices_outer_start_corner(indices_ndim, 0);
|
||||
Coordinate indices_outer_end_corner(indices_shape);
|
||||
size_t slice_rank = indices_shape[indices_ndim - 1];
|
||||
indices_outer_end_corner[indices_ndim - 1] = 1;
|
||||
Strides indices_strides(indices_ndim, 1);
|
||||
AxisVector indices_axis_order(indices_ndim);
|
||||
std::iota(indices_axis_order.begin(), indices_axis_order.end(), 0);
|
||||
CoordinateTransform indices_outer_transform(indices_shape,
|
||||
indices_outer_start_corner,
|
||||
indices_outer_end_corner,
|
||||
indices_strides,
|
||||
indices_axis_order);
|
||||
|
||||
// Create a matching CoordinateTransform for "updates" that visits the same outer
|
||||
// coordinates
|
||||
size_t updates_ndim = static_cast<size_t>(updates_shape.size());
|
||||
Strides updates_strides(updates_ndim, 1);
|
||||
AxisVector updates_axis_order(updates_ndim);
|
||||
std::iota(updates_axis_order.begin(), updates_axis_order.end(), 0);
|
||||
Coordinate updates_outer_start_corner(updates_ndim, 0);
|
||||
Coordinate updates_outer_end_corner(updates_shape);
|
||||
for (size_t i = indices_ndim - 1; i < updates_ndim; i++)
|
||||
{
|
||||
updates_outer_end_corner[i] = 1;
|
||||
}
|
||||
CoordinateTransform updates_outer_transform(updates_shape,
|
||||
updates_outer_start_corner,
|
||||
updates_outer_end_corner,
|
||||
updates_strides,
|
||||
updates_axis_order);
|
||||
|
||||
// Add an updates slice to a slice on out indexed by innermost dim ofindices
|
||||
size_t out_ndim = static_cast<size_t>(out_shape.size());
|
||||
Strides out_strides(out_ndim, 1);
|
||||
AxisVector out_axis_order(out_ndim);
|
||||
std::iota(out_axis_order.begin(), out_axis_order.end(), 0);
|
||||
|
||||
auto updates_outer_coord_iter = updates_outer_transform.begin();
|
||||
for (const Coordinate& indices_coord : indices_outer_transform)
|
||||
{
|
||||
if (updates_outer_coord_iter == updates_outer_transform.end())
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
Coordinate out_start_corner(out_ndim, 0);
|
||||
Coordinate out_end_corner(out_shape);
|
||||
auto indices_index = indices_outer_transform.index(indices_coord);
|
||||
for (size_t i = 0; i < slice_rank; i++)
|
||||
{
|
||||
U index = indices[indices_index];
|
||||
out_start_corner[i] = index;
|
||||
out_end_corner[i] = index + 1;
|
||||
indices_index++;
|
||||
}
|
||||
CoordinateTransform out_transform(
|
||||
out_shape, out_start_corner, out_end_corner, out_strides, out_axis_order);
|
||||
auto updates_index = updates_outer_transform.index(*updates_outer_coord_iter);
|
||||
for (const Coordinate& out_coord : out_transform)
|
||||
{
|
||||
out[out_transform.index(out_coord)] += updates[updates_index];
|
||||
updates_index++;
|
||||
}
|
||||
updates_outer_coord_iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1951,21 +1951,6 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
|
||||
node = make_shared<op::ScaleShift>(args[0], args[1], args[2]);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterAdd:
|
||||
{
|
||||
node = make_shared<op::ScatterAdd>(args[0], args[1], args[2]);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterND:
|
||||
{
|
||||
node = make_shared<op::ScatterND>(args[0], args[1], args[2]);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterNDAdd:
|
||||
{
|
||||
node = make_shared<op::ScatterNDAdd>(args[0], args[1], args[2]);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::Select:
|
||||
{
|
||||
node = make_shared<op::Select>(args[0], args[1], args[2]);
|
||||
@ -2996,12 +2981,6 @@ json JSONSerializer::serialize_node(const Node& n)
|
||||
}
|
||||
case OP_TYPEID::ScaleShift: { break;
|
||||
}
|
||||
case OP_TYPEID::ScatterAdd: { break;
|
||||
}
|
||||
case OP_TYPEID::ScatterND: { break;
|
||||
}
|
||||
case OP_TYPEID::ScatterNDAdd: { break;
|
||||
}
|
||||
case OP_TYPEID::Select: { break;
|
||||
}
|
||||
case OP_TYPEID::Selu: { break;
|
||||
|
@ -173,9 +173,7 @@ set(SRC
|
||||
type_prop/roi_align.cpp
|
||||
type_prop/rnn_cell.cpp
|
||||
type_prop/scale_shift.cpp
|
||||
type_prop/scatter_add.cpp
|
||||
type_prop/scatter_elements_update.cpp
|
||||
type_prop/scatter_nd.cpp
|
||||
type_prop/scatter_nd_update.cpp
|
||||
type_prop/scatter_update.cpp
|
||||
type_prop/select.cpp
|
||||
@ -352,7 +350,6 @@ set(MULTI_TEST_SRC
|
||||
backend/reverse_sequence.in.cpp
|
||||
backend/reverse.in.cpp
|
||||
backend/round.in.cpp
|
||||
backend/scatter.in.cpp
|
||||
backend/select.in.cpp
|
||||
backend/shape_of.in.cpp
|
||||
backend/sigmoid.in.cpp
|
||||
|
@ -1,288 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/runtime/tensor.hpp"
|
||||
#include "runtime/backend.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
#include "util/ndarray.hpp"
|
||||
#include "util/random.hpp"
|
||||
#include "util/test_control.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
static string s_manifest = "${MANIFEST}";
|
||||
|
||||
#if 0
|
||||
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{2, 3, 4, 2};
|
||||
Shape updates_shape{2, 3, 4, 2, 3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
auto f =
|
||||
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
// Create some tensors for input/output
|
||||
auto r = backend->create_tensor(element::f32, ref_shape);
|
||||
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5,
|
||||
6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
|
||||
auto i = backend->create_tensor(element::i32, indices_shape);
|
||||
copy_data(i, vector<int32_t>{0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0,
|
||||
1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1,
|
||||
2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2});
|
||||
auto u = backend->create_tensor(element::f32, updates_shape);
|
||||
copy_data(u,
|
||||
vector<float>{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
|
||||
auto result = backend->create_tensor(element::f32, out_shape);
|
||||
|
||||
auto c = backend->compile(f);
|
||||
c->call_with_validate({result}, {r, i, u});
|
||||
EXPECT_TRUE(test::all_close_f(
|
||||
(vector<float>{0, 17, 34, 51, 68, 85, 102, 119, 136, 17, 34, 51, 68, 85,
|
||||
102, 119, 136, 153, 0, 17, 34, 51, 68, 85, 102, 119, 136}),
|
||||
read_vector<float>(result),
|
||||
MIN_FLOAT_TOLERANCE_BITS));
|
||||
}
|
||||
#endif
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
|
||||
{
|
||||
Shape ref_shape{2, 3, 3};
|
||||
Shape indices_shape{2, 2, 2};
|
||||
Shape updates_shape{2, 2, 2, 3, 3};
|
||||
Shape out_shape{2, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
auto f =
|
||||
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
// Create some tensors for input/output
|
||||
auto r = backend->create_tensor(element::f32, ref_shape);
|
||||
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
auto i = backend->create_tensor(element::i32, indices_shape);
|
||||
copy_data(i, vector<int32_t>{0, 1, 1, 0, 0, 1, 1, 0});
|
||||
auto u = backend->create_tensor(element::f32, updates_shape);
|
||||
copy_data(u, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
|
||||
auto result = backend->create_tensor(element::f32, out_shape);
|
||||
|
||||
auto c = backend->compile(f);
|
||||
c->call_with_validate({result}, {r, i, u});
|
||||
EXPECT_TRUE(test::all_close_f(
|
||||
(vector<float>{0, 5, 10, 15, 20, 25, 30, 35, 40, 5, 10, 15, 20, 25, 30, 35, 40, 45}),
|
||||
read_vector<float>(result),
|
||||
MIN_FLOAT_TOLERANCE_BITS));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
|
||||
{
|
||||
Shape ref_shape{3};
|
||||
Shape indices_shape{2, 2};
|
||||
Shape updates_shape{2, 2};
|
||||
Shape out_shape{3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
auto f =
|
||||
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
// Create some tensors for input/output
|
||||
auto r = backend->create_tensor(element::f32, ref_shape);
|
||||
copy_data(r, vector<float>{0, 1, 2});
|
||||
auto i = backend->create_tensor(element::i32, indices_shape);
|
||||
copy_data(i, vector<int32_t>{0, 1, 1, 0});
|
||||
auto u = backend->create_tensor(element::f32, updates_shape);
|
||||
copy_data(u, vector<float>{1, 2, 3, 4});
|
||||
auto result = backend->create_tensor(element::f32, out_shape);
|
||||
|
||||
auto c = backend->compile(f);
|
||||
c->call_with_validate({result}, {r, i, u});
|
||||
EXPECT_TRUE(test::all_close_f(
|
||||
(vector<float>{5, 6, 2}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_1d_indices)
|
||||
{
|
||||
Shape ref_shape{2, 3, 3};
|
||||
Shape indices_shape{2};
|
||||
Shape updates_shape{2, 3, 3};
|
||||
Shape out_shape{2, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
auto f =
|
||||
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
// Create some tensors for input/output
|
||||
auto r = backend->create_tensor(element::f32, ref_shape);
|
||||
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
auto i = backend->create_tensor(element::i32, indices_shape);
|
||||
copy_data(i, vector<int32_t>{1, 0});
|
||||
auto u = backend->create_tensor(element::f32, updates_shape);
|
||||
copy_data(u, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
|
||||
auto result = backend->create_tensor(element::f32, out_shape);
|
||||
|
||||
auto c = backend->compile(f);
|
||||
c->call_with_validate({result}, {r, i, u});
|
||||
EXPECT_TRUE(test::all_close_f(
|
||||
(vector<float>{0, 2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16, 18}),
|
||||
read_vector<float>(result),
|
||||
MIN_FLOAT_TOLERANCE_BITS));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_scalar_indices)
|
||||
{
|
||||
Shape ref_shape{2, 3, 3};
|
||||
Shape indices_shape{};
|
||||
Shape updates_shape{3, 3};
|
||||
Shape out_shape{2, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
auto f =
|
||||
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
// Create some tensors for input/output
|
||||
auto r = backend->create_tensor(element::f32, ref_shape);
|
||||
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
auto i = backend->create_tensor(element::i32, indices_shape);
|
||||
copy_data(i, vector<int32_t>{1});
|
||||
auto u = backend->create_tensor(element::f32, updates_shape);
|
||||
copy_data(u, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
auto result = backend->create_tensor(element::f32, out_shape);
|
||||
|
||||
auto c = backend->compile(f);
|
||||
c->call_with_validate({result}, {r, i, u});
|
||||
EXPECT_TRUE(test::all_close_f(
|
||||
(vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 2, 4, 6, 8, 10, 12, 14, 16, 18}),
|
||||
read_vector<float>(result),
|
||||
MIN_FLOAT_TOLERANCE_BITS));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, scatter_nd_add_batch_2d_to_3d)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
auto f =
|
||||
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
// Create some tensors for input/output
|
||||
auto r = backend->create_tensor(element::f32, ref_shape);
|
||||
copy_data(r, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5,
|
||||
5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9});
|
||||
auto i = backend->create_tensor(element::i32, indices_shape);
|
||||
copy_data(i, vector<int32_t>{0, 2});
|
||||
auto u = backend->create_tensor(element::f32, updates_shape);
|
||||
copy_data(u, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3, 7, 7, 7, 8, 8, 8, 9, 9, 9});
|
||||
auto result = backend->create_tensor(element::f32, out_shape);
|
||||
|
||||
auto c = backend->compile(f);
|
||||
c->call_with_validate({result}, {r, i, u});
|
||||
EXPECT_TRUE(test::all_close_f((vector<float>{2, 2, 2, 4, 4, 4, 6, 6, 6, 4, 4, 4, 5, 5,
|
||||
5, 6, 6, 6, 14, 14, 14, 16, 16, 16, 18, 18, 18}),
|
||||
read_vector<float>(result),
|
||||
MIN_FLOAT_TOLERANCE_BITS));
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, scatter_nd_add_2d_to_3d)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{1};
|
||||
Shape updates_shape{3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
auto f =
|
||||
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
// Create some tensors for input/output
|
||||
auto r = backend->create_tensor(element::f32, ref_shape);
|
||||
copy_data(r, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5,
|
||||
5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9});
|
||||
auto i = backend->create_tensor(element::i32, indices_shape);
|
||||
copy_data(i, vector<int32_t>{0});
|
||||
auto u = backend->create_tensor(element::f32, updates_shape);
|
||||
copy_data(u, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3});
|
||||
auto result = backend->create_tensor(element::f32, out_shape);
|
||||
|
||||
auto c = backend->compile(f);
|
||||
c->call_with_validate({result}, {r, i, u});
|
||||
EXPECT_TRUE(test::all_close_f((vector<float>{2, 2, 2, 4, 4, 4, 6, 6, 6, 4, 4, 4, 5, 5,
|
||||
5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9}),
|
||||
read_vector<float>(result),
|
||||
MIN_FLOAT_TOLERANCE_BITS));
|
||||
}
|
@ -899,33 +899,6 @@ namespace
|
||||
EXPECT_FALSE(node.is_binary_elementwise_logical());
|
||||
}
|
||||
|
||||
void op_is_ScatterAdd()
|
||||
{
|
||||
op::ScatterAdd node;
|
||||
EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_comparison());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_logical());
|
||||
}
|
||||
|
||||
void op_is_ScatterND()
|
||||
{
|
||||
op::ScatterND node;
|
||||
EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_comparison());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_logical());
|
||||
}
|
||||
|
||||
void op_is_ScatterNDAdd()
|
||||
{
|
||||
op::ScatterNDAdd node;
|
||||
EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_comparison());
|
||||
EXPECT_FALSE(node.is_binary_elementwise_logical());
|
||||
}
|
||||
|
||||
void op_is_Select()
|
||||
{
|
||||
op::Select node;
|
||||
|
@ -553,14 +553,6 @@ all_change_axis
|
||||
# Positive input shape should be the same as negative input shape
|
||||
select_v1
|
||||
|
||||
# Cannot cast ngraph node ScatterNDAdd to CNNLayer!
|
||||
scatter_add_3d_indices
|
||||
scatter_add_2d_indices
|
||||
scatter_add_1d_indices
|
||||
scatter_add_scalar_indices
|
||||
scatter_nd_add_batch_2d_to_3d
|
||||
scatter_nd_add_2d_to_3d
|
||||
|
||||
# Cannot cast ngraph node Reverse to CNNLayer!
|
||||
reverse_1d_0
|
||||
reverse_2d_0
|
||||
|
@ -79,8 +79,6 @@
|
||||
#include "ngraph/runtime/reference/reverse.hpp"
|
||||
#include "ngraph/runtime/reference/reverse_sequence.hpp"
|
||||
#include "ngraph/runtime/reference/round.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_add.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
|
||||
#include "ngraph/runtime/reference/select.hpp"
|
||||
#include "ngraph/runtime/reference/send.hpp"
|
||||
#include "ngraph/runtime/reference/sigmoid.hpp"
|
||||
@ -1126,66 +1124,6 @@ protected:
|
||||
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterAdd:
|
||||
{
|
||||
if (node.get_input_element_type(1) == element::i64)
|
||||
{
|
||||
reference::scatter_add<T, int64_t>(args[0]->get_data_ptr<T>(),
|
||||
args[1]->get_data_ptr<int64_t>(),
|
||||
args[2]->get_data_ptr<T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2),
|
||||
node.get_output_shape(0));
|
||||
}
|
||||
else if (node.get_input_element_type(1) == element::i32)
|
||||
{
|
||||
reference::scatter_add<T, int32_t>(args[0]->get_data_ptr<T>(),
|
||||
args[1]->get_data_ptr<int32_t>(),
|
||||
args[2]->get_data_ptr<T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2),
|
||||
node.get_output_shape(0));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Unexpected type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterNDAdd:
|
||||
{
|
||||
if (node.get_input_element_type(1) == element::i64)
|
||||
{
|
||||
reference::scatter_nd_add<T, int64_t>(args[0]->get_data_ptr<T>(),
|
||||
args[1]->get_data_ptr<int64_t>(),
|
||||
args[2]->get_data_ptr<T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2),
|
||||
node.get_output_shape(0));
|
||||
}
|
||||
else if (node.get_input_element_type(1) == element::i32)
|
||||
{
|
||||
reference::scatter_nd_add<T, int32_t>(args[0]->get_data_ptr<T>(),
|
||||
args[1]->get_data_ptr<int32_t>(),
|
||||
args[2]->get_data_ptr<T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2),
|
||||
node.get_output_shape(0));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Unexpected type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::Select:
|
||||
{
|
||||
size_t element_count = shape_size(node.get_output_shape(0));
|
||||
@ -1332,7 +1270,6 @@ protected:
|
||||
case OP_TYPEID::PRelu:
|
||||
case OP_TYPEID::RNNCell:
|
||||
case OP_TYPEID::ScaleShift:
|
||||
case OP_TYPEID::ScatterND:
|
||||
case OP_TYPEID::Selu:
|
||||
case OP_TYPEID::ShuffleChannels:
|
||||
case OP_TYPEID::SoftmaxCrossEntropy:
|
||||
|
@ -143,9 +143,6 @@ NGRAPH_OP(ReverseSequence, ngraph::op)
|
||||
NGRAPH_OP(RNNCell, ngraph::op)
|
||||
NGRAPH_OP(Round, ngraph::op)
|
||||
NGRAPH_OP(ScaleShift, ngraph::op)
|
||||
NGRAPH_OP(ScatterAdd, ngraph::op)
|
||||
NGRAPH_OP(ScatterND, ngraph::op)
|
||||
NGRAPH_OP(ScatterNDAdd, ngraph::op)
|
||||
NGRAPH_OP(Select, ngraph::op)
|
||||
NGRAPH_OP(Selu, ngraph::op)
|
||||
NGRAPH_OP(Send, ngraph::op)
|
||||
|
@ -1,399 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, scatter_add_fail_indices_element_type)
|
||||
{
|
||||
Shape ref_shape{2, 3, 3};
|
||||
Shape indices_shape{2, 2};
|
||||
Shape updates_shape{2, 2, 3, 3};
|
||||
Shape out_shape{2, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices element type must be i64 or i32"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_fail_updates_element_type)
|
||||
{
|
||||
Shape ref_shape{2, 3, 3};
|
||||
Shape indices_shape{2, 2};
|
||||
Shape updates_shape{2, 2, 3, 3};
|
||||
Shape out_shape{2, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::i32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Updates element type must be the same as Inputs"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_fail_updates_rank)
|
||||
{
|
||||
Shape ref_shape{2, 3, 3};
|
||||
Shape indices_shape{2, 2};
|
||||
Shape updates_shape{2, 3, 3};
|
||||
Shape out_shape{2, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates rank";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates rank is expected to be indices rank + inputs rank - 1"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_fail_updates_shape)
|
||||
{
|
||||
Shape ref_shape{2, 3, 3};
|
||||
Shape indices_shape{2, 2};
|
||||
Shape updates_shape{1, 2, 3, 3};
|
||||
Shape out_shape{2, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates shape";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Updates shape must be indices_shape + inputs_shape[1:]"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_indices_element_type)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto A = op::Constant::create(element::i64, Shape{}, {1});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(), std::string("Indices element type must be of an integral number type"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_updates_data_et_not_equal)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::u32, updates_shape);
|
||||
auto A = op::Constant::create(element::u32, Shape{1}, {1});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Element types for input data and updates do not match"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_axis_element_type)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::i16, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::u64, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::i16, updates_shape);
|
||||
auto A = op::Constant::create(element::f32, Shape{1}, {1.5f});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Axis element type must be of an integral number type"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_axis_shape)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::u8, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::u16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::u8, updates_shape);
|
||||
auto A = op::Constant::create(element::u8, Shape{2}, {1, 5});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Axis input shape is required to be scalar or 1D tensor"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_updates_rank)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::f64, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f64, updates_shape);
|
||||
auto A = op::Constant::create(element::u8, Shape{}, {0});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates rank is expected to be indices rank + data rank - 1"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_updates_shape_axis)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::u64, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::u64, updates_shape);
|
||||
auto A = op::Constant::create(element::u16, Shape{}, {0});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates shape must have appropriate dimensions equal to indices and "
|
||||
"data dimensions"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_updates_shape_indices)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 3, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::u32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::u32, updates_shape);
|
||||
auto A = op::Constant::create(element::i32, Shape{}, {1});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates shape must have appropriate dimensions equal to indices and "
|
||||
"data dimensions"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_updates_shape_data_before_axis)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{3, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::u16, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::u16, updates_shape);
|
||||
auto A = op::Constant::create(element::i8, Shape{}, {1});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates shape must have appropriate dimensions equal to indices and "
|
||||
"data dimensions"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_fail_updates_shape_data_after_axis)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 5};
|
||||
auto R = make_shared<op::Parameter>(element::i8, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::i8, updates_shape);
|
||||
auto A = op::Constant::create(element::i16, Shape{}, {1});
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates shape must have appropriate dimensions equal to indices and "
|
||||
"data dimensions"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3)
|
||||
{
|
||||
Shape ref_shape{2, 3, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::i8, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::i8, updates_shape);
|
||||
auto A = op::Constant::create(element::i16, Shape{}, {1});
|
||||
|
||||
auto scatter_update = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
EXPECT_EQ(scatter_update->get_output_element_type(0), element::i8);
|
||||
EXPECT_EQ(scatter_update->get_output_shape(0), ref_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_add_v3_dynamic_data_shape)
|
||||
{
|
||||
PartialShape ref_shape = PartialShape::dynamic();
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{2, 2, 1, 4};
|
||||
auto R = make_shared<op::Parameter>(element::i8, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::i8, updates_shape);
|
||||
auto A = op::Constant::create(element::i16, Shape{}, {1});
|
||||
|
||||
auto scatter_update = make_shared<op::v3::ScatterAdd>(R, I, U, A);
|
||||
EXPECT_EQ(scatter_update->get_output_element_type(0), element::i8);
|
||||
EXPECT_TRUE(scatter_update->get_output_partial_shape(0).is_dynamic());
|
||||
}
|
@ -1,263 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, scatter_nd_add_fail_indices_element_type)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{1};
|
||||
Shape updates_shape{3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices element type must be i64 or i32"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_add_fail_indices_rank)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{};
|
||||
Shape updates_shape{3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Indices rank is expected to be at least 1"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_add_fail_indices_last_dim)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{2, 4};
|
||||
Shape updates_shape{2, 3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices innermost dim";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Last dimension of indices can be at most the rank of inputs"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_add_fail_updates_element_type)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{1};
|
||||
Shape updates_shape{3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::i32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates element type";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Updates element type must be the same as inputs"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_add_fail_updates_rank)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{1};
|
||||
Shape updates_shape{3, 3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates rank";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Rank of updates must be rank of inputs + rank of indices "
|
||||
"- last dimension of indices - 1"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_add_fail_updates_shape)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{1};
|
||||
Shape updates_shape{2, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect updates shape";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string(
|
||||
"updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_fail_updates_element_type)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{1};
|
||||
Shape updates_shape{3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::i32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterND>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Created ScatterND op with incorrect updates element type.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates element type must be the same as element type of data."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_fail_updates_rank)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{1};
|
||||
Shape updates_shape{3, 3, 3};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterND>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Created ScatterND op with incorrect updates rank";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Updates rank is expected to be equal data_rank + indices_rank - "
|
||||
"indices_shape[-1] - 1."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_nd_fail_updates_shape)
|
||||
{
|
||||
Shape ref_shape{3, 3, 3};
|
||||
Shape indices_shape{4};
|
||||
Shape updates_shape{2};
|
||||
Shape out_shape{3, 3, 3};
|
||||
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::ScatterND>(R, I, U);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Created ScatterND op with incorrect indices shape";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Last dimension of indices can be at most the rank of data."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user