Reference implementation for ScatterUpdate (#1678)

* Reference implementation for ScatterUpdate and use of it in evaluate.

* Review comments. Clarify comments.

* Update file directory.

* Replace scatter_update reference implementation in ngraph/core/reference/

* Remove template code from ScatterUpdate reference implementation

* Apply review requests

Co-authored-by: mitruska <katarzyna.mitrus@intel.com>
This commit is contained in:
Adam Osewski 2020-08-25 05:12:39 +02:00 committed by GitHub
parent db2e5c0728
commit 393e9295cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 390 additions and 105 deletions

View File

@ -18,6 +18,7 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/scatter_base.hpp"
#include "ngraph/runtime/host_tensor.hpp"
namespace ngraph
{
@ -49,6 +50,9 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
};
}
}

View File

@ -1,86 +1,120 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include "ngraph/check.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape.hpp"
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename dataType, typename indicesType, typename axisType>
void scatterUpdate(const dataType* inputData,
const indicesType* indices,
const dataType* updates,
const axisType* _axis,
dataType* outBuf,
const Shape& dataShape,
const Shape& indicesShape,
const Shape& updatesShape)
void scatter_update(const char* input_data,
const int64_t* indices,
const char* updates,
const int64_t axis,
char* out_buf,
const size_t elem_size,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape)
{
int rank = static_cast<int>(dataShape.size());
if (_axis[0] < -rank || _axis[0] > rank - 1)
// Copy inputs to out
std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape));
// Algorithm overview
// data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...]
// where first ... in the data corresponds to first axis dimensions,
// last ... in the data corresponds to the rank(data) - (axis + 1) dimensions.
//
// for i_coord in indices[m, n, ..., p]:
// # get linear index
// i_idx = index(i_coord)
// # simultaneously iterate over two slices of data with same elements count
// for d_coord in slice data[..., i_idx, ...],
// u_coord in slice updates[..., i_coord, ...]
// data[index(d_coord)] = updates[index(u_coord)]
CoordinateTransform indices_transform{indices_shape};
CoordinateTransform data_transform{data_shape};
size_t indices_ndim = indices_shape.size();
size_t updates_ndim = updates_shape.size();
// Create an outer CoordinateTransform for "update", which would allow to
// iterate only over "indices" dimensions:
// set to "1" all non-indices dimensions
// updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1]
Coordinate updates_indices_start_corner(updates_ndim, 0);
Coordinate updates_indices_end_corner(updates_ndim, 1);
for (size_t i = 0; i < indices_ndim; ++i)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds axis value: ") +
std::to_string(_axis[0]);
throw ngraph_error(error);
updates_indices_end_corner[axis + i] = updates_shape[axis + i];
}
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
CoordinateTransform indicesTransform{indicesShape};
CoordinateTransform updates_indices_transform(
updates_shape, updates_indices_start_corner, updates_indices_end_corner);
// Is needed to simultaneously iterate over updates coordinates while
// iterating over indices.
auto updates_indices_coord_iter = updates_indices_transform.begin();
Shape dataShapeIter = dataShape;
dataShapeIter.erase(dataShapeIter.begin() + axis);
CoordinateTransform dataTransfIter{dataShapeIter};
CoordinateTransform updateTransform{updatesShape};
CoordinateTransform dataTransform{dataShape};
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
for (const Coordinate& indicesCoordIt : indicesTransform)
for (const Coordinate& indices_cord : indices_transform)
{
const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
const size_t indices_idx = indices_transform.index(indices_cord);
int64_t slice_index = indices[indices_idx];
if (indices[indicesIdx] < 0)
{
std::string error =
std::string("ScatterUpdate layer has negative index value: ") +
std::to_string(indices[indicesIdx]);
throw ngraph_error(error);
}
const size_t idx = static_cast<size_t>(indices[indicesIdx]);
if (dataShape[axis] <= idx)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds coordinate: ") +
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
"th axis";
throw ngraph_error(error);
}
// Define the extent of coordinates which will be updated.
Coordinate out_start_corner(data_shape.size(), 0);
Coordinate out_end_corner(data_shape);
out_start_corner[axis] = static_cast<size_t>(slice_index);
out_end_corner[axis] = out_start_corner[axis] + 1;
CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner);
for (const Coordinate& dataCoordIt : dataTransfIter)
// Define the CoordinateTransform for updates coordinates.
// All except indices-dimensions.
Coordinate updates_update_start_corner = *updates_indices_coord_iter;
Coordinate updates_update_end_corner(updates_shape);
for (size_t i = 0; i < indices_ndim; ++i)
{
Coordinate dataCoord = dataCoordIt;
dataCoord.insert(dataCoord.begin() + axis, idx);
const size_t startIndices = dataTransform.index(dataCoord);
auto updCoord = dataCoordIt;
updCoord.insert(
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
const size_t startUpd = updateTransform.index(updCoord);
outBuf[startIndices] = updates[startUpd];
updates_update_end_corner[axis + i] =
updates_update_start_corner[axis + i] + 1;
}
// The m, n, .., p symbols stand for values at those axes.
// The m+1 means value at axis m plus 1.
// udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0]
// updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1]
CoordinateTransform updates_update_transform(
updates_shape, updates_update_start_corner, updates_update_end_corner);
auto updates_update_coord_iter = updates_update_transform.begin();
for (const Coordinate& out_cord : out_transform)
{
const auto src_idx =
updates_update_transform.index(*updates_update_coord_iter) * elem_size;
std::copy(updates + src_idx,
updates + (src_idx + elem_size),
out_buf + out_transform.index(out_cord) * elem_size);
updates_update_coord_iter++;
}
updates_indices_coord_iter++;
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph
}
}
}

