Revise Round operation reference implementation (#6287)
* Revise OP Round improve the input type error check add attribute test case Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com> * fix clang code style issue Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
parent
5ce5f9e0c8
commit
b68166fc3c
@ -83,6 +83,8 @@ bool ngraph::op::v5::Round::visit_attributes(AttributeVisitor& visitor)
|
||||
void op::v5::Round::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v5_Round_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, get_input_size() == 1, "Only accepts one argument. Got: ", get_input_size());
|
||||
set_output_size(1);
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
@ -284,6 +284,7 @@ set(SRC
|
||||
visitors/op/reverse_sequence.cpp
|
||||
visitors/op/rnn_cell.cpp
|
||||
visitors/op/roi_pooling.cpp
|
||||
visitors/op/round.cpp
|
||||
visitors/op/selu.cpp
|
||||
visitors/op/shuffle_channels.cpp
|
||||
visitors/op/softmax.cpp
|
||||
|
@ -19,7 +19,7 @@ using namespace ngraph;
|
||||
|
||||
static string s_manifest = "${MANIFEST}";
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, round)
|
||||
NGRAPH_TEST(${BACKEND_NAME}, round_half_to_even)
|
||||
{
|
||||
Shape shape{5};
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
|
42
ngraph/test/visitors/op/round.cpp
Normal file
42
ngraph/test/visitors/op/round.cpp
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
void static test_mode(opset5::Round::RoundMode mode)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<opset5::Round>();
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{200});
|
||||
auto round = make_shared<opset5::Round>(data, opset5::Round::RoundMode::HALF_TO_EVEN);
|
||||
NodeBuilder builder(round);
|
||||
auto g_round = as_type_ptr<opset5::Round>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_round->get_mode(), round->get_mode());
|
||||
}
|
||||
|
||||
TEST(attributes, round_op_enum_mode_half_to_even)
|
||||
{
|
||||
test_mode(opset5::Round::RoundMode::HALF_TO_EVEN);
|
||||
}
|
||||
|
||||
TEST(attributes, round_op_enum_mode_half_away_from_zero)
|
||||
{
|
||||
test_mode(opset5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user