Refactor visitor tests for BinaryElementwiseArithmetic ops (#6667)

* add binary_elementwise file

* change binary_elementwise.hpp to binary_ops.hpp

* migrate mod operation test o typed template test

* add tests for remaining binary ops

* remove comment

* fix formatting to match clang-format

* add RVO-exploit string concatenating andbeautify the code

* add validation for attributes number

* add missing visit_attributes() calls

* add missing 4th param to NGRAPH_RTTI_DEFINITION calls

* fix formatting to match clang-format
This commit is contained in:
Dawid Kożykowski 2021-07-26 14:49:14 +02:00 committed by GitHub
parent 1a92a69515
commit b4ad7a1755
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 326 additions and 116 deletions

View File

@ -20,8 +20,8 @@ namespace ngraph
class NGRAPH_API Broadcast : public util::BroadcastBase
{
public:
static constexpr NodeTypeInfo type_info{"Broadcast", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.

View File

@ -16,8 +16,8 @@ namespace ngraph
class NGRAPH_API Bucketize : public Op
{
public:
static constexpr NodeTypeInfo type_info{"Bucketize", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
Bucketize() = default;
/// \brief Constructs a Bucketize node

View File

@ -60,8 +60,8 @@ namespace ngraph
class NGRAPH_API CumSum : public Op
{
public:
static constexpr NodeTypeInfo type_info{"CumSum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a cumulative summation operation.
CumSum() = default;

View File

@ -19,8 +19,8 @@ namespace ngraph
class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"FloorMod", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an uninitialized addition operation
FloorMod()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};

View File

@ -16,8 +16,8 @@ namespace ngraph
class NGRAPH_API Maximum : public util::BinaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Maximum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a maximum operation.
Maximum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)

View File

@ -16,8 +16,8 @@ namespace ngraph
class NGRAPH_API Minimum : public util::BinaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Minimum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a minimum operation.
Minimum()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)

View File

@ -17,8 +17,8 @@ namespace ngraph
class NGRAPH_API Mod : public util::BinaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Mod", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a Mod node.
Mod()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)

View File

@ -31,8 +31,8 @@ namespace ngraph
class NGRAPH_API Power : public util::BinaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Power", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
Power()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{

View File

@ -15,8 +15,8 @@ namespace ngraph
class NGRAPH_API ROIPooling : public Op
{
public:
static constexpr NodeTypeInfo type_info{"ROIPooling", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
ROIPooling() = default;
/// \brief Constructs a ROIPooling operation
///

View File

@ -16,8 +16,8 @@ namespace ngraph
class NGRAPH_API Tanh : public util::UnaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Tanh", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a hyperbolic tangent operation.
///
/// \param arg Node that produces the input tensor.

View File

@ -55,6 +55,8 @@ namespace ngraph
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
NGRAPH_RTTI_DECLARATION;
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const override { return m_autob; }

View File

@ -17,7 +17,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v3::Broadcast::type_info;
NGRAPH_RTTI_DEFINITION(op::v3::Broadcast, "Broadcast", 3, op::util::BroadcastBase);
op::v3::Broadcast::Broadcast(const Output<Node>& arg,
const Output<Node>& target_shape,

View File

@ -8,7 +8,7 @@
using namespace ngraph;
using namespace std;
constexpr NodeTypeInfo op::v3::Bucketize::type_info;
NGRAPH_RTTI_DEFINITION(op::v3::Bucketize, "Bucketize", 3);
op::v3::Bucketize::Bucketize(const Output<Node>& data,
const Output<Node>& buckets,

View File

@ -12,7 +12,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v0::CumSum::type_info;
NGRAPH_RTTI_DEFINITION(op::v0::CumSum, "CumSum", 0);
op::v0::CumSum::CumSum(const Output<Node>& arg,
const Output<Node>& axis,

View File

@ -51,7 +51,7 @@ namespace equal
//------------------------------- v1 -------------------------------------------
NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1);
NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1, op::util::BinaryElementwiseComparison);
op::v1::Equal::Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
@ -94,5 +94,6 @@ bool op::v1::Equal::has_evaluate() const
bool op::v1::Equal::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v1_Equal_visit_attributes);
BinaryElementwiseComparison::visit_attributes(visitor);
return true;
}

View File

@ -10,7 +10,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::FloorMod::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::FloorMod, "FloorMod", 1, op::util::BinaryElementwiseArithmetic);
op::v1::FloorMod::FloorMod(const Output<Node>& arg0,
const Output<Node>& arg1,
@ -97,5 +97,6 @@ bool op::v1::FloorMod::has_evaluate() const
bool op::v1::FloorMod::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v1_FloorMod_visit_attributes);
BinaryElementwiseArithmetic::visit_attributes(visitor);
return true;
}

View File

@ -51,7 +51,7 @@ namespace greaterop
//-------------------------------------- v1 ------------------------------------
NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1);
NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1, op::util::BinaryElementwiseComparison);
op::v1::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,

View File

@ -51,7 +51,10 @@ namespace greater_equalop
//---------------------------------- v1 ----------------------------------------
NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual, "GreaterEqual", 1);
NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual,
"GreaterEqual",
1,
op::util::BinaryElementwiseComparison);
op::v1::GreaterEqual::GreaterEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
@ -95,5 +98,6 @@ bool op::v1::GreaterEqual::has_evaluate() const
bool op::v1::GreaterEqual::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v1_GreaterEqual_visit_attributes);
BinaryElementwiseComparison::visit_attributes(visitor);
return true;
}