View File

@ -15,7 +15,11 @@
//*****************************************************************************
#include "ngraph/op/scatter_update.hpp"
#include "ngraph/runtime/reference/scatter_update.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type_traits.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
@ -36,3 +40,110 @@ shared_ptr<Node> op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector
return make_shared<v3::ScatterUpdate>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}
bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
const auto& data = inputs[0];
const auto& indices = inputs[1];
const auto& updates = inputs[2];
const auto& axis = inputs[3];
const auto& out = outputs[0];
const auto elem_size = data->get_element_type().size();
out->set_shape(data->get_shape());
int64_t axis_val = 0;
switch (axis->get_element_type())
{
case element::Type_t::i8: axis_val = axis->get_data_ptr<element::Type_t::i8>()[0]; break;
case element::Type_t::i16: axis_val = axis->get_data_ptr<element::Type_t::i16>()[0]; break;
case element::Type_t::i32: axis_val = axis->get_data_ptr<element::Type_t::i32>()[0]; break;
case element::Type_t::i64: axis_val = axis->get_data_ptr<element::Type_t::i64>()[0]; break;
case element::Type_t::u8: axis_val = axis->get_data_ptr<element::Type_t::u8>()[0]; break;
case element::Type_t::u16: axis_val = axis->get_data_ptr<element::Type_t::u16>()[0]; break;
case element::Type_t::u32: axis_val = axis->get_data_ptr<element::Type_t::u32>()[0]; break;
case element::Type_t::u64: axis_val = axis->get_data_ptr<element::Type_t::u64>()[0]; break;
default: throw ngraph_error("axis element type is not integral data type");
}
if (axis_val < 0)
{
axis_val =
ngraph::normalize_axis(this, axis_val, static_cast<int64_t>(data->get_shape().size()));
}
std::vector<int64_t> indices_casted_vector;
switch (indices->get_element_type())
{
case element::Type_t::i8:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i8>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i16:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i16>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i32:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i32>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i64:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i64>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u8:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u8>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u16:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u16>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u32:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u32>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u64:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u64>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
default: throw ngraph_error("indices element type is not integral data type");
}
runtime::reference::scatter_update(data->get_data_ptr<char>(),
indices_casted_vector.data(),
updates->get_data_ptr<char>(),
axis_val,
out->get_data_ptr<char>(),
elem_size,
data->get_shape(),
indices->get_shape(),
updates->get_shape());
return true;
}

View File

