diff --git a/ngraph/core/include/ngraph/op/broadcast.hpp b/ngraph/core/include/ngraph/op/broadcast.hpp index 665668bcab6..838115084a0 100644 --- a/ngraph/core/include/ngraph/op/broadcast.hpp +++ b/ngraph/core/include/ngraph/op/broadcast.hpp @@ -20,8 +20,8 @@ namespace ngraph class NGRAPH_API Broadcast : public util::BroadcastBase { public: - static constexpr NodeTypeInfo type_info{"Broadcast", 3}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + /// \brief Constructs a broadcast operation. Broadcast() = default; /// \brief Constructs a broadcast operation. diff --git a/ngraph/core/include/ngraph/op/bucketize.hpp b/ngraph/core/include/ngraph/op/bucketize.hpp index f45cbd4746e..6f6145b7972 100644 --- a/ngraph/core/include/ngraph/op/bucketize.hpp +++ b/ngraph/core/include/ngraph/op/bucketize.hpp @@ -16,8 +16,8 @@ namespace ngraph class NGRAPH_API Bucketize : public Op { public: - static constexpr NodeTypeInfo type_info{"Bucketize", 3}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + Bucketize() = default; /// \brief Constructs a Bucketize node diff --git a/ngraph/core/include/ngraph/op/cum_sum.hpp b/ngraph/core/include/ngraph/op/cum_sum.hpp index 72cfe892256..120d134300e 100644 --- a/ngraph/core/include/ngraph/op/cum_sum.hpp +++ b/ngraph/core/include/ngraph/op/cum_sum.hpp @@ -60,8 +60,8 @@ namespace ngraph class NGRAPH_API CumSum : public Op { public: - static constexpr NodeTypeInfo type_info{"CumSum", 0}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + /// \brief Constructs a cumulative summation operation. CumSum() = default; diff --git a/ngraph/core/include/ngraph/op/floor_mod.hpp b/ngraph/core/include/ngraph/op/floor_mod.hpp index 9cc1d25557d..ba8af70fcc4 100644 --- a/ngraph/core/include/ngraph/op/floor_mod.hpp +++ b/ngraph/core/include/ngraph/op/floor_mod.hpp @@ -19,8 +19,8 @@ namespace ngraph class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic { public: - static constexpr NodeTypeInfo type_info{"FloorMod", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + /// \brief Constructs an uninitialized addition operation FloorMod() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){}; diff --git a/ngraph/core/include/ngraph/op/maximum.hpp b/ngraph/core/include/ngraph/op/maximum.hpp index 154801336e0..d4135285600 100644 --- a/ngraph/core/include/ngraph/op/maximum.hpp +++ b/ngraph/core/include/ngraph/op/maximum.hpp @@ -16,8 +16,8 @@ namespace ngraph class NGRAPH_API Maximum : public util::BinaryElementwiseArithmetic { public: - static constexpr NodeTypeInfo type_info{"Maximum", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + /// \brief Constructs a maximum operation. Maximum() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) diff --git a/ngraph/core/include/ngraph/op/minimum.hpp b/ngraph/core/include/ngraph/op/minimum.hpp index 33ef1395e93..edf19276218 100644 --- a/ngraph/core/include/ngraph/op/minimum.hpp +++ b/ngraph/core/include/ngraph/op/minimum.hpp @@ -16,8 +16,8 @@ namespace ngraph class NGRAPH_API Minimum : public util::BinaryElementwiseArithmetic { public: - static constexpr NodeTypeInfo type_info{"Minimum", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + /// \brief Constructs a minimum operation. Minimum() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) diff --git a/ngraph/core/include/ngraph/op/mod.hpp b/ngraph/core/include/ngraph/op/mod.hpp index 4b9851be6b6..50d351c3358 100644 --- a/ngraph/core/include/ngraph/op/mod.hpp +++ b/ngraph/core/include/ngraph/op/mod.hpp @@ -17,8 +17,8 @@ namespace ngraph class NGRAPH_API Mod : public util::BinaryElementwiseArithmetic { public: - static constexpr NodeTypeInfo type_info{"Mod", 0}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + /// \brief Constructs a Mod node. Mod() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) diff --git a/ngraph/core/include/ngraph/op/power.hpp b/ngraph/core/include/ngraph/op/power.hpp index 0d800f538e5..3fbd29a2994 100644 --- a/ngraph/core/include/ngraph/op/power.hpp +++ b/ngraph/core/include/ngraph/op/power.hpp @@ -31,8 +31,8 @@ namespace ngraph class NGRAPH_API Power : public util::BinaryElementwiseArithmetic { public: - static constexpr NodeTypeInfo type_info{"Power", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + Power() : util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY) { diff --git a/ngraph/core/include/ngraph/op/roi_pooling.hpp b/ngraph/core/include/ngraph/op/roi_pooling.hpp index 2fcf3443639..d1dc8fbd5e6 100644 --- a/ngraph/core/include/ngraph/op/roi_pooling.hpp +++ b/ngraph/core/include/ngraph/op/roi_pooling.hpp @@ -15,8 +15,8 @@ namespace ngraph class NGRAPH_API ROIPooling : public Op { public: - static constexpr NodeTypeInfo type_info{"ROIPooling", 0}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + ROIPooling() = default; /// \brief Constructs a ROIPooling operation /// diff --git a/ngraph/core/include/ngraph/op/tanh.hpp b/ngraph/core/include/ngraph/op/tanh.hpp index b67849ed20c..61dec52f213 100644 --- a/ngraph/core/include/ngraph/op/tanh.hpp +++ b/ngraph/core/include/ngraph/op/tanh.hpp @@ -16,8 +16,8 @@ namespace ngraph class NGRAPH_API Tanh : public util::UnaryElementwiseArithmetic { public: - static constexpr NodeTypeInfo type_info{"Tanh", 0}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; + /// \brief Constructs a hyperbolic tangent operation. /// /// \param arg Node that produces the input tensor. diff --git a/ngraph/core/include/ngraph/op/util/binary_elementwise_comparison.hpp b/ngraph/core/include/ngraph/op/util/binary_elementwise_comparison.hpp index 64ca2502b41..e94a3c5a6b4 100644 --- a/ngraph/core/include/ngraph/op/util/binary_elementwise_comparison.hpp +++ b/ngraph/core/include/ngraph/op/util/binary_elementwise_comparison.hpp @@ -55,6 +55,8 @@ namespace ngraph const AutoBroadcastSpec& autob = AutoBroadcastSpec()); public: + NGRAPH_RTTI_DECLARATION; + void validate_and_infer_types() override; const AutoBroadcastSpec& get_autob() const override { return m_autob; } diff --git a/ngraph/core/src/op/broadcast.cpp b/ngraph/core/src/op/broadcast.cpp index 81993720203..1e775be5fa3 100644 --- a/ngraph/core/src/op/broadcast.cpp +++ b/ngraph/core/src/op/broadcast.cpp @@ -17,7 +17,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v3::Broadcast::type_info; +NGRAPH_RTTI_DEFINITION(op::v3::Broadcast, "Broadcast", 3, op::util::BroadcastBase); op::v3::Broadcast::Broadcast(const Output& arg, const Output& target_shape, diff --git a/ngraph/core/src/op/bucketize.cpp b/ngraph/core/src/op/bucketize.cpp index 4ee5270bab2..fbe7f3ae7f3 100644 --- a/ngraph/core/src/op/bucketize.cpp +++ b/ngraph/core/src/op/bucketize.cpp @@ -8,7 +8,7 @@ using namespace ngraph; using namespace std; -constexpr NodeTypeInfo op::v3::Bucketize::type_info; +NGRAPH_RTTI_DEFINITION(op::v3::Bucketize, "Bucketize", 3); op::v3::Bucketize::Bucketize(const Output& data, const Output& buckets, diff --git a/ngraph/core/src/op/cum_sum.cpp b/ngraph/core/src/op/cum_sum.cpp index f8a7286eb86..00ad8c631ab 100644 --- a/ngraph/core/src/op/cum_sum.cpp +++ b/ngraph/core/src/op/cum_sum.cpp @@ -12,7 +12,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v0::CumSum::type_info; +NGRAPH_RTTI_DEFINITION(op::v0::CumSum, "CumSum", 0); op::v0::CumSum::CumSum(const Output& arg, const Output& axis, diff --git a/ngraph/core/src/op/equal.cpp b/ngraph/core/src/op/equal.cpp index d70abc1537c..e9f8b57ec55 100644 --- a/ngraph/core/src/op/equal.cpp +++ b/ngraph/core/src/op/equal.cpp @@ -51,7 +51,7 @@ namespace equal //------------------------------- v1 ------------------------------------------- -NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1); +NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1, op::util::BinaryElementwiseComparison); op::v1::Equal::Equal(const Output& arg0, const Output& arg1, @@ -94,5 +94,6 @@ bool op::v1::Equal::has_evaluate() const bool op::v1::Equal::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v1_Equal_visit_attributes); + BinaryElementwiseComparison::visit_attributes(visitor); return true; } diff --git a/ngraph/core/src/op/floor_mod.cpp b/ngraph/core/src/op/floor_mod.cpp index 3ccb7a29524..23b84098146 100644 --- a/ngraph/core/src/op/floor_mod.cpp +++ b/ngraph/core/src/op/floor_mod.cpp @@ -10,7 +10,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v1::FloorMod::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::FloorMod, "FloorMod", 1, op::util::BinaryElementwiseArithmetic); op::v1::FloorMod::FloorMod(const Output& arg0, const Output& arg1, @@ -97,5 +97,6 @@ bool op::v1::FloorMod::has_evaluate() const bool op::v1::FloorMod::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v1_FloorMod_visit_attributes); + BinaryElementwiseArithmetic::visit_attributes(visitor); return true; } diff --git a/ngraph/core/src/op/greater.cpp b/ngraph/core/src/op/greater.cpp index bbc28493346..fc3333472fb 100644 --- a/ngraph/core/src/op/greater.cpp +++ b/ngraph/core/src/op/greater.cpp @@ -51,7 +51,7 @@ namespace greaterop //-------------------------------------- v1 ------------------------------------ -NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1); +NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1, op::util::BinaryElementwiseComparison); op::v1::Greater::Greater(const Output& arg0, const Output& arg1, diff --git a/ngraph/core/src/op/greater_eq.cpp b/ngraph/core/src/op/greater_eq.cpp index 3db1d4155a5..11c099dcc5d 100644 --- a/ngraph/core/src/op/greater_eq.cpp +++ b/ngraph/core/src/op/greater_eq.cpp @@ -51,7 +51,10 @@ namespace greater_equalop //---------------------------------- v1 ---------------------------------------- -NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual, "GreaterEqual", 1); +NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual, + "GreaterEqual", + 1, + op::util::BinaryElementwiseComparison); op::v1::GreaterEqual::GreaterEqual(const Output& arg0, const Output& arg1, @@ -95,5 +98,6 @@ bool op::v1::GreaterEqual::has_evaluate() const bool op::v1::GreaterEqual::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v1_GreaterEqual_visit_attributes); + BinaryElementwiseComparison::visit_attributes(visitor); return true; } diff --git a/ngraph/core/src/op/less.cpp b/ngraph/core/src/op/less.cpp index d9b4e8dfeb9..af0131f0d8b 100644 --- a/ngraph/core/src/op/less.cpp +++ b/ngraph/core/src/op/less.cpp @@ -51,7 +51,7 @@ namespace lessop // ----------------------------- v1 -------------------------------------------- -NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1); +NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1, op::util::BinaryElementwiseComparison); op::v1::Less::Less(const Output& arg0, const Output& arg1, diff --git a/ngraph/core/src/op/less_eq.cpp b/ngraph/core/src/op/less_eq.cpp index 3528090de46..9e00e738929 100644 --- a/ngraph/core/src/op/less_eq.cpp +++ b/ngraph/core/src/op/less_eq.cpp @@ -12,7 +12,7 @@ using namespace ngraph; // ---------------------------------- v1 --------------------------------------- -NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1); +NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1, op::util::BinaryElementwiseComparison); op::v1::LessEqual::LessEqual(const Output& arg0, const Output& arg1, diff --git a/ngraph/core/src/op/maximum.cpp b/ngraph/core/src/op/maximum.cpp index 0733759c2b5..7d06d67b356 100644 --- a/ngraph/core/src/op/maximum.cpp +++ b/ngraph/core/src/op/maximum.cpp @@ -58,7 +58,7 @@ namespace maximumop // ------------------------------------ v1 ------------------------------------- -constexpr NodeTypeInfo op::v1::Maximum::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::Maximum, "Maximum", 1, op::util::BinaryElementwiseArithmetic); op::v1::Maximum::Maximum(const Output& arg0, const Output& arg1, diff --git a/ngraph/core/src/op/minimum.cpp b/ngraph/core/src/op/minimum.cpp index bd0ff3f79f6..cfa7abeffaf 100644 --- a/ngraph/core/src/op/minimum.cpp +++ b/ngraph/core/src/op/minimum.cpp @@ -56,7 +56,7 @@ namespace minimumop // ------------------------------ v1 ------------------------------------------- -constexpr NodeTypeInfo op::v1::Minimum::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::Minimum, "Minimum", 1, op::util::BinaryElementwiseArithmetic); op::v1::Minimum::Minimum(const Output& arg0, const Output& arg1, diff --git a/ngraph/core/src/op/mod.cpp b/ngraph/core/src/op/mod.cpp index 8f3703c829f..12b323e6ebd 100644 --- a/ngraph/core/src/op/mod.cpp +++ b/ngraph/core/src/op/mod.cpp @@ -10,7 +10,7 @@ using namespace ngraph; // ------------------------------ v1 ------------------------------------------- -constexpr NodeTypeInfo op::v1::Mod::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::Mod, "Mod", 1, op::util::BinaryElementwiseArithmetic); op::v1::Mod::Mod(const Output& arg0, const Output& arg1, diff --git a/ngraph/core/src/op/not_equal.cpp b/ngraph/core/src/op/not_equal.cpp index a53ea2ee74f..1e079c36163 100644 --- a/ngraph/core/src/op/not_equal.cpp +++ b/ngraph/core/src/op/not_equal.cpp @@ -51,7 +51,7 @@ namespace not_equalop // ----------------------------------- v1 -------------------------------------- -NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1); +NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1, op::util::BinaryElementwiseComparison); op::v1::NotEqual::NotEqual(const Output& arg0, const Output& arg1, @@ -95,5 +95,6 @@ bool op::v1::NotEqual::has_evaluate() const bool op::v1::NotEqual::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v1_NotEqual_visit_attributes); + BinaryElementwiseComparison::visit_attributes(visitor); return true; } diff --git a/ngraph/core/src/op/power.cpp b/ngraph/core/src/op/power.cpp index 10695c755b5..1a52c959cd8 100644 --- a/ngraph/core/src/op/power.cpp +++ b/ngraph/core/src/op/power.cpp @@ -54,7 +54,7 @@ namespace power // ------------------------------ v1 ------------------------------------------- -constexpr NodeTypeInfo op::v1::Power::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::Power, "Power", 1, op::util::BinaryElementwiseArithmetic); op::v1::Power::Power(const Output& arg0, const Output& arg1, diff --git a/ngraph/core/src/op/roi_pooling.cpp b/ngraph/core/src/op/roi_pooling.cpp index 2aac3d9f786..3d287a6581c 100644 --- a/ngraph/core/src/op/roi_pooling.cpp +++ b/ngraph/core/src/op/roi_pooling.cpp @@ -8,7 +8,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::ROIPooling::type_info; +NGRAPH_RTTI_DEFINITION(op::ROIPooling, "ROIPooling", 0); op::ROIPooling::ROIPooling(const Output& input, const Output& coords, diff --git a/ngraph/core/src/op/tanh.cpp b/ngraph/core/src/op/tanh.cpp index c251f1c63a5..e392507475b 100644 --- a/ngraph/core/src/op/tanh.cpp +++ b/ngraph/core/src/op/tanh.cpp @@ -14,7 +14,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::Tanh::type_info; +NGRAPH_RTTI_DEFINITION(op::v0::Tanh, "Tanh", 0, op::util::UnaryElementwiseArithmetic); op::Tanh::Tanh(const Output& arg) : UnaryElementwiseArithmetic(arg) diff --git a/ngraph/core/src/op/util/binary_elementwise_comparison.cpp b/ngraph/core/src/op/util/binary_elementwise_comparison.cpp index e8b878c3ed7..3fd5ee8d3bc 100644 --- a/ngraph/core/src/op/util/binary_elementwise_comparison.cpp +++ b/ngraph/core/src/op/util/binary_elementwise_comparison.cpp @@ -10,6 +10,8 @@ using namespace std; using namespace ngraph; +NGRAPH_RTTI_DEFINITION(op::util::BinaryElementwiseComparison, "BinaryElementwiseComparison", 0); + op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const AutoBroadcastSpec& autob) : m_autob(autob) { diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 1ac0ef0f345..917be3ceb5f 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -234,6 +234,7 @@ set(SRC visitors/op/acosh.cpp visitors/op/adaptive_avg_pool.cpp visitors/op/adaptive_max_pool.cpp + visitors/op/add.cpp visitors/op/asinh.cpp visitors/op/atan.cpp visitors/op/batch_norm.cpp @@ -253,16 +254,23 @@ set(SRC visitors/op/detection_output.cpp visitors/op/einsum.cpp visitors/op/elu.cpp + visitors/op/equal.cpp visitors/op/erf.cpp visitors/op/extractimagepatches.cpp visitors/op/fake_quantize.cpp + visitors/op/floor_mod.cpp visitors/op/floor.cpp visitors/op/gather.cpp visitors/op/gelu.cpp + visitors/op/greater_equal.cpp + visitors/op/greater.cpp visitors/op/grn.cpp visitors/op/group_conv.cpp visitors/op/interpolate.cpp + visitors/op/less_equal.cpp + visitors/op/less.cpp visitors/op/log.cpp + visitors/op/logical_or.cpp visitors/op/logical_xor.cpp visitors/op/lrn.cpp visitors/op/lstm_cell.cpp @@ -270,17 +278,22 @@ set(SRC visitors/op/matmul.cpp visitors/op/matrix_nms.cpp visitors/op/max_pool.cpp + visitors/op/maximum.cpp + visitors/op/minimum.cpp visitors/op/mish.cpp visitors/op/mod.cpp visitors/op/multiclass_nms.cpp + visitors/op/multiply.cpp visitors/op/mvn.cpp visitors/op/negative.cpp visitors/op/non_max_suppression.cpp visitors/op/non_zero.cpp visitors/op/normalize_l2.cpp + visitors/op/not_equal.cpp visitors/op/one_hot.cpp visitors/op/pad.cpp visitors/op/parameter.cpp + visitors/op/power.cpp visitors/op/prior_box.cpp visitors/op/prior_box_clustered.cpp visitors/op/proposal.cpp @@ -314,10 +327,11 @@ set(SRC visitors/op/space_to_batch.cpp visitors/op/space_to_depth.cpp visitors/op/split.cpp + visitors/op/sqrt.cpp visitors/op/squared_difference.cpp visitors/op/squeeze.cpp - visitors/op/sqrt.cpp visitors/op/strided_slice.cpp + visitors/op/subtract.cpp visitors/op/swish.cpp visitors/op/tanh.cpp visitors/op/topk.cpp diff --git a/ngraph/test/visitors/op/add.cpp b/ngraph/test/visitors/op/add.cpp new file mode 100644 index 00000000000..1f13a33c1ba --- /dev/null +++ b/ngraph/test/visitors/op/add.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/binary_ops.hpp b/ngraph/test/visitors/op/binary_ops.hpp new file mode 100644 index 00000000000..6c94d5c4596 --- /dev/null +++ b/ngraph/test/visitors/op/binary_ops.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "gtest/gtest.h" +#include "util/visitor.hpp" + +template +class BinaryOperatorType +{ +public: + using op_type = T; + static constexpr ngraph::element::Type_t element_type = ELEMENT_TYPE; +}; + +template +class BinaryOperatorVisitor : public testing::Test +{ +}; + +class BinaryOperatorTypeName +{ +public: + template + static std::string GetName(int) + { + using OP_Type = typename T::op_type; + constexpr ngraph::element::Type precision(T::element_type); + const ngraph::Node::type_info_t typeinfo = OP_Type::get_type_info_static(); + return std::string{typeinfo.name} + "_" + precision.get_type_name(); + } +}; + +TYPED_TEST_SUITE_P(BinaryOperatorVisitor); + +TYPED_TEST_P(BinaryOperatorVisitor, Auto_Broadcast) +{ + using OP_Type = typename TypeParam::op_type; + const ngraph::element::Type_t element_type = TypeParam::element_type; + + ngraph::test::NodeBuilder::get_ops().register_factory(); + const auto A = + std::make_shared(element_type, ngraph::PartialShape{1, 2, 3}); + const auto B = + std::make_shared(element_type, ngraph::PartialShape{3, 2, 1}); + + auto auto_broadcast = ngraph::op::AutoBroadcastType::NUMPY; + + const auto op_func = std::make_shared(A, B, auto_broadcast); + ngraph::test::NodeBuilder builder(op_func); + const auto g_op_func = ngraph::as_type_ptr(builder.create()); + + const auto expected_attr_count = 1; + EXPECT_EQ(builder.get_value_map_size(), expected_attr_count); + EXPECT_EQ(op_func->get_autob(), g_op_func->get_autob()); +} + +REGISTER_TYPED_TEST_SUITE_P(BinaryOperatorVisitor, Auto_Broadcast); diff --git a/ngraph/test/visitors/op/equal.cpp b/ngraph/test/visitors/op/equal.cpp new file mode 100644 index 00000000000..af23535c866 --- /dev/null +++ b/ngraph/test/visitors/op/equal.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/floor_mod.cpp b/ngraph/test/visitors/op/floor_mod.cpp new file mode 100644 index 00000000000..9a1b9e2a98b --- /dev/null +++ b/ngraph/test/visitors/op/floor_mod.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/greater.cpp b/ngraph/test/visitors/op/greater.cpp new file mode 100644 index 00000000000..71362c489f8 --- /dev/null +++ b/ngraph/test/visitors/op/greater.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/greater_equal.cpp b/ngraph/test/visitors/op/greater_equal.cpp new file mode 100644 index 00000000000..fb12a162ee0 --- /dev/null +++ b/ngraph/test/visitors/op/greater_equal.cpp @@ -0,0 +1,14 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = + ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/less.cpp b/ngraph/test/visitors/op/less.cpp new file mode 100644 index 00000000000..91ac6c742f1 --- /dev/null +++ b/ngraph/test/visitors/op/less.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/less_equal.cpp b/ngraph/test/visitors/op/less_equal.cpp new file mode 100644 index 00000000000..33bb954532e --- /dev/null +++ b/ngraph/test/visitors/op/less_equal.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/logical_or.cpp b/ngraph/test/visitors/op/logical_or.cpp new file mode 100644 index 00000000000..fe1ca378111 --- /dev/null +++ b/ngraph/test/visitors/op/logical_or.cpp @@ -0,0 +1,14 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = + ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/logical_xor.cpp b/ngraph/test/visitors/op/logical_xor.cpp index 80f0085b094..b30bfba4857 100644 --- a/ngraph/test/visitors/op/logical_xor.cpp +++ b/ngraph/test/visitors/op/logical_xor.cpp @@ -2,33 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "gtest/gtest.h" - -#include "ngraph/ngraph.hpp" -#include "ngraph/op/util/attr_types.hpp" +#include "binary_ops.hpp" #include "ngraph/opsets/opset1.hpp" -#include "ngraph/opsets/opset3.hpp" -#include "ngraph/opsets/opset4.hpp" -#include "ngraph/opsets/opset5.hpp" -#include "util/visitor.hpp" +using Type = + ::testing::Types>; -using namespace std; -using namespace ngraph; -using ngraph::test::NodeBuilder; -using ngraph::test::ValueMap; - -TEST(attributes, logical_xor_op) -{ - NodeBuilder::get_ops().register_factory(); - auto x1 = make_shared(element::boolean, Shape{200}); - auto x2 = make_shared(element::boolean, Shape{200}); - - auto auto_broadcast = op::AutoBroadcastType::NUMPY; - - auto logical_xor = make_shared(x1, x2, auto_broadcast); - NodeBuilder builder(logical_xor); - auto g_logical_xor = as_type_ptr(builder.create()); - - EXPECT_EQ(g_logical_xor->get_autob(), logical_xor->get_autob()); -} +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/maximum.cpp b/ngraph/test/visitors/op/maximum.cpp new file mode 100644 index 00000000000..26b748e019b --- /dev/null +++ b/ngraph/test/visitors/op/maximum.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/minimum.cpp b/ngraph/test/visitors/op/minimum.cpp new file mode 100644 index 00000000000..f2c4d164280 --- /dev/null +++ b/ngraph/test/visitors/op/minimum.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/mod.cpp b/ngraph/test/visitors/op/mod.cpp index dce8ef15a07..bae485b02d6 100644 --- a/ngraph/test/visitors/op/mod.cpp +++ b/ngraph/test/visitors/op/mod.cpp @@ -2,33 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "gtest/gtest.h" - -#include "ngraph/ngraph.hpp" -#include "ngraph/op/util/attr_types.hpp" +#include "binary_ops.hpp" #include "ngraph/opsets/opset1.hpp" -#include "ngraph/opsets/opset3.hpp" -#include "ngraph/opsets/opset4.hpp" -#include "ngraph/opsets/opset5.hpp" -#include "util/visitor.hpp" +using Type = ::testing::Types>; -using namespace std; -using namespace ngraph; -using ngraph::test::NodeBuilder; -using ngraph::test::ValueMap; - -TEST(attributes, mod_op) -{ - NodeBuilder::get_ops().register_factory(); - auto A = make_shared(element::f32, Shape{1, 2}); - auto B = make_shared(element::f32, Shape{2, 1}); - - auto auto_broadcast = op::AutoBroadcastType::NUMPY; - - auto mod = make_shared(A, B, auto_broadcast); - NodeBuilder builder(mod); - auto g_mod = as_type_ptr(builder.create()); - - EXPECT_EQ(g_mod->get_autob(), mod->get_autob()); -} +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/multiply.cpp b/ngraph/test/visitors/op/multiply.cpp new file mode 100644 index 00000000000..f60e2b5ebb1 --- /dev/null +++ b/ngraph/test/visitors/op/multiply.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/not_equal.cpp b/ngraph/test/visitors/op/not_equal.cpp new file mode 100644 index 00000000000..93b3fc2eae4 --- /dev/null +++ b/ngraph/test/visitors/op/not_equal.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/power.cpp b/ngraph/test/visitors/op/power.cpp new file mode 100644 index 00000000000..24c25b3d64c --- /dev/null +++ b/ngraph/test/visitors/op/power.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/squared_difference.cpp b/ngraph/test/visitors/op/squared_difference.cpp index cf2a7e0981e..6ebc06579a4 100644 --- a/ngraph/test/visitors/op/squared_difference.cpp +++ b/ngraph/test/visitors/op/squared_difference.cpp @@ -2,31 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "gtest/gtest.h" - -#include "ngraph/ngraph.hpp" -#include "ngraph/op/util/attr_types.hpp" +#include "binary_ops.hpp" #include "ngraph/opsets/opset1.hpp" -#include "ngraph/opsets/opset3.hpp" -#include "ngraph/opsets/opset4.hpp" -#include "ngraph/opsets/opset5.hpp" -#include "util/visitor.hpp" +using Type = + ::testing::Types>; -using namespace std; -using namespace ngraph; -using ngraph::test::NodeBuilder; -using ngraph::test::ValueMap; - -TEST(attributes, squared_difference_op) -{ - NodeBuilder::get_ops().register_factory(); - auto x1 = make_shared(element::i32, Shape{200}); - auto x2 = make_shared(element::i32, Shape{200}); - auto auto_broadcast = op::AutoBroadcastType::NUMPY; - auto squared_difference = make_shared(x1, x2, auto_broadcast); - NodeBuilder builder(squared_difference); - auto g_squared_difference = as_type_ptr(builder.create()); - - EXPECT_EQ(g_squared_difference->get_autob(), squared_difference->get_autob()); -} +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName); diff --git a/ngraph/test/visitors/op/subtract.cpp b/ngraph/test/visitors/op/subtract.cpp new file mode 100644 index 00000000000..a2aa158c076 --- /dev/null +++ b/ngraph/test/visitors/op/subtract.cpp @@ -0,0 +1,13 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "binary_ops.hpp" +#include "ngraph/opsets/opset1.hpp" + +using Type = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast, + BinaryOperatorVisitor, + Type, + BinaryOperatorTypeName);