View File

@ -51,7 +51,7 @@ namespace lessop
// ----------------------------- v1 --------------------------------------------
NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1);
NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1, op::util::BinaryElementwiseComparison);
op::v1::Less::Less(const Output<Node>& arg0,
const Output<Node>& arg1,

View File

@ -12,7 +12,7 @@ using namespace ngraph;
// ---------------------------------- v1 ---------------------------------------
NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1);
NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1, op::util::BinaryElementwiseComparison);
op::v1::LessEqual::LessEqual(const Output<Node>& arg0,
const Output<Node>& arg1,

View File

@ -58,7 +58,7 @@ namespace maximumop
// ------------------------------------ v1 -------------------------------------
constexpr NodeTypeInfo op::v1::Maximum::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::Maximum, "Maximum", 1, op::util::BinaryElementwiseArithmetic);
op::v1::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,

View File

@ -56,7 +56,7 @@ namespace minimumop
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Minimum::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::Minimum, "Minimum", 1, op::util::BinaryElementwiseArithmetic);
op::v1::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,

View File

@ -10,7 +10,7 @@ using namespace ngraph;
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Mod::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::Mod, "Mod", 1, op::util::BinaryElementwiseArithmetic);
op::v1::Mod::Mod(const Output<Node>& arg0,
const Output<Node>& arg1,

View File

@ -51,7 +51,7 @@ namespace not_equalop
// ----------------------------------- v1 --------------------------------------
NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1);
NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1, op::util::BinaryElementwiseComparison);
op::v1::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
@ -95,5 +95,6 @@ bool op::v1::NotEqual::has_evaluate() const
bool op::v1::NotEqual::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v1_NotEqual_visit_attributes);
BinaryElementwiseComparison::visit_attributes(visitor);
return true;
}

View File

@ -54,7 +54,7 @@ namespace power
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Power::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::Power, "Power", 1, op::util::BinaryElementwiseArithmetic);
op::v1::Power::Power(const Output<Node>& arg0,
const Output<Node>& arg1,

View File

@ -8,7 +8,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::ROIPooling::type_info;
NGRAPH_RTTI_DEFINITION(op::ROIPooling, "ROIPooling", 0);
op::ROIPooling::ROIPooling(const Output<Node>& input,
const Output<Node>& coords,

View File

@ -14,7 +14,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Tanh::type_info;
NGRAPH_RTTI_DEFINITION(op::v0::Tanh, "Tanh", 0, op::util::UnaryElementwiseArithmetic);
op::Tanh::Tanh(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)

View File

@ -10,6 +10,8 @@
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::util::BinaryElementwiseComparison, "BinaryElementwiseComparison", 0);
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const AutoBroadcastSpec& autob)
: m_autob(autob)
{

View File

@ -234,6 +234,7 @@ set(SRC
visitors/op/acosh.cpp
visitors/op/adaptive_avg_pool.cpp
visitors/op/adaptive_max_pool.cpp
visitors/op/add.cpp
visitors/op/asinh.cpp
visitors/op/atan.cpp
visitors/op/batch_norm.cpp
@ -253,16 +254,23 @@ set(SRC
visitors/op/detection_output.cpp
visitors/op/einsum.cpp
visitors/op/elu.cpp
visitors/op/equal.cpp
visitors/op/erf.cpp
visitors/op/extractimagepatches.cpp
visitors/op/fake_quantize.cpp
visitors/op/floor_mod.cpp
visitors/op/floor.cpp
visitors/op/gather.cpp
visitors/op/gelu.cpp
visitors/op/greater_equal.cpp
visitors/op/greater.cpp
visitors/op/grn.cpp
visitors/op/group_conv.cpp
visitors/op/interpolate.cpp
visitors/op/less_equal.cpp
visitors/op/less.cpp
visitors/op/log.cpp
visitors/op/logical_or.cpp
visitors/op/logical_xor.cpp
visitors/op/lrn.cpp
visitors/op/lstm_cell.cpp
@ -270,17 +278,22 @@ set(SRC
visitors/op/matmul.cpp
visitors/op/matrix_nms.cpp
visitors/op/max_pool.cpp
visitors/op/maximum.cpp
visitors/op/minimum.cpp
visitors/op/mish.cpp
visitors/op/mod.cpp
visitors/op/multiclass_nms.cpp
visitors/op/multiply.cpp
visitors/op/mvn.cpp
visitors/op/negative.cpp
visitors/op/non_max_suppression.cpp
visitors/op/non_zero.cpp
visitors/op/normalize_l2.cpp
visitors/op/not_equal.cpp
visitors/op/one_hot.cpp
visitors/op/pad.cpp
visitors/op/parameter.cpp
visitors/op/power.cpp
visitors/op/prior_box.cpp
visitors/op/prior_box_clustered.cpp
visitors/op/proposal.cpp
@ -314,10 +327,11 @@ set(SRC
visitors/op/space_to_batch.cpp
visitors/op/space_to_depth.cpp
visitors/op/split.cpp
visitors/op/sqrt.cpp
visitors/op/squared_difference.cpp
visitors/op/squeeze.cpp
visitors/op/sqrt.cpp
visitors/op/strided_slice.cpp
visitors/op/subtract.cpp
visitors/op/swish.cpp
visitors/op/tanh.cpp
visitors/op/topk.cpp

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Add, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,60 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "gtest/gtest.h"
#include "util/visitor.hpp"
template <typename T, ngraph::element::Type_t ELEMENT_TYPE>
class BinaryOperatorType
{
public:
using op_type = T;
static constexpr ngraph::element::Type_t element_type = ELEMENT_TYPE;
};
template <typename T>
class BinaryOperatorVisitor : public testing::Test
{
};
class BinaryOperatorTypeName
{
public:
template <typename T>
static std::string GetName(int)
{
using OP_Type = typename T::op_type;
constexpr ngraph::element::Type precision(T::element_type);
const ngraph::Node::type_info_t typeinfo = OP_Type::get_type_info_static();
return std::string{typeinfo.name} + "_" + precision.get_type_name();
}
};
TYPED_TEST_SUITE_P(BinaryOperatorVisitor);
TYPED_TEST_P(BinaryOperatorVisitor, Auto_Broadcast)
{
using OP_Type = typename TypeParam::op_type;
const ngraph::element::Type_t element_type = TypeParam::element_type;
ngraph::test::NodeBuilder::get_ops().register_factory<OP_Type>();
const auto A =
std::make_shared<ngraph::op::Parameter>(element_type, ngraph::PartialShape{1, 2, 3});
const auto B =
std::make_shared<ngraph::op::Parameter>(element_type, ngraph::PartialShape{3, 2, 1});
auto auto_broadcast = ngraph::op::AutoBroadcastType::NUMPY;
const auto op_func = std::make_shared<OP_Type>(A, B, auto_broadcast);
ngraph::test::NodeBuilder builder(op_func);
const auto g_op_func = ngraph::as_type_ptr<OP_Type>(builder.create());
const auto expected_attr_count = 1;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
EXPECT_EQ(op_func->get_autob(), g_op_func->get_autob());
}
REGISTER_TYPED_TEST_SUITE_P(BinaryOperatorVisitor, Auto_Broadcast);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Equal, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::FloorMod, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Greater, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,14 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type =
::testing::Types<BinaryOperatorType<ngraph::opset1::GreaterEqual, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Less, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::LessEqual, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,14 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type =
::testing::Types<BinaryOperatorType<ngraph::opset1::LogicalOr, ngraph::element::boolean>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -2,33 +2,13 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "util/visitor.hpp"
using Type =
::testing::Types<BinaryOperatorType<ngraph::opset1::LogicalXor, ngraph::element::boolean>>;
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;
TEST(attributes, logical_xor_op)
{
NodeBuilder::get_ops().register_factory<opset1::LogicalXor>();
auto x1 = make_shared<op::Parameter>(element::boolean, Shape{200});
auto x2 = make_shared<op::Parameter>(element::boolean, Shape{200});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto logical_xor = make_shared<opset1::LogicalXor>(x1, x2, auto_broadcast);
NodeBuilder builder(logical_xor);
auto g_logical_xor = as_type_ptr<opset1::LogicalXor>(builder.create());
EXPECT_EQ(g_logical_xor->get_autob(), logical_xor->get_autob());
}
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Maximum, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Minimum, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -2,33 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "util/visitor.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Mod, ngraph::element::f32>>;
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;
TEST(attributes, mod_op)
{
NodeBuilder::get_ops().register_factory<opset1::Mod>();
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 1});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto mod = make_shared<opset1::Mod>(A, B, auto_broadcast);
NodeBuilder builder(mod);
auto g_mod = as_type_ptr<opset1::Mod>(builder.create());
EXPECT_EQ(g_mod->get_autob(), mod->get_autob());
}
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Multiply, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::NotEqual, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Power, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -2,31 +2,13 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "util/visitor.hpp"
using Type =
::testing::Types<BinaryOperatorType<ngraph::opset1::SquaredDifference, ngraph::element::f32>>;
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;
TEST(attributes, squared_difference_op)
{
NodeBuilder::get_ops().register_factory<opset1::SquaredDifference>();
auto x1 = make_shared<op::Parameter>(element::i32, Shape{200});
auto x2 = make_shared<op::Parameter>(element::i32, Shape{200});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto squared_difference = make_shared<opset1::SquaredDifference>(x1, x2, auto_broadcast);
NodeBuilder builder(squared_difference);
auto g_squared_difference = as_type_ptr<opset1::SquaredDifference>(builder.create());
EXPECT_EQ(g_squared_difference->get_autob(), squared_difference->get_autob());
}
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "binary_ops.hpp"
#include "ngraph/opsets/opset1.hpp"
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Subtract, ngraph::element::f32>>;
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
BinaryOperatorVisitor,
Type,
BinaryOperatorTypeName);