Shell implementation for Roll operation. (#4843)

* Added shell implementation for Roll operation.

* Added test, scalar check corrected.

* Code style correction.

* Comment fixed.

* Removed redundant virtual.

* Added parentheses.

* Fixed tests, added axes values check.

* Fixed scalar check, added axes check.

* Added comment.

* Added static shape checks, added more tests.

* Corrected error messages.

* Corrected error messages.
This commit is contained in:
Anastasia Popova 2021-03-24 09:58:13 +03:00 committed by GitHub
parent 568f320cbc
commit 83e5bde4ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 366 additions and 0 deletions

View File

@ -0,0 +1,56 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v7
{
/// \brief Tensor roll operation.
class NGRAPH_API Roll : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
Roll() = default;
///
/// \brief Constructs a roll operation.
///
/// \param data Node producing the tensor to be shifted.
/// \param shift Node producing the 0D or 1D tensor which specifies the
/// number of places by which the elements are shifted.
/// \param axes Node producing the 0D or 1D tensor which specifies axes
/// along which elements are shifted.
///
Roll(const Output<Node>& data, const Output<Node>& shift, const Output<Node>& axes);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v7
} // namespace op
} // namespace ngraph

View File

@ -137,6 +137,7 @@
#include "ngraph/op/rnn_sequence.hpp"
#include "ngraph/op/roi_align.hpp"
#include "ngraph/op/roi_pooling.hpp"
#include "ngraph/op/roll.hpp"
#include "ngraph/op/round.hpp"
#include "ngraph/op/scatter_elements_update.hpp"
#include "ngraph/op/scatter_nd_update.hpp"

View File

@ -183,3 +183,4 @@ NGRAPH_OP(ReadValue, ngraph::op::v6) // new version
// New operations added in opset7
NGRAPH_OP(Gelu, ngraph::op::v7)
NGRAPH_OP(Roll, ngraph::op::v7)

118
ngraph/core/src/op/roll.cpp Normal file
View File

@ -0,0 +1,118 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <ngraph/validation_util.hpp>
#include "itt.hpp"
#include "ngraph/op/roll.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::v7::Roll, "Roll", 7);
op::v7::Roll::Roll(const Output<Node>& data, const Output<Node>& shift, const Output<Node>& axes)
: Op({data, shift, axes})
{
constructor_validate_and_infer_types();
}
void op::v7::Roll::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v7_Roll_validate_and_infer_types);
const auto& shift_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
shift_et.is_dynamic() || shift_et == element::i32 ||
shift_et == element::i64,
"Shift must have int32 or int64 element type.");
const auto& axes_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
axes_et.is_dynamic() || axes_et == element::i32 ||
axes_et == element::i64,
"Axes must have int32 or int64 element type.");
const auto& data_pshape = get_input_partial_shape(0);
const auto& shift_pshape = get_input_partial_shape(1);
const auto& axes_pshape = get_input_partial_shape(2);
if (shift_pshape.is_static())
{
const auto& shift_rank = shift_pshape.rank().get_length();
NODE_VALIDATION_CHECK(this, shift_rank <= 1, "Shift must be a scalar or 1D tensor.");
}
if (axes_pshape.is_static())
{
const auto& axes_rank = axes_pshape.rank().get_length();
NODE_VALIDATION_CHECK(this, axes_rank <= 1, "Axes must be a scalar or 1D tensor.");
}
// If shift is a scalar, than axes can be arbitrary 1d tensor and we don't need
// to check shift shape consistency with axes, otherwise the check is needed.
if (!(shift_pshape.is_static() && is_scalar(shift_pshape.to_shape())))
{
NODE_VALIDATION_CHECK(
this,
shift_pshape.compatible(axes_pshape),
"If shift is a 1D vector, axes must be a 1D tensor of the same size.");
}
if (const auto& const_axes = get_constant_from_source(input_value(2)))
{
auto axes = const_axes->cast_vector<int64_t>();
if (data_pshape.is_static())
{
const auto& data_rank = data_pshape.rank().get_length();
for (int64_t& axis : axes)
{
NODE_VALIDATION_CHECK(this,
axis < data_rank,
"Axes must be less than data tensor rank. Got "
"data tensor rank: ",
data_rank,
", axis: ",
axis);
if (axis < 0)
{
axis += data_rank;
}
NODE_VALIDATION_CHECK(this,
axis >= 0,
"Axes must be positive or equal to zero. Got "
"axis: ",
axis);
}
}
}
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
bool op::v7::Roll::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v7_Roll_visit_attributes);
return true;
}
shared_ptr<Node> op::v7::Roll::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v7_Roll_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v7::Roll>(new_args[0], new_args[1], new_args[2]);
}

View File

@ -177,6 +177,7 @@ set(SRC
type_prop/reverse_sequence.cpp
type_prop/roi_align.cpp
type_prop/roi_pooling.cpp
type_prop/roll.cpp
type_prop/round.cpp
type_prop/rnn_cell.cpp
type_prop/rnn_sequence.cpp

View File

@ -0,0 +1,189 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset7.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, roll_output_shape_type_test)
{
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3, 4, 1, 5});
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{2});
auto axes = make_shared<opset7::Parameter>(element::i64, Shape{2});
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{3, 3, 4, 1, 5}));
}
TEST(type_prop, roll_axis_const_test)
{
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3, 3});
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{3});
auto axes = opset7::Constant::create(element::i64, Shape{3}, {0, 1, -1});
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{3, 3, 3}));
}
TEST(type_prop, roll_incorrect_axis_test)
{
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3});
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{2});
auto axes = opset7::Constant::create(element::i64, Shape{2}, {0, 2});
try
{
auto r = make_shared<opset7::Roll>(arg, shift, axes);
// Should have thrown, so fail if it didn't
FAIL() << "Unexpected pass with invalid axes and shift.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Axes must be less than data tensor rank."));
}
catch (...)
{
FAIL() << "Check failed for unexpected reason";
}
}
TEST(type_prop, roll_incorrect_negative_axis_test)
{
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3});
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{2});
auto axes = opset7::Constant::create(element::i64, Shape{2}, {0, -5});
try
{
auto r = make_shared<opset7::Roll>(arg, shift, axes);
// Should have thrown, so fail if it didn't
FAIL() << "Unexpected pass with invalid axes and shift.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Axes must be positive or equal to zero."));
}
catch (...)
{
FAIL() << "Check failed for unexpected reason";
}
}
TEST(type_prop, roll_axis_scalar_test)
{
auto arg = make_shared<opset7::Parameter>(element::i32, Shape{3, 3, 4});
auto shift = opset7::Constant::create(element::i64, Shape{}, {5});
auto axes = make_shared<opset7::Parameter>(element::i32, Shape{3});
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::i32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{3, 3, 4}));
}
TEST(type_prop, roll_invalid_axes_check)
{
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3, 4, 1, 5});
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{3});
auto axes = make_shared<opset7::Parameter>(element::i64, Shape{1});
try
{
auto r = make_shared<opset7::Roll>(arg, shift, axes);
// Should have thrown, so fail if it didn't
FAIL() << "Unexpected pass with invalid axes and shift.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("If shift is a 1D vector, axes must be a 1D tensor of the same size."));
}
catch (...)
{
FAIL() << "Check failed for unexpected reason";
}
}
TEST(type_prop, roll_dynamic_shape)
{
auto arg = make_shared<opset7::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic()});
auto shift = make_shared<opset7::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(2)));
}
TEST(type_prop, roll_dynamic_ranks)
{
auto arg = make_shared<opset7::Parameter>(element::f32, PartialShape::dynamic());
auto shift = make_shared<opset7::Parameter>(element::i64, PartialShape::dynamic());
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape::dynamic());
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, roll_dynamic_axes_static_shift)
{
auto arg = make_shared<opset7::Parameter>(element::i32, Shape{3, 3, 4, 2});
auto shift = opset7::Constant::create(element::i64, Shape{}, {5});
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::i32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(Shape{3, 3, 4, 2}));
}
TEST(type_prop, roll_scatic_axes_dynamic_shift)
{
auto arg = make_shared<opset7::Parameter>(element::i32, Shape{1, 2, 4});
auto shift = make_shared<opset7::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto axes = make_shared<opset7::Parameter>(element::i32, Shape{3});
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::i32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(Shape{1, 2, 4}));
}
TEST(type_prop, roll_scatic_axes_dynamic_data)
{
auto arg = make_shared<opset7::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic()});
auto shift = opset7::Constant::create(element::i64, Shape{}, {5});
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto r = make_shared<opset7::Roll>(arg, shift, axes);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(2)));
}