Revise Round operation reference implementation (#6287)

* Revise OP Round

improve the input type error check
add attribute test case

Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com>

* fix clang code style issue

Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
Yuan Hu 2021-06-22 16:05:26 +08:00 committed by GitHub
parent 5ce5f9e0c8
commit b68166fc3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 1 deletions

View File

@ -83,6 +83,8 @@ bool ngraph::op::v5::Round::visit_attributes(AttributeVisitor& visitor)
void op::v5::Round::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v5_Round_validate_and_infer_types);
NODE_VALIDATION_CHECK(
this, get_input_size() == 1, "Only accepts one argument. Got: ", get_input_size());
set_output_size(1);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}

View File

@ -284,6 +284,7 @@ set(SRC
visitors/op/reverse_sequence.cpp
visitors/op/rnn_cell.cpp
visitors/op/roi_pooling.cpp
visitors/op/round.cpp
visitors/op/selu.cpp
visitors/op/shuffle_channels.cpp
visitors/op/softmax.cpp

View File

@ -19,7 +19,7 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, round)
NGRAPH_TEST(${BACKEND_NAME}, round_half_to_even)
{
Shape shape{5};
auto A = make_shared<op::Parameter>(element::f32, shape);

View File

@ -0,0 +1,42 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "util/visitor.hpp"
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;
void static test_mode(opset5::Round::RoundMode mode)
{
NodeBuilder::get_ops().register_factory<opset5::Round>();
auto data = make_shared<op::Parameter>(element::f32, Shape{200});
auto round = make_shared<opset5::Round>(data, opset5::Round::RoundMode::HALF_TO_EVEN);
NodeBuilder builder(round);
auto g_round = as_type_ptr<opset5::Round>(builder.create());
EXPECT_EQ(g_round->get_mode(), round->get_mode());
}
TEST(attributes, round_op_enum_mode_half_to_even)
{
test_mode(opset5::Round::RoundMode::HALF_TO_EVEN);
}
TEST(attributes, round_op_enum_mode_half_away_from_zero)
{
test_mode(opset5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
}