Reference implementation for ScatterUpdate (#1678)
* Reference implementation for ScatterUpdate and use of it in evaluate. * Review comments. Clarify comments. * Update file directory. * Replace scatter_update reference implementation in ngraph/core/reference/ * Remove template code from ScatterUpdate reference implementation * Apply review requests Co-authored-by: mitruska <katarzyna.mitrus@intel.com>
This commit is contained in:
parent
db2e5c0728
commit
393e9295cd
@ -18,6 +18,7 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/scatter_base.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -49,6 +50,9 @@ namespace ngraph
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& inputs) const override;
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1,86 +1,120 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename dataType, typename indicesType, typename axisType>
|
||||
void scatterUpdate(const dataType* inputData,
|
||||
const indicesType* indices,
|
||||
const dataType* updates,
|
||||
const axisType* _axis,
|
||||
dataType* outBuf,
|
||||
const Shape& dataShape,
|
||||
const Shape& indicesShape,
|
||||
const Shape& updatesShape)
|
||||
void scatter_update(const char* input_data,
|
||||
const int64_t* indices,
|
||||
const char* updates,
|
||||
const int64_t axis,
|
||||
char* out_buf,
|
||||
const size_t elem_size,
|
||||
const Shape& data_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& updates_shape)
|
||||
{
|
||||
int rank = static_cast<int>(dataShape.size());
|
||||
if (_axis[0] < -rank || _axis[0] > rank - 1)
|
||||
// Copy inputs to out
|
||||
std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape));
|
||||
|
||||
// Algorithm overview
|
||||
// data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...]
|
||||
// where first ... in the data corresponds to first axis dimensions,
|
||||
// last ... in the data corresponds to the rank(data) - (axis + 1) dimensions.
|
||||
|
||||
//
|
||||
// for i_coord in indices[m, n, ..., p]:
|
||||
// # get linear index
|
||||
// i_idx = index(i_coord)
|
||||
// # simultaneously iterate over two slices of data with same elements count
|
||||
// for d_coord in slice data[..., i_idx, ...],
|
||||
// u_coord in slice updates[..., i_coord, ...]
|
||||
// data[index(d_coord)] = updates[index(u_coord)]
|
||||
|
||||
CoordinateTransform indices_transform{indices_shape};
|
||||
CoordinateTransform data_transform{data_shape};
|
||||
|
||||
size_t indices_ndim = indices_shape.size();
|
||||
size_t updates_ndim = updates_shape.size();
|
||||
|
||||
// Create an outer CoordinateTransform for "update", which would allow to
|
||||
// iterate only over "indices" dimensions:
|
||||
// set to "1" all non-indices dimensions
|
||||
// updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1]
|
||||
Coordinate updates_indices_start_corner(updates_ndim, 0);
|
||||
Coordinate updates_indices_end_corner(updates_ndim, 1);
|
||||
for (size_t i = 0; i < indices_ndim; ++i)
|
||||
{
|
||||
std::string error =
|
||||
std::string("ScatterUpdate layer has out of bounds axis value: ") +
|
||||
std::to_string(_axis[0]);
|
||||
throw ngraph_error(error);
|
||||
updates_indices_end_corner[axis + i] = updates_shape[axis + i];
|
||||
}
|
||||
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
|
||||
CoordinateTransform indicesTransform{indicesShape};
|
||||
CoordinateTransform updates_indices_transform(
|
||||
updates_shape, updates_indices_start_corner, updates_indices_end_corner);
|
||||
// Is needed to simultaneously iterate over updates coordinates while
|
||||
// iterating over indices.
|
||||
auto updates_indices_coord_iter = updates_indices_transform.begin();
|
||||
|
||||
Shape dataShapeIter = dataShape;
|
||||
dataShapeIter.erase(dataShapeIter.begin() + axis);
|
||||
CoordinateTransform dataTransfIter{dataShapeIter};
|
||||
|
||||
CoordinateTransform updateTransform{updatesShape};
|
||||
CoordinateTransform dataTransform{dataShape};
|
||||
|
||||
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
|
||||
|
||||
for (const Coordinate& indicesCoordIt : indicesTransform)
|
||||
for (const Coordinate& indices_cord : indices_transform)
|
||||
{
|
||||
const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
|
||||
const size_t indices_idx = indices_transform.index(indices_cord);
|
||||
int64_t slice_index = indices[indices_idx];
|
||||
|
||||
if (indices[indicesIdx] < 0)
|
||||
{
|
||||
std::string error =
|
||||
std::string("ScatterUpdate layer has negative index value: ") +
|
||||
std::to_string(indices[indicesIdx]);
|
||||
throw ngraph_error(error);
|
||||
}
|
||||
const size_t idx = static_cast<size_t>(indices[indicesIdx]);
|
||||
if (dataShape[axis] <= idx)
|
||||
{
|
||||
std::string error =
|
||||
std::string("ScatterUpdate layer has out of bounds coordinate: ") +
|
||||
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
|
||||
"th axis";
|
||||
throw ngraph_error(error);
|
||||
}
|
||||
// Define the extent of coordinates which will be updated.
|
||||
Coordinate out_start_corner(data_shape.size(), 0);
|
||||
Coordinate out_end_corner(data_shape);
|
||||
out_start_corner[axis] = static_cast<size_t>(slice_index);
|
||||
out_end_corner[axis] = out_start_corner[axis] + 1;
|
||||
CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner);
|
||||
|
||||
for (const Coordinate& dataCoordIt : dataTransfIter)
|
||||
// Define the CoordinateTransform for updates coordinates.
|
||||
// All except indices-dimensions.
|
||||
Coordinate updates_update_start_corner = *updates_indices_coord_iter;
|
||||
Coordinate updates_update_end_corner(updates_shape);
|
||||
for (size_t i = 0; i < indices_ndim; ++i)
|
||||
{
|
||||
Coordinate dataCoord = dataCoordIt;
|
||||
dataCoord.insert(dataCoord.begin() + axis, idx);
|
||||
const size_t startIndices = dataTransform.index(dataCoord);
|
||||
|
||||
auto updCoord = dataCoordIt;
|
||||
updCoord.insert(
|
||||
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
|
||||
const size_t startUpd = updateTransform.index(updCoord);
|
||||
outBuf[startIndices] = updates[startUpd];
|
||||
updates_update_end_corner[axis + i] =
|
||||
updates_update_start_corner[axis + i] + 1;
|
||||
}
|
||||
// The m, n, .., p symbols stand for values at those axes.
|
||||
// The m+1 means value at axis m plus 1.
|
||||
// udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0]
|
||||
// updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1]
|
||||
CoordinateTransform updates_update_transform(
|
||||
updates_shape, updates_update_start_corner, updates_update_end_corner);
|
||||
auto updates_update_coord_iter = updates_update_transform.begin();
|
||||
for (const Coordinate& out_cord : out_transform)
|
||||
{
|
||||
const auto src_idx =
|
||||
updates_update_transform.index(*updates_update_coord_iter) * elem_size;
|
||||
std::copy(updates + src_idx,
|
||||
updates + (src_idx + elem_size),
|
||||
out_buf + out_transform.index(out_cord) * elem_size);
|
||||
updates_update_coord_iter++;
|
||||
}
|
||||
updates_indices_coord_iter++;
|
||||
}
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,11 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/scatter_update.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_update.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include "ngraph/type/element_type_traits.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -36,3 +40,110 @@ shared_ptr<Node> op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector
|
||||
return make_shared<v3::ScatterUpdate>(
|
||||
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
|
||||
}
|
||||
|
||||
bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const
|
||||
{
|
||||
const auto& data = inputs[0];
|
||||
const auto& indices = inputs[1];
|
||||
const auto& updates = inputs[2];
|
||||
const auto& axis = inputs[3];
|
||||
const auto& out = outputs[0];
|
||||
|
||||
const auto elem_size = data->get_element_type().size();
|
||||
out->set_shape(data->get_shape());
|
||||
|
||||
int64_t axis_val = 0;
|
||||
switch (axis->get_element_type())
|
||||
{
|
||||
case element::Type_t::i8: axis_val = axis->get_data_ptr<element::Type_t::i8>()[0]; break;
|
||||
case element::Type_t::i16: axis_val = axis->get_data_ptr<element::Type_t::i16>()[0]; break;
|
||||
case element::Type_t::i32: axis_val = axis->get_data_ptr<element::Type_t::i32>()[0]; break;
|
||||
case element::Type_t::i64: axis_val = axis->get_data_ptr<element::Type_t::i64>()[0]; break;
|
||||
case element::Type_t::u8: axis_val = axis->get_data_ptr<element::Type_t::u8>()[0]; break;
|
||||
case element::Type_t::u16: axis_val = axis->get_data_ptr<element::Type_t::u16>()[0]; break;
|
||||
case element::Type_t::u32: axis_val = axis->get_data_ptr<element::Type_t::u32>()[0]; break;
|
||||
case element::Type_t::u64: axis_val = axis->get_data_ptr<element::Type_t::u64>()[0]; break;
|
||||
default: throw ngraph_error("axis element type is not integral data type");
|
||||
}
|
||||
|
||||
if (axis_val < 0)
|
||||
{
|
||||
axis_val =
|
||||
ngraph::normalize_axis(this, axis_val, static_cast<int64_t>(data->get_shape().size()));
|
||||
}
|
||||
|
||||
std::vector<int64_t> indices_casted_vector;
|
||||
switch (indices->get_element_type())
|
||||
{
|
||||
case element::Type_t::i8:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::i8>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i16:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::i16>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i32:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::i32>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i64:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::i64>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u8:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::u8>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u16:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::u16>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u32:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::u32>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u64:
|
||||
{
|
||||
auto indices_ptr = indices->get_data_ptr<element::Type_t::u64>();
|
||||
indices_casted_vector =
|
||||
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
|
||||
break;
|
||||
}
|
||||
default: throw ngraph_error("indices element type is not integral data type");
|
||||
}
|
||||
|
||||
runtime::reference::scatter_update(data->get_data_ptr<char>(),
|
||||
indices_casted_vector.data(),
|
||||
updates->get_data_ptr<char>(),
|
||||
axis_val,
|
||||
out->get_data_ptr<char>(),
|
||||
elem_size,
|
||||
data->get_shape(),
|
||||
indices->get_shape(),
|
||||
updates->get_shape());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -54,6 +54,7 @@
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/round.hpp"
|
||||
#include "ngraph/op/scatter_elements_update.hpp"
|
||||
#include "ngraph/op/scatter_update.hpp"
|
||||
#include "ngraph/op/shape_of.hpp"
|
||||
#include "ngraph/op/sigmoid.hpp"
|
||||
#include "ngraph/op/sign.hpp"
|
||||
@ -1937,3 +1938,180 @@ TEST(eval, reduce_logical_and__neg_axis)
|
||||
}),
|
||||
ngraph::ngraph_error);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i32)
|
||||
{
|
||||
const Shape data_shape{3, 3};
|
||||
const Shape indices_shape{1, 2};
|
||||
const Shape updates_shape{1, 2, 3};
|
||||
|
||||
auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto arg2 = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto arg4 = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
|
||||
auto fun = make_shared<Function>(OutputVector{scatter_update},
|
||||
ParameterVector{arg1, arg2, arg3, arg4});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
data_shape, std::vector<float>(shape_size(data_shape))),
|
||||
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
|
||||
make_host_tensor<element::Type_t::f32>(
|
||||
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
|
||||
make_host_tensor<element::Type_t::i32>({}, {0})}));
|
||||
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
|
||||
auto cval = read_vector<float>(result_tensor);
|
||||
vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
|
||||
ASSERT_EQ(cval, out);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i64)
|
||||
{
|
||||
const Shape data_shape{3, 3};
|
||||
const Shape indices_shape{1, 2};
|
||||
const Shape updates_shape{1, 2, 3};
|
||||
|
||||
auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto arg2 = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto arg4 = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
|
||||
auto fun = make_shared<Function>(OutputVector{scatter_update},
|
||||
ParameterVector{arg1, arg2, arg3, arg4});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
data_shape, std::vector<float>(shape_size(data_shape))),
|
||||
make_host_tensor<element::Type_t::i64>(indices_shape, {1, 2}),
|
||||
make_host_tensor<element::Type_t::f32>(
|
||||
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
|
||||
make_host_tensor<element::Type_t::i64>({}, {0})}));
|
||||
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
|
||||
auto cval = read_vector<float>(result_tensor);
|
||||
vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
|
||||
ASSERT_EQ(cval, out);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_dynamic_scatter_update_basic)
|
||||
{
|
||||
const Shape data_shape{3, 3};
|
||||
const Shape indices_shape{1, 2};
|
||||
const Shape updates_shape{1, 2, 3};
|
||||
|
||||
auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
|
||||
auto fun = make_shared<Function>(OutputVector{scatter_update},
|
||||
ParameterVector{arg1, arg2, arg3, arg4});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
data_shape, std::vector<float>(shape_size(data_shape))),
|
||||
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
|
||||
make_host_tensor<element::Type_t::f32>(
|
||||
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
|
||||
make_host_tensor<element::Type_t::i64>({}, {0})}));
|
||||
|
||||
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
|
||||
auto cval = read_vector<float>(result_tensor);
|
||||
vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
|
||||
ASSERT_EQ(cval, out);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_dynamic_scatter_update_negative_axis)
|
||||
{
|
||||
const Shape data_shape{3, 3};
|
||||
const Shape indices_shape{1, 2};
|
||||
const Shape updates_shape{3, 1, 2};
|
||||
const Shape axis_shape{};
|
||||
|
||||
auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
|
||||
auto fun = make_shared<Function>(OutputVector{scatter_update},
|
||||
ParameterVector{arg1, arg2, arg3, arg4});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
data_shape, std::vector<float>(shape_size(data_shape))),
|
||||
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
|
||||
make_host_tensor<element::Type_t::f32>(
|
||||
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
|
||||
make_host_tensor<element::Type_t::i64>(axis_shape, {-1})}));
|
||||
|
||||
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
|
||||
auto cval = read_vector<float>(result_tensor);
|
||||
vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
|
||||
ASSERT_EQ(cval, out);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_dynamic_scatter_update_1d_axis)
|
||||
{
|
||||
const Shape data_shape{3, 3};
|
||||
const Shape indices_shape{1, 2};
|
||||
const Shape updates_shape{3, 1, 2};
|
||||
|
||||
auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
|
||||
auto fun = make_shared<Function>(OutputVector{scatter_update},
|
||||
ParameterVector{arg1, arg2, arg3, arg4});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
data_shape, std::vector<float>(shape_size(data_shape))),
|
||||
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
|
||||
make_host_tensor<element::Type_t::f32>(
|
||||
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
|
||||
make_host_tensor<element::Type_t::i64>({1}, {1})}));
|
||||
|
||||
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
|
||||
auto cval = read_vector<float>(result_tensor);
|
||||
vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
|
||||
ASSERT_EQ(cval, out);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32)
|
||||
{
|
||||
const Shape data_shape{3, 3, 2};
|
||||
const Shape indices_shape{1, 1};
|
||||
const Shape updates_shape{1, 1, 3, 2};
|
||||
|
||||
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto arg3 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
|
||||
auto fun = make_shared<Function>(OutputVector{scatter_update},
|
||||
ParameterVector{arg1, arg2, arg3, arg4});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(
|
||||
fun->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::i32>(
|
||||
data_shape, std::vector<int32_t>(shape_size(data_shape))),
|
||||
make_host_tensor<element::Type_t::i32>(indices_shape, {1}),
|
||||
make_host_tensor<element::Type_t::i32>(updates_shape, {1, 2, 3, 4, 5, 6}),
|
||||
make_host_tensor<element::Type_t::i64>({}, {0})}));
|
||||
|
||||
EXPECT_EQ(result_tensor->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3, 2}));
|
||||
auto cval = read_vector<int32_t>(result_tensor);
|
||||
vector<int32_t> out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0};
|
||||
ASSERT_EQ(cval, out);
|
||||
}
|
||||
|
@ -79,7 +79,6 @@
|
||||
#include "ngraph/runtime/reference/reverse_sequence.hpp"
|
||||
#include "ngraph/runtime/reference/round.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_update.hpp"
|
||||
#include "ngraph/runtime/reference/select.hpp"
|
||||
#include "ngraph/runtime/reference/sigmoid.hpp"
|
||||
#include "ngraph/runtime/reference/sign.hpp"
|
||||
@ -1195,48 +1194,6 @@ protected:
|
||||
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterUpdate_v3:
|
||||
{
|
||||
const op::v3::ScatterUpdate* scatterUpd =
|
||||
static_cast<const op::v3::ScatterUpdate*>(&node);
|
||||
|
||||
if (scatterUpd->get_input_element_type(3) != element::i64)
|
||||
throw ngraph_error(
|
||||
"ScatterNDUpdate layer support only i64 'axis' input precision!");
|
||||
|
||||
auto idxType = scatterUpd->get_input_element_type(1);
|
||||
if (idxType == element::i32)
|
||||
{
|
||||
reference::scatterUpdate<T, int32_t, int64_t>(
|
||||
args[0]->get_data_ptr<const T>(),
|
||||
args[1]->get_data_ptr<const int32_t>(),
|
||||
args[2]->get_data_ptr<const T>(),
|
||||
args[3]->get_data_ptr<const int64_t>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2));
|
||||
}
|
||||
else if (idxType == element::i64)
|
||||
{
|
||||
reference::scatterUpdate<T, int64_t, int64_t>(
|
||||
args[0]->get_data_ptr<const T>(),
|
||||
args[1]->get_data_ptr<const int64_t>(),
|
||||
args[2]->get_data_ptr<const T>(),
|
||||
args[3]->get_data_ptr<const int64_t>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error(
|
||||
"ScatterUpdate layer support only i32 and i64 'indices' input precision!");
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
|
||||
case OP_TYPEID::DepthToSpace:
|
||||
@ -1255,6 +1212,7 @@ protected:
|
||||
case OP_TYPEID::NormalizeL2:
|
||||
case OP_TYPEID::PRelu:
|
||||
case OP_TYPEID::RNNCell:
|
||||
case OP_TYPEID::ScatterUpdate_v3:
|
||||
case OP_TYPEID::Selu:
|
||||
case OP_TYPEID::ShuffleChannels:
|
||||
case OP_TYPEID::SpaceToDepth:
|
||||
|
Loading…
Reference in New Issue
Block a user