@ -54,6 +54,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/round.hpp"
#include "ngraph/op/scatter_elements_update.hpp"
#include "ngraph/op/scatter_update.hpp"
#include "ngraph/op/shape_of.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
@ -1937,3 +1938,180 @@ TEST(eval, reduce_logical_and__neg_axis)
}),
ngraph::ngraph_error);
}
TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i32)
{
const Shape data_shape{3, 3};
const Shape indices_shape{1, 2};
const Shape updates_shape{1, 2, 3};
auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
auto arg2 = make_shared<op::Parameter>(element::i32, indices_shape);
auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
auto arg4 = make_shared<op::Parameter>(element::i32, Shape{});
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
auto fun = make_shared<Function>(OutputVector{scatter_update},
ParameterVector{arg1, arg2, arg3, arg4});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>(
data_shape, std::vector<float>(shape_size(data_shape))),
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
make_host_tensor<element::Type_t::f32>(
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
make_host_tensor<element::Type_t::i32>({}, {0})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
auto cval = read_vector<float>(result_tensor);
vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
ASSERT_EQ(cval, out);
}
TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i64)
{
const Shape data_shape{3, 3};
const Shape indices_shape{1, 2};
const Shape updates_shape{1, 2, 3};
auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
auto arg2 = make_shared<op::Parameter>(element::i64, indices_shape);
auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
auto arg4 = make_shared<op::Parameter>(element::i64, Shape{});
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
auto fun = make_shared<Function>(OutputVector{scatter_update},
ParameterVector{arg1, arg2, arg3, arg4});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>(
data_shape, std::vector<float>(shape_size(data_shape))),
make_host_tensor<element::Type_t::i64>(indices_shape, {1, 2}),
make_host_tensor<element::Type_t::f32>(
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
make_host_tensor<element::Type_t::i64>({}, {0})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
auto cval = read_vector<float>(result_tensor);
vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
ASSERT_EQ(cval, out);
}
TEST(eval, evaluate_dynamic_scatter_update_basic)
{
const Shape data_shape{3, 3};
const Shape indices_shape{1, 2};
const Shape updates_shape{1, 2, 3};
auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
auto fun = make_shared<Function>(OutputVector{scatter_update},
ParameterVector{arg1, arg2, arg3, arg4});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>(
data_shape, std::vector<float>(shape_size(data_shape))),
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
make_host_tensor<element::Type_t::f32>(
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
make_host_tensor<element::Type_t::i64>({}, {0})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
auto cval = read_vector<float>(result_tensor);
vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
ASSERT_EQ(cval, out);
}
TEST(eval, evaluate_dynamic_scatter_update_negative_axis)
{
const Shape data_shape{3, 3};
const Shape indices_shape{1, 2};
const Shape updates_shape{3, 1, 2};
const Shape axis_shape{};
auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
auto fun = make_shared<Function>(OutputVector{scatter_update},
ParameterVector{arg1, arg2, arg3, arg4});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>(
data_shape, std::vector<float>(shape_size(data_shape))),
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
make_host_tensor<element::Type_t::f32>(
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
make_host_tensor<element::Type_t::i64>(axis_shape, {-1})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
auto cval = read_vector<float>(result_tensor);
vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
ASSERT_EQ(cval, out);
}
TEST(eval, evaluate_dynamic_scatter_update_1d_axis)
{
const Shape data_shape{3, 3};
const Shape indices_shape{1, 2};
const Shape updates_shape{3, 1, 2};
auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
auto fun = make_shared<Function>(OutputVector{scatter_update},
ParameterVector{arg1, arg2, arg3, arg4});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>(
data_shape, std::vector<float>(shape_size(data_shape))),
make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
make_host_tensor<element::Type_t::f32>(
updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
make_host_tensor<element::Type_t::i64>({1}, {1})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
auto cval = read_vector<float>(result_tensor);
vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
ASSERT_EQ(cval, out);
}
TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32)
{
const Shape data_shape{3, 3, 2};
const Shape indices_shape{1, 1};
const Shape updates_shape{1, 1, 3, 2};
auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto arg3 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
auto fun = make_shared<Function>(OutputVector{scatter_update},
ParameterVector{arg1, arg2, arg3, arg4});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(
fun->evaluate({result_tensor},
{make_host_tensor<element::Type_t::i32>(
data_shape, std::vector<int32_t>(shape_size(data_shape))),
make_host_tensor<element::Type_t::i32>(indices_shape, {1}),
make_host_tensor<element::Type_t::i32>(updates_shape, {1, 2, 3, 4, 5, 6}),
make_host_tensor<element::Type_t::i64>({}, {0})}));
EXPECT_EQ(result_tensor->get_element_type(), element::i32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3, 2}));
auto cval = read_vector<int32_t>(result_tensor);
vector<int32_t> out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0};
ASSERT_EQ(cval, out);
}

View File

@ -79,7 +79,6 @@
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/round.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
#include "ngraph/runtime/reference/scatter_update.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
@ -1195,48 +1194,6 @@ protected:
break;
}
case OP_TYPEID::ScatterUpdate_v3:
{
const op::v3::ScatterUpdate* scatterUpd =
static_cast<const op::v3::ScatterUpdate*>(&node);
if (scatterUpd->get_input_element_type(3) != element::i64)
throw ngraph_error(
"ScatterNDUpdate layer support only i64 'axis' input precision!");
auto idxType = scatterUpd->get_input_element_type(1);
if (idxType == element::i32)
{
reference::scatterUpdate<T, int32_t, int64_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int32_t>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else if (idxType == element::i64)
{
reference::scatterUpdate<T, int64_t, int64_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int64_t>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else
{
throw ngraph_error(
"ScatterUpdate layer support only i32 and i64 'indices' input precision!");
}
break;
}
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
case OP_TYPEID::DepthToSpace:
@ -1255,6 +1212,7 @@ protected:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScatterUpdate_v3:
case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SpaceToDepth: