Shell implementation for Roll operation. (#4843)
* Added shell implementation for Roll operation. * Added test, scalar check corrected. * Code style correction. * Comment fixed. * Removed redundant virtual. * Added parentheses. * Fixed tests, added axes values check. * Fixed scalar check, added axes check. * Added comment. * Added static shape checks, added more tests. * Corrected error messages. * Corrected error messages.
This commit is contained in:
parent
568f320cbc
commit
83e5bde4ea
56
ngraph/core/include/ngraph/op/roll.hpp
Normal file
56
ngraph/core/include/ngraph/op/roll.hpp
Normal file
@ -0,0 +1,56 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v7
|
||||
{
|
||||
/// \brief Tensor roll operation.
|
||||
class NGRAPH_API Roll : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Roll() = default;
|
||||
|
||||
///
|
||||
/// \brief Constructs a roll operation.
|
||||
///
|
||||
/// \param data Node producing the tensor to be shifted.
|
||||
/// \param shift Node producing the 0D or 1D tensor which specifies the
|
||||
/// number of places by which the elements are shifted.
|
||||
/// \param axes Node producing the 0D or 1D tensor which specifies axes
|
||||
/// along which elements are shifted.
|
||||
///
|
||||
Roll(const Output<Node>& data, const Output<Node>& shift, const Output<Node>& axes);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
} // namespace v7
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
@ -137,6 +137,7 @@
|
||||
#include "ngraph/op/rnn_sequence.hpp"
|
||||
#include "ngraph/op/roi_align.hpp"
|
||||
#include "ngraph/op/roi_pooling.hpp"
|
||||
#include "ngraph/op/roll.hpp"
|
||||
#include "ngraph/op/round.hpp"
|
||||
#include "ngraph/op/scatter_elements_update.hpp"
|
||||
#include "ngraph/op/scatter_nd_update.hpp"
|
||||
|
@ -183,3 +183,4 @@ NGRAPH_OP(ReadValue, ngraph::op::v6) // new version
|
||||
|
||||
// New operations added in opset7
|
||||
NGRAPH_OP(Gelu, ngraph::op::v7)
|
||||
NGRAPH_OP(Roll, ngraph::op::v7)
|
||||
|
118
ngraph/core/src/op/roll.cpp
Normal file
118
ngraph/core/src/op/roll.cpp
Normal file
@ -0,0 +1,118 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/roll.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::Roll, "Roll", 7);
|
||||
|
||||
op::v7::Roll::Roll(const Output<Node>& data, const Output<Node>& shift, const Output<Node>& axes)
|
||||
: Op({data, shift, axes})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v7::Roll::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Roll_validate_and_infer_types);
|
||||
|
||||
const auto& shift_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
shift_et.is_dynamic() || shift_et == element::i32 ||
|
||||
shift_et == element::i64,
|
||||
"Shift must have int32 or int64 element type.");
|
||||
|
||||
const auto& axes_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et.is_dynamic() || axes_et == element::i32 ||
|
||||
axes_et == element::i64,
|
||||
"Axes must have int32 or int64 element type.");
|
||||
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& shift_pshape = get_input_partial_shape(1);
|
||||
const auto& axes_pshape = get_input_partial_shape(2);
|
||||
|
||||
if (shift_pshape.is_static())
|
||||
{
|
||||
const auto& shift_rank = shift_pshape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this, shift_rank <= 1, "Shift must be a scalar or 1D tensor.");
|
||||
}
|
||||
|
||||
if (axes_pshape.is_static())
|
||||
{
|
||||
const auto& axes_rank = axes_pshape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this, axes_rank <= 1, "Axes must be a scalar or 1D tensor.");
|
||||
}
|
||||
|
||||
// If shift is a scalar, than axes can be arbitrary 1d tensor and we don't need
|
||||
// to check shift shape consistency with axes, otherwise the check is needed.
|
||||
if (!(shift_pshape.is_static() && is_scalar(shift_pshape.to_shape())))
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
shift_pshape.compatible(axes_pshape),
|
||||
"If shift is a 1D vector, axes must be a 1D tensor of the same size.");
|
||||
}
|
||||
|
||||
if (const auto& const_axes = get_constant_from_source(input_value(2)))
|
||||
{
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
|
||||
if (data_pshape.is_static())
|
||||
{
|
||||
const auto& data_rank = data_pshape.rank().get_length();
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis < data_rank,
|
||||
"Axes must be less than data tensor rank. Got "
|
||||
"data tensor rank: ",
|
||||
data_rank,
|
||||
", axis: ",
|
||||
axis);
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += data_rank;
|
||||
}
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis >= 0,
|
||||
"Axes must be positive or equal to zero. Got "
|
||||
"axis: ",
|
||||
axis);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
bool op::v7::Roll::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Roll_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v7::Roll::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Roll_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v7::Roll>(new_args[0], new_args[1], new_args[2]);
|
||||
}
|
@ -177,6 +177,7 @@ set(SRC
|
||||
type_prop/reverse_sequence.cpp
|
||||
type_prop/roi_align.cpp
|
||||
type_prop/roi_pooling.cpp
|
||||
type_prop/roll.cpp
|
||||
type_prop/round.cpp
|
||||
type_prop/rnn_cell.cpp
|
||||
type_prop/rnn_sequence.cpp
|
||||
|
189
ngraph/test/type_prop/roll.cpp
Normal file
189
ngraph/test/type_prop/roll.cpp
Normal file
@ -0,0 +1,189 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset7.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, roll_output_shape_type_test)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3, 4, 1, 5});
|
||||
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{2});
|
||||
auto axes = make_shared<opset7::Parameter>(element::i64, Shape{2});
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{3, 3, 4, 1, 5}));
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_axis_const_test)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3, 3});
|
||||
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{3});
|
||||
auto axes = opset7::Constant::create(element::i64, Shape{3}, {0, 1, -1});
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{3, 3, 3}));
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_incorrect_axis_test)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3});
|
||||
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{2});
|
||||
auto axes = opset7::Constant::create(element::i64, Shape{2}, {0, 2});
|
||||
|
||||
try
|
||||
{
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid axes and shift.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Axes must be less than data tensor rank."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_incorrect_negative_axis_test)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3});
|
||||
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{2});
|
||||
auto axes = opset7::Constant::create(element::i64, Shape{2}, {0, -5});
|
||||
|
||||
try
|
||||
{
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid axes and shift.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Axes must be positive or equal to zero."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_axis_scalar_test)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::i32, Shape{3, 3, 4});
|
||||
auto shift = opset7::Constant::create(element::i64, Shape{}, {5});
|
||||
auto axes = make_shared<opset7::Parameter>(element::i32, Shape{3});
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::i32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{3, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_invalid_axes_check)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::f32, Shape{3, 3, 4, 1, 5});
|
||||
auto shift = make_shared<opset7::Parameter>(element::i32, Shape{3});
|
||||
auto axes = make_shared<opset7::Parameter>(element::i64, Shape{1});
|
||||
|
||||
try
|
||||
{
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid axes and shift.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("If shift is a 1D vector, axes must be a 1D tensor of the same size."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_dynamic_shape)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(
|
||||
element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic()});
|
||||
auto shift = make_shared<opset7::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
|
||||
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(2)));
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_dynamic_ranks)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto shift = make_shared<opset7::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape::dynamic());
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_dynamic_axes_static_shift)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::i32, Shape{3, 3, 4, 2});
|
||||
auto shift = opset7::Constant::create(element::i64, Shape{}, {5});
|
||||
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::i32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(Shape{3, 3, 4, 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_scatic_axes_dynamic_shift)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(element::i32, Shape{1, 2, 4});
|
||||
auto shift = make_shared<opset7::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
|
||||
auto axes = make_shared<opset7::Parameter>(element::i32, Shape{3});
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::i32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(Shape{1, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(type_prop, roll_scatic_axes_dynamic_data)
|
||||
{
|
||||
auto arg = make_shared<opset7::Parameter>(
|
||||
element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic()});
|
||||
auto shift = opset7::Constant::create(element::i64, Shape{}, {5});
|
||||
auto axes = make_shared<opset7::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
|
||||
auto r = make_shared<opset7::Roll>(arg, shift, axes);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(2)));
|
||||
}
|
Loading…
Reference in New Issue
Block a user