From b68166fc3cb97116dc5685539f852d5174601b21 Mon Sep 17 00:00:00 2001 From: Yuan Hu Date: Tue, 22 Jun 2021 16:05:26 +0800 Subject: [PATCH] Revise Round operation reference implementation (#6287) * Revise OP Round improve the input type error check add attribute test case Signed-off-by: Hu, Yuan2 * fix clang code style issue Signed-off-by: Hu, Yuan2 --- ngraph/core/src/op/round.cpp | 2 ++ ngraph/test/CMakeLists.txt | 1 + ngraph/test/backend/round.in.cpp | 2 +- ngraph/test/visitors/op/round.cpp | 42 +++++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 ngraph/test/visitors/op/round.cpp diff --git a/ngraph/core/src/op/round.cpp b/ngraph/core/src/op/round.cpp index dde7a19d81b..83ed1fff66b 100644 --- a/ngraph/core/src/op/round.cpp +++ b/ngraph/core/src/op/round.cpp @@ -83,6 +83,8 @@ bool ngraph::op::v5::Round::visit_attributes(AttributeVisitor& visitor) void op::v5::Round::validate_and_infer_types() { NGRAPH_OP_SCOPE(v5_Round_validate_and_infer_types); + NODE_VALIDATION_CHECK( + this, get_input_size() == 1, "Only accepts one argument. Got: ", get_input_size()); set_output_size(1); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); } diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 6f4e2bb0db4..e25b6410ba0 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -284,6 +284,7 @@ set(SRC visitors/op/reverse_sequence.cpp visitors/op/rnn_cell.cpp visitors/op/roi_pooling.cpp + visitors/op/round.cpp visitors/op/selu.cpp visitors/op/shuffle_channels.cpp visitors/op/softmax.cpp diff --git a/ngraph/test/backend/round.in.cpp b/ngraph/test/backend/round.in.cpp index c58ecc7d8dc..d8fa657b010 100644 --- a/ngraph/test/backend/round.in.cpp +++ b/ngraph/test/backend/round.in.cpp @@ -19,7 +19,7 @@ using namespace ngraph; static string s_manifest = "${MANIFEST}"; -NGRAPH_TEST(${BACKEND_NAME}, round) +NGRAPH_TEST(${BACKEND_NAME}, round_half_to_even) { Shape shape{5}; auto A = make_shared(element::f32, shape); diff --git a/ngraph/test/visitors/op/round.cpp b/ngraph/test/visitors/op/round.cpp new file mode 100644 index 00000000000..159d07edd33 --- /dev/null +++ b/ngraph/test/visitors/op/round.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gtest/gtest.h" + +#include "ngraph/ngraph.hpp" +#include "ngraph/op/util/attr_types.hpp" +#include "ngraph/opsets/opset1.hpp" +#include "ngraph/opsets/opset3.hpp" +#include "ngraph/opsets/opset4.hpp" +#include "ngraph/opsets/opset5.hpp" + +#include "util/visitor.hpp" + +using namespace std; +using namespace ngraph; +using ngraph::test::NodeBuilder; +using ngraph::test::ValueMap; + +void static test_mode(opset5::Round::RoundMode mode) +{ + NodeBuilder::get_ops().register_factory(); + auto data = make_shared(element::f32, Shape{200}); + auto round = make_shared(data, opset5::Round::RoundMode::HALF_TO_EVEN); + NodeBuilder builder(round); + auto g_round = as_type_ptr(builder.create()); + + EXPECT_EQ(g_round->get_mode(), round->get_mode()); +} + +TEST(attributes, round_op_enum_mode_half_to_even) +{ + test_mode(opset5::Round::RoundMode::HALF_TO_EVEN); +} + +TEST(attributes, round_op_enum_mode_half_away_from_zero) +{ + test_mode(opset5::Round::RoundMode::HALF_AWAY_FROM_ZERO); +} + +