Revise Round operation reference implementation (#6287)
* Revise OP Round improve the input type error check add attribute test case Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com> * fix clang code style issue Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
parent
5ce5f9e0c8
commit
b68166fc3c
@ -83,6 +83,8 @@ bool ngraph::op::v5::Round::visit_attributes(AttributeVisitor& visitor)
|
|||||||
void op::v5::Round::validate_and_infer_types()
|
void op::v5::Round::validate_and_infer_types()
|
||||||
{
|
{
|
||||||
NGRAPH_OP_SCOPE(v5_Round_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v5_Round_validate_and_infer_types);
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
this, get_input_size() == 1, "Only accepts one argument. Got: ", get_input_size());
|
||||||
set_output_size(1);
|
set_output_size(1);
|
||||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||||
}
|
}
|
||||||
|
@ -284,6 +284,7 @@ set(SRC
|
|||||||
visitors/op/reverse_sequence.cpp
|
visitors/op/reverse_sequence.cpp
|
||||||
visitors/op/rnn_cell.cpp
|
visitors/op/rnn_cell.cpp
|
||||||
visitors/op/roi_pooling.cpp
|
visitors/op/roi_pooling.cpp
|
||||||
|
visitors/op/round.cpp
|
||||||
visitors/op/selu.cpp
|
visitors/op/selu.cpp
|
||||||
visitors/op/shuffle_channels.cpp
|
visitors/op/shuffle_channels.cpp
|
||||||
visitors/op/softmax.cpp
|
visitors/op/softmax.cpp
|
||||||
|
@ -19,7 +19,7 @@ using namespace ngraph;
|
|||||||
|
|
||||||
static string s_manifest = "${MANIFEST}";
|
static string s_manifest = "${MANIFEST}";
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, round)
|
NGRAPH_TEST(${BACKEND_NAME}, round_half_to_even)
|
||||||
{
|
{
|
||||||
Shape shape{5};
|
Shape shape{5};
|
||||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
42
ngraph/test/visitors/op/round.cpp
Normal file
42
ngraph/test/visitors/op/round.cpp
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "ngraph/op/util/attr_types.hpp"
|
||||||
|
#include "ngraph/opsets/opset1.hpp"
|
||||||
|
#include "ngraph/opsets/opset3.hpp"
|
||||||
|
#include "ngraph/opsets/opset4.hpp"
|
||||||
|
#include "ngraph/opsets/opset5.hpp"
|
||||||
|
|
||||||
|
#include "util/visitor.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
using ngraph::test::NodeBuilder;
|
||||||
|
using ngraph::test::ValueMap;
|
||||||
|
|
||||||
|
void static test_mode(opset5::Round::RoundMode mode)
|
||||||
|
{
|
||||||
|
NodeBuilder::get_ops().register_factory<opset5::Round>();
|
||||||
|
auto data = make_shared<op::Parameter>(element::f32, Shape{200});
|
||||||
|
auto round = make_shared<opset5::Round>(data, opset5::Round::RoundMode::HALF_TO_EVEN);
|
||||||
|
NodeBuilder builder(round);
|
||||||
|
auto g_round = as_type_ptr<opset5::Round>(builder.create());
|
||||||
|
|
||||||
|
EXPECT_EQ(g_round->get_mode(), round->get_mode());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(attributes, round_op_enum_mode_half_to_even)
|
||||||
|
{
|
||||||
|
test_mode(opset5::Round::RoundMode::HALF_TO_EVEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(attributes, round_op_enum_mode_half_away_from_zero)
|
||||||
|
{
|
||||||
|
test_mode(opset5